机器学习(2)-- KNN算法之手写数字识别

embedded/2024/9/23 22:14:41/

文章目录

  • KNN算法
  • 数字识别
    • 训练模型
      • 完整代码展示
    • 测试模型
    • 测试新的数据
  • 总结

KNN算法

KNN(K-Nearest Neighbor,K最近邻)算法是一种用于分类和回归的非参数统计方法,尤其在分类问题中表现出色。在手写数字识别领域,KNN算法通过比较测试样本与训练样本之间的距离,找到最近的K个邻居,并根据这些邻居的类别来预测测试样本的类别。

接下来,让我们详细了解了解,knn怎么进行手写数字识别:

数字识别

对于数字识别我们进行三个方面来完成它:

  1. 训练模型:得到模型
  2. 测试模型:测试模型识别的准确率
  3. 测试新的数据:查看实用效果

训练模型

  1. 收集数据

在这里插入图片描述

  1. 读取图片数据

使用opencv处理图片,将图片的像素数值读取进来,并返回的是一个三维(高,宽,颜色)numpy数组

 pip install opencv-python==3.4.11.45
import cv2
img = cv2.imread("digits.png")
  1. 转化灰度图

将图片转化为灰度图,从而让三维数组变成二位的数组:

gray = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)
  1. 处理图像

对图片进行处理:将其先垂直切分(横向)成50份,再将每一份水平切分(竖向)成100份,这样我们的每份图片的像素值都为20*20(训练的图片比较规范)共500个,比如:

在这里插入图片描述

cells = [np.hsplit(row,100) for row in np.vsplit(gray,50)] #列表生成式
  1. 装进array数组

将切分的每一份图片像素数据都装进array数组中:

x = np.array(cells)
  1. 分隔数据

将数据竖着分隔一半,一半作为训练集,一般作为测试集:

train = x[:,:50]
test = x[:,50:100]
  1. 调整数据结构

由于我们最后要将数据放在KNN算法中训练,我们得将数据结构调整为适合KNN算法训练的结构,KNN要求输入的数据为二维数组,那么我们就来改变每份图片数组的维度:reshape:

train_new = train.reshape(-1,400).astype(np.float32)
  1. 分配标签

我们训练着那么多的数据,却没有给他们具体的类别标签(图像的实际值),因为我们之前的图像处理都是在寻找图像特征,但是并没有给他们一个具体对应的类别,只有空荡荡的特征,无法分类,所以我们得给切分的每份图片打上它们对应的标签:

#repeat用于重复数组中数值,此处重复250次,因为训练集中表示每个类别的图片只有250个,要将它们一一对应打上标签
#np.newaxis用于在数组中创建一个新的维度,即将每个标签单独放
#原本[00000……1111……] ----> [[0][0]……[1][1]……]
k = np.arange(10)
train_mark = np.repeat(k,250)[:,np.newaxis]
  1. 训练模型

在训练时,将训练集与标签一一对应训练:

#通过cv2创建一个knn模型
knn = cv2.ml.KNearest_create()
#cv2.ml.ROW_SAMPLE:告诉opencv将训练的数据与类别按行一一对应训练
knn.train(train_new,cv2.ml.ROW_SAMPLE,train_mark)

完整代码展示

import numpy as np
import cv2
#总结:收集数据 -- 读取图片数据 -- 转化灰度图 -- 处理图像 -- 装进array数组 -- 调整数据结构 -- 分配标签 -- 训练模型#读取训练集图片
img = cv2.imread("digits.png")#将图片转化为灰度图
gray = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)#对图片进行处理
cells = [np.hsplit(row,100) for row in np.vsplit(gray,50)]#将每一份图片都装进array数组中
x = np.array(cells)#分隔数据
train = x[:,:50]
test = x[:,50:100]#将数据构造为符合KNN的输入,KNN要求输入的数据为二维数组
#那么我们就来改变每份图片数组的维度:reshape
train_new = train.reshape(-1,400).astype(np.float32)#分配标签
k = np.arange(10)
train_mark = np.repeat(k,250)[:,np.newaxis]#通过cv2创建一个knn模型
knn = cv2.ml.KNearest_create()
#cv2.ml.ROW_SAMPLE:告诉opencv将训练的数据与类别按行一一对应训练
knn.train(train_new,cv2.ml.ROW_SAMPLE,train_mark)

这样我们就训练好的这份简单的数据内容,训练了一个可以识别数字0~9的模型,模型训练完了,我们总得知道它到底能不能识别数字吧?识别数字成功的准确率能达到多少呢?

测试模型

  1. 评估性能:测试模型帮助评估模型的准确性、效率、鲁棒性和其他性能指标。
  2. 识别问题:通过测试,可以及早发现和定位模型或产品中的缺陷、错误或不足之处。
  3. 优化和改进:测试模型提供的数据和反馈是优化和改进模型或产品的关键依据。基于测试结果,可以调整模型参数、改进算法设计、优化系统架构等,以提升模型或产品的性能和质量。

那么我们来测试我们刚刚训练出的模型:

