OpenCV机器学习运用之KNN辨认手写数据集
作者头像
  • 雷军校友
  • 2020-05-12 19:51:12 2

回顾

在《OpenCV机器学习运用之KNN简单实现》一文中,我们通过一个实例深入了解了KNN算法,并学会了如何利用OpenCV提供的API来应用KNN算法。本节我们将进一步探讨如何利用KNN算法构建一个基本的OCR程序,从而识别手写数据集,加深对KNN算法的理解。

数据集准备

通过对之前SVM与KNN简单实现的例子,我们已经大致掌握了机器学习的基本流程:设定训练数据、初始化模型、训练模型以及进行预测。为了识别手写数据集,我们需要准备好训练数据集和测试数据集。

获取digits.png

如上图所示,在OpenCV的samples文件夹下的data文件夹中有一张名为digits.png的图像,该图像包含5000个手写数字,每个数字从0到9各500个,每个数字大小为20x20像素。我们利用这张图像来准备数据集。

首先,我们需要将图像中的5000个数字分割出来,作为5000个样本。观察图像可以发现,每行有100个数字,每列有50个数字。因此,我们使用NumPy的hsplit()vsplit()函数将图像分割成5000个20x20像素的单元格。具体代码如下:

python img = cv.imread('digits.png') gray = cv.cvtColor(img, cv.COLOR_BGR2GRAY) cells = [np.hsplit(row, 100) for row in np.vsplit(gray, 50)] x = np.array(cells)

分割出5000个手写数字后,我们开始准备训练数据集和测试数据集。为了确保训练集和测试集都包含0至9的所有数字,我们将digits.png图像的前50列作为训练集,后50列作为测试集。

python train = x[:,:50].reshape(-1,400).astype(np.float32) # 大小为(2500,400) test = x[:,50:100].reshape(-1,400).astype(np.float32) # 大小为(2500,400)

接下来,我们需要为训练数据和测试数据创建相应的标签,以便模型进行训练。由于图像和分割都非常规则,我们使用NumPy的repeat()函数为数据集创建标签。

python k = np.arange(10) train_labels = np.repeat(k,250)[:,np.newaxis] test_labels = train_labels.copy()

至此,我们已经完成了数据集的准备。

模型训练与预测

在OpenCV中初始化KNN模型非常简单,只需调用ml.KNearest_create()方法即可创建KNN模型对象,然后使用该对象的train()方法传入训练数据和标签进行训练。

python knn = cv.ml.KNearest_create() knn.train(train, cv.ml.ROW_SAMPLE, train_labels)

训练完成后,我们使用KNN模型对象的findNearest()方法获取预测结果。这里我们将测试数据传入,得到预测结果,并将其与真实标签进行对比,计算预测的准确率。完整代码如下:

完整代码

运行结果表明,预测的准确率为91.76%。

总结

通过本节的学习,我们使用KNN算法构建了一个基本的OCR程序来识别手写数据集,进一步加深了对KNN算法的理解。

    本文来源:图灵汇
责任编辑: : 雷军校友
声明:本文系图灵汇原创稿件,版权属图灵汇所有,未经授权不得转载,已经协议授权的媒体下载使用时须注明"稿件来源:图灵汇",违者将依法追究责任。
    分享
辨认手写运用机器数据OpenCV学习KNN
    下一篇