项目4:识别手写数字图片
上一节
下一节
一、安装第三方库
安装numpy库
pip install numpy
安装 matplotlib库
pip install matplotlib
安装sklearn库
pip install scikit_learn
安装CV2库
pip install opencv_python
安装joblib库
pip install joblib
二、训练模型
import numpy as np
from sklearn.linear_model import LogisticRegression
import os
import joblib
#from sklearn.externals import joblib
# 数据预处理
trainData = np.loadtxt(open('digits_training.csv', 'r'), delimiter=",", skiprows=1) # 装载数据
MTrain, NTrain = np.shape(trainData) # 行列数
print("训练集:", MTrain, NTrain)
xTrain = trainData[:, 1:NTrain]
xTrain_col_avg = np.mean(xTrain, axis=0) # 对各列求均值
xTrain = (xTrain - xTrain_col_avg) / 255 # 归一化
yTrain = trainData[:, 0]
'''================================='''
# 训练模型
model = LogisticRegression(solver='lbfgs', multi_class='multinomial', max_iter=500)
model.fit(xTrain, yTrain)
print("训练完毕")
'''================================='''
# 测试模型
testData = np.loadtxt(open('digits_testing.csv', 'r'), delimiter=",", skiprows=1)
MTest, NTest = np.shape(testData)
print("测试集:", MTest, NTest)
xTest = testData[:, 1:NTest]
xTest = (xTest - xTrain_col_avg) / 255 # 使用训练数据的列均值进行处理
yTest = testData[:, 0]
yPredict = model.predict(xTest)
errors = np.count_nonzero(yTest - yPredict) # 返回非零项个数
print("预测完毕。错误:", errors, "条")
print("测试数据正确率:", (MTest - errors) / MTest)
'''================================='''
# 保存模型
# 创建文件目录
dirs = 'testModel'
if not os.path.exists(dirs):
os.makedirs(dirs)
joblib.dump(model, dirs + '/model.pkl')
print("模型已保存")
三、测试模型
import cv2
import numpy as np
from sklearn import svm
import joblib
#from sklearn.externals import joblib
map = cv2.imread(r"a0.png")
GrayImage = cv2.cvtColor(map, cv2.COLOR_BGR2GRAY)
ret, thresh2 = cv2.threshold(GrayImage, 127, 255, cv2.THRESH_BINARY_INV)
Image = cv2.resize(thresh2, (28, 28))
img_array = np.asarray(Image)
z = img_array.reshape(1, -1)
'''================================================'''
model = joblib.load('testModel' + '/model.pkl')
yPredict = model.predict(z)
print(yPredict)
y = str(yPredict)
cv2.putText(map, y, (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2, cv2.LINE_AA)
cv2.imshow("map", map)
cv2.waitKey(0)

