手写数字识别(识别纸上手写的数字) |
您所在的位置:网站首页 › 数字五的照片 › 手写数字识别(识别纸上手写的数字) |
说明
使用pytorch框架,实现对MNIST手写数字数据集的训练和识别。重点是,自己手写数字,手机拍照后传入电脑,使用你自己训练的权重和偏置能够识别。数据预处理过程的代码是重点。 分析要识别自己用手在纸上写的数字,从特征上来看,手写数字相比于普通的电脑上的数字最大的 不同就是数字的边缘会发生不同幅度的抖动。而且,在MNIST数据集中的数字是边缘为黑色的,然后数字是不同灰度的白色的,如下所示: 至此,对手写数字网络的训练已经结束,且训练的准确性为: 因为我们手机拍的照片和训练集的图片有很大的区别,所以无法将手机上拍的照片直接丢到训练好的网络模型中进行识别,需要先对图片进行预处理。有几点需要对原图进行改变: 图片的大小:肯定得将拍摄到的图片转换成 28 ∗ 28 28*28 28∗28尺寸大小的图片。图片的通道数:由于MNIST是灰度图,所以原图的channel也得转换成1。图片的背景:图片的背景得转换成MNIST相同的黑色,这样识别结果准确性更高。数字的颜色:毋庸置疑,数字的颜色得变成MNIST相同的白色。数字颜色中间深边缘前:观察MNIST的白色部分并不都是255全白,而是有渐变色的,这个渐变色模拟起来比较困难,算是难度最大的一点了。 接下来直接上代码了: import cv2 import numpy as np def image_preprocessing(): # 读取图片 img = cv2.imread("picture/test8.jpeg") # =====================图像处理======================== # # 转换成灰度图像 gray_img = cv2.cvtColor(img , cv2.COLOR_BGR2GRAY) # 进行高斯滤波 gauss_img = cv2.GaussianBlur(gray_img, (5,5), 0, 0, cv2.BORDER_DEFAULT) # 边缘检测 img_edge1 = cv2.Canny(gauss_img, 100, 200) # ==================================================== # # =====================图像分割======================== # # 获取原始图像的宽和高 high = img.shape[0] width = img.shape[1] # 分别初始化高和宽的和 add_width = np.zeros(high, dtype = int) add_high = np.zeros(width, dtype = int) # 计算每一行的灰度图的值的和 for h in range(high): for w in range(width): add_width[h] = add_width[h] + img_edge1[h][w] # 计算每一列的值的和 for w in range(width): for h in range(high): add_high[w] = add_high[w] + img_edge1[h][w] # 初始化上下边界为宽度总值最大的值的索引 acount_high_up = np.argmax(add_width) acount_high_down = np.argmax(add_width) # 将上边界坐标值上移,直到没有遇到白色点停止,此为数字的上边界 while add_width[acount_high_up] != 0: acount_high_up = acount_high_up + 1 # 将下边界坐标值下移,直到没有遇到白色点停止,此为数字的下边界 while add_width[acount_high_down] != 0: acount_high_down = acount_high_down - 1 # 初始化左右边界为宽度总值最大的值的索引 acount_width_left = np.argmax(add_high) acount_width_right = np.argmax(add_high) # 将左边界坐标值左移,直到没有遇到白色点停止,此为数字的左边界 while add_high[acount_width_left] != 0: acount_width_left = acount_width_left - 1 # 将右边界坐标值右移,直到没有遇到白色点停止,此为数字的右边界 while add_high[acount_width_right] != 0: acount_width_right = acount_width_right + 1 # 求出宽和高的间距 width_spacing = acount_width_right - acount_width_left high_spacing = acount_high_up - acount_high_down # 求出宽和高的间距差 poor = width_spacing - high_spacing # 将数字进行正方形分割,目的是方便之后进行图像压缩 if poor > 0: tailor_image = img[acount_high_down - poor // 2 - 5:acount_high_up + poor - poor // 2 + 5, acount_width_left - 5:acount_width_right + 5] else: tailor_image = img[acount_high_down - 5:acount_high_up + 5, acount_width_left + poor // 2 - 5:acount_width_right - poor + poor // 2 + 5] # ==================================================== # # ======================小图处理======================= # # 将裁剪后的图片进行灰度化 gray_img = cv2.cvtColor(tailor_image , cv2.COLOR_BGR2GRAY) # 高斯去噪 gauss_img = cv2.GaussianBlur(gray_img, (5,5), 0, 0, cv2.BORDER_DEFAULT) # 将图像形状调整到28*28大小 zoom_image = cv2.resize(gauss_img, (28, 28)) # 获取图像的高和宽 high = zoom_image.shape[0] wide = zoom_image.shape[1] # 将图像每个点的灰度值进行阈值比较 for h in range(high): for w in range(wide): # 若灰度值大于100,则判断为背景并赋值0,否则将深灰度值变白处理 if zoom_image[h][w] > 100: zoom_image[h][w] = 0 else: zoom_image[h][w] = 255 - zoom_image[h][w] # ==================================================== # return zoom_image在此,我在纸上写了个6,如下图所示: 预测代码如下: import torch # pretreatment.py为上面图片预处理的文件名,导入图片预处理文件 import pretreatment as PRE # 加载网络模型 net = torch.load('weight/test.pkl') # 得到返回的待预测图片值,就是pretreatment.py中的zoom_image img = PRE.image_preprocessing() # 将待预测图片转换形状 inputs = img.reshape(-1, 784) # 输入数据转换成tensor张量类型,并转换成浮点类型 inputs = torch.from_numpy(inputs) inputs = inputs.float() # 丢入网络进行预测,得到预测数据 predict = net(inputs) # 打印对应的最后的预测结果 print("The number in this picture is {}".format(torch.argmax(predict).detach().numpy()))最后得到结果如图所示: |
今日新闻 |
推荐新闻 |
CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3 |