项目4:识别手写数字图片
上一节
下一节
一、安装第三方库
安装numpy库
pip install numpy
安装 matplotlib库
pip install matplotlib
安装sklearn库
pip install scikit_learn
安装CV2库
pip install opencv_python
安装joblib库
pip install joblib
二、训练模型
Plain Text
import numpy as npfrom sklearn.linear_model import LogisticRegressionimport osimport joblib #from sklearn.externals import joblib # 数据预处理trainData = np.loadtxt(open('digits_training.csv', 'r'), delimiter=",", skiprows=1) # 装载数据MTrain, NTrain = np.shape(trainData) # 行列数print("训练集:", MTrain, NTrain)xTrain = trainData[:, 1:NTrain]xTrain_col_avg = np.mean(xTrain, axis=0) # 对各列求均值xTrain = (xTrain - xTrain_col_avg) / 255 # 归一化yTrain = trainData[:, 0] '''================================='''# 训练模型model = LogisticRegression(solver='lbfgs', multi_class='multinomial', max_iter=500)model.fit(xTrain, yTrain)print("训练完毕") '''================================='''# 测试模型testData = np.loadtxt(open('digits_testing.csv', 'r'), delimiter=",", skiprows=1)MTest, NTest = np.shape(testData)print("测试集:", MTest, NTest)xTest = testData[:, 1:NTest]xTest = (xTest - xTrain_col_avg) / 255 # 使用训练数据的列均值进行处理yTest = testData[:, 0]yPredict = model.predict(xTest)errors = np.count_nonzero(yTest - yPredict) # 返回非零项个数print("预测完毕。错误:", errors, "条")print("测试数据正确率:", (MTest - errors) / MTest) '''================================='''# 保存模型 # 创建文件目录dirs = 'testModel'if not os.path.exists(dirs): os.makedirs(dirs)joblib.dump(model, dirs + '/model.pkl')print("模型已保存")
三、测试模型
Plain Text
import cv2import numpy as npfrom sklearn import svmimport joblib #from sklearn.externals import joblib map = cv2.imread(r"a0.png")GrayImage = cv2.cvtColor(map, cv2.COLOR_BGR2GRAY)ret, thresh2 = cv2.threshold(GrayImage, 127, 255, cv2.THRESH_BINARY_INV)Image = cv2.resize(thresh2, (28, 28))img_array = np.asarray(Image)z = img_array.reshape(1, -1) '''================================================''' model = joblib.load('testModel' + '/model.pkl')yPredict = model.predict(z)print(yPredict)y = str(yPredict)cv2.putText(map, y, (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2, cv2.LINE_AA)cv2.imshow("map", map)cv2.waitKey(0)