前面说了,图片中的数据一半作为训练集,一半作为测试集,将测试集数据也进行以上操作:

test_new = test.reshape(-1,400).astype(np.float32) #调整数据结构
test_labels = np.repeat(k,250)[:,np.newaxis] #分配标签

处理好测试集的数据之后,我们来测试模型:

#将测试集放入模型测试
ret,result,neighbours,dist = knn.findNearest(test_new,k=3)#ret:表示操作是否成功#result:表示测试样本的预测标签(浮点数组)#neighbours:表示与测试样本最近的k个邻居的索引(整数数组)#dist:表示测试样本与每个最近邻居之间的距离(浮点数组)
#通过测试集校验准确率
matches = result==test_labels #将模型对测试集的预测结果(result)与实际的测试标签(test_labels)进行比较。
correct = np.count_nonzero(matches) #计算 matches 数组中 True(即正确预测)的数量
accuracy = correct*100.0/result.size #result.size 返回 result 数组中的元素总数
print("当前准确率为:",accuracy)
----------------
当前准确率为: 91.64

模型测试完成后,我们要尝试它在实际中的使用效果,查看其实用性。

测试新的数据

在画图软件中,画几个像素值20*20的图片,让其进入模型看看测试结果:比如:

在这里插入图片描述

这个测试数据已经进行了一部分的处理:

#处理图片
try_img = cv2.imread("4.png")  #读取图片
try_gray = cv2.cvtColor(try_img,cv2.COLOR_BGR2GRAY) #转为灰度图,二维
z = np.array(try_gray) #装入二维数组
try_new = z.reshape(-1,400).astype(np.float32) #调整结构,适用于KNN
#测试结果
ret,result,neighbours,dist = knn.findNearest(try_new,k=3)
print(result)  #查看测试结果,显示分类类别
------------------
[[4.]]  #测试结果正确

总结

本篇介绍了如何使用KNN算法进行手写数字识别:

  1. 训练模型:收集数据 – 读取图片数据 – 转化灰度图 – 处理图像 – 装进array数组 – 调整数据结构 – 分配标签 – 训练模型

  2. 测试模型:评估性能 – 识别问题 – 优化和改进

  3. 测试数据:查看实用性


http://www.ppmy.cn/embedded/97908.html

相关文章

Apache HOP (Hop Orchestration Platform) VS Data Integration (通常被称为 Kettle)

Apache HOP (Hop Orchestration Platform) 和 Data Integration (通常被称为 Kettle) 都是强大的 ETL (Extract, Transform, Load) 工具, 它们都由 Hitachi Vantara 开发和支持。尽管它们有着相似的目标,即帮助用户进行数据集成任务,但它们在…

差分法(Differencing),多变量差分对多个时间序列进行联合分析

介绍 差分法(Differencing)是时间序列分析中的一种重要技术,主要用于使非平稳时间序列变得平稳,以便能够应用诸如ARIMA(AutoRegressive Integrated Moving Average)模型等线性模型。非平稳时间序列通常具有…

二:《Python基础语法汇总》— 条件判断与循环结构

一:条件判断 1.程序执行的三大流程: ​ 顺序流程:无缩进代码,从上往下依次执行 ​ 分支流程:选择性执行某块代码,或跳过某行代码去执行,与缩进(TAB)有关 ​ 循环流程&…

超实用的短链接使用指南

当你突然灵感爆棚,有一个自认为能火出地球的创意推广,正准备把链接分享出去的时候,可不能让那些长长的链接坏了事! 链接太长,会让人觉得不安全可靠,大大影响点击量; 字节太多,短信…

数据结构入门——08排序

1.排序 1.1什么是排序 排序是一种操作,通过比较记录中的关键字,将一组数据按照特定顺序(递增或递减)排列起来。排序在计算机科学中非常重要,因为它不仅有助于数据的快速检索,还能提高其他算法的性能。 1…

【前端】VUE动态引入组件 通过字符串动态渲染模板 动态生成组件

【前端】VUE动态引入组件 通过字符串动态渲染模板 应用场景&#xff1a; js增强 动态代码 扩展一类的 可以配合低代码平台等 灵活配置在线表单 在线列表等 适合灵活性 扩展性比较强的组件 VUE2 <template><div><textarea v-model"templateString"…

Axios请求使用params参数导致后端获取数据嵌套

问题重述&#xff1a; 首先看前端的axios请求这里我使用params参数将data数据传给后端 let data JSON.stringify(this.posts);axios.post("/blog_war_exploded/insertPost", {params: {data: data}}).then((res) > {if (res.data "success") {alert(…

ElasticSearch 高级查询语法Query DSL实战

文章目录 1.ES高级查询Query DSL1.1 match_all1.2 术语级别查询1.3 全文检索1.4 bool query布尔查询1.5 highlight高亮 2. ES 深度分页问题及针对不同需求下的解决方案2.1 什么是深度分页2.2 深度分页会带来什么问题2.3 深度分页问题的常见解决方案2.4 总结 1.ES高级查询Query …