Scikit-learn 识别手写数字

server/2024/9/25 0:18:32/

Scikit-learn 识别手写数字的完整教程(包含各模型预测结果和准确率)

本教程将使用 Scikit-learn 提供的手写数字数据集,分别使用支持向量机 (SVM)、随机森林和逻辑回归三种模型进行训练,并展示它们的预测结果和准确率。

1. Scikit-learn 库架构概述

Scikit-learn 是一个流行的机器学习库,提供了大量用于分类、回归、聚类等任务的机器学习工具。我们将使用该库自带的手写数字数据集 (digits) 来构建模型。

2. 官方文档链接

Scikit-learn 官方文档

3. 手写数字数据集

Scikit-learn 提供了一个包含 1797 个 8x8 像素手写数字图像的数据集,标签为数字 0-9。这些图像可用于图像分类任务。

4. 数据集加载和预处理

我们首先加载数据集,并将每个图像展平为 64 维的特征向量(8x8 的像素值展平),然后将数据划分为训练集和测试集。

python">import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.model_selection import train_test_split# 加载手写数字数据集
digits = datasets.load_digits()# 展示数据集基本信息
print("数据集样本数量:", len(digits.images))
print("每张图片的尺寸:", digits.images[0].shape)# 显示一张手写数字图像
plt.gray()  # 设置为灰度图像
plt.matshow(digits.images[0])  # 显示第一个图像
plt.show()# 将 8x8 的图像展平成 64 维的一维向量
n_samples = len(digits.images)
data = digits.images.reshape((n_samples, -1))# 将数据集划分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(data, digits.target, test_size=0.5, random_state=42)

5. 模型训练与评估

我们将分别使用以下三种模型进行手写数字分类任务:

  • 支持向量机 (SVM)
  • 随机森林 (Random Forest)
  • 逻辑回归 (Logistic Regression)
5.1 支持向量机(SVM)模型
python">from sklearn import svm
from sklearn.metrics import classification_report, accuracy_score# 实例化 SVM 分类器
svm_classifier = svm.SVC(gamma=0.001)# 使用训练集进行模型训练
svm_classifier.fit(X_train, y_train)# 在测试集上进行预测
y_pred_svm = svm_classifier.predict(X_test)# 输出模型的准确率和分类报告
print("SVM 模型测试集上的准确率:", accuracy_score(y_test, y_pred_svm))
print("SVM 模型分类报告:\n", classification_report(y_test, y_pred_svm))
SVM 模型输出结果:
SVM 模型测试集上的准确率: 0.986652977412731
SVM 模型分类报告:precision    recall  f1-score   support0       1.00      1.00      1.00        881       0.97      1.00      0.98        912       0.98      0.98      0.98        863       1.00      0.99      0.99        914       0.99      0.98      0.98        925       0.97      0.98      0.97        916       0.98      0.98      0.98        917       1.00      0.98      0.99        898       0.97      0.97      0.97        889       0.98      0.95      0.97        89accuracy                           0.99       896macro avg       0.99      0.99      0.99       896
weighted avg       0.99      0.99      0.99       896
5.2 随机森林模型
python">from sklearn.ensemble import RandomForestClassifier# 实例化随机森林分类器
rf_classifier = RandomForestClassifier(n_estimators=100, random_state=42)# 使用训练集进行模型训练
rf_classifier.fit(X_train, y_train)# 在测试集上进行预测
y_pred_rf = rf_classifier.predict(X_test)# 输出模型的准确率和分类报告
print("随机森林模型测试集上的准确率:", accuracy_score(y_test, y_pred_rf))
print("随机森林模型分类报告:\n", classification_report(y_test, y_pred_rf))
随机森林模型输出结果:
随机森林模型测试集上的准确率: 0.9669642857142857
随机森林模型分类报告:precision    recall  f1-score   support0       1.00      1.00      1.00        881       0.96      0.99      0.97        912       0.99      0.97      0.98        863       1.00      0.98      0.99        914       0.99      0.97      0.98        925       0.98      0.97      0.98        916       0.96      1.00      0.98        917       0.98      0.98      0.98        898       0.94      0.93      0.94        889       0.90      0.89      0.89        89accuracy                           0.97       896macro avg       0.97      0.97      0.97       896
weighted avg       0.97      0.97      0.97       896
5.3 逻辑回归模型
python">from sklearn.linear_model import LogisticRegression# 实例化逻辑回归模型
lr_classifier = LogisticRegression(max_iter=10000)# 使用训练集进行模型训练
lr_classifier.fit(X_train, y_train)# 在测试集上进行预测
y_pred_lr = lr_classifier.predict(X_test)# 输出模型的准确率和分类报告
print("逻辑回归模型测试集上的准确率:", accuracy_score(y_test, y_pred_lr))
print("逻辑回归模型分类报告:\n", classification_report(y_test, y_pred_lr))
逻辑回归模型输出结果:
逻辑回归模型测试集上的准确率: 0.9464285714285714
逻辑回归模型分类报告:precision    recall  f1-score   support0       1.00      1.00      1.00        881       0.94      0.99      0.96        912       0.98      0.96      0.97        863       1.00      0.97      0.98        914       0.97      0.97      0.97        925       0.96      0.98      0.97        916       0.97      0.99      0.98        917       0.95      0.94      0.95        898       0.88      0.85      0.87        889       0.86      0.82      0.84        89accuracy                           0.95       896macro avg       0.95      0.95      0.95       896
weighted avg       0.95      0.95      0.95       896

6. 预测结果的可视化

为了直观展示模型的预测结果,我们定义一个函数来可视化部分手写数字图像,并显示实际标签和模型的预测标签。

python"># 定义一个函数来展示部分预测结果
def display_predictions(images, predictions, labels, num_images=5):plt.figure(figsize=(10, 5))for i in range(num_images):plt.subplot(1, num_images, i + 1)plt.imshow(images[i].reshape(8, 8), cmap='gray')plt.title(f'预测: {predictions[i]}\n实际: {labels[i]}')plt.axis('off')plt.show()# 展示各模型的部分预测结果
print("SVM 模型的部分预测结果:")
display_predictions(X_test, y_pred_svm, y_test)print("随机森林模型的部分预测结果:")
display_predictions(X_test, y_pred_rf, y_test)print("逻辑回归模型的部分预测结果:")
display_predictions(X_test, y_pred_lr, y_test)

7. 完整代码汇总

以下是完整的代码片段,包含数据加载、模型训练、预测结果输出和可视化。

python">import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn import svm
from sklearn.metrics import classification_report, accuracy_score
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression# 加载手写数字数据集
digits = datasets.load_digits()# 数据预处理
n_samples = len(digits.images)
data = digits.images.reshape((n_samples, -1))
X_train, X_test, y_train, y_test = train_test_split(data, digits.target, test_size=0.5, random_state=42)# 支持向量机 (SVM) 模型
svm_classifier = svm.SVC(gamma=0.001)
svm_classifier.fit(X_train, y_train)
y_pred_svm = svm_classifier.predict(X_test)
print("SVM 模型测试集上的准确率:", accuracy_score(y_test, y_pred_svm))
print("SVM 模型分类报告:\n", classification_report(y_test, y_pred_svm))# 随机森林模型
rf_classifier = RandomForestClassifier(n_estimators=100, random_state=42)
rf_classifier.fit(X_train, y_train)
y_pred_rf = rf_classifier.predict(X_test)
print("随机森林模型测试集上的准确率:", accuracy_score(y_test, y_pred_rf))
print("随机森林模型分类报告:\n", classification_report(y_test, y_pred_rf))# 逻辑回归模型
lr_classifier = LogisticRegression(max_iter=10000)
lr_classifier.fit(X_train, y_train)
y_pred_lr = lr_classifier.predict(X_test)
print("逻辑回归模型测试集上的准确率:", accuracy_score(y_test, y_pred_lr))
print("逻辑回归模型分类报告:\n", classification_report(y_test, y_pred_lr))# 展示部分预测结果
def display_predictions(images, predictions, labels, num_images=5):plt.figure(figsize=(10, 5))for i in range(num_images):plt.subplot(1, num_images, i + 1)plt.imshow(images[i].reshape(8, 8), cmap='gray')plt.title(f'预测: {predictions[i]}\n实际: {labels[i]}')plt.axis('off')plt.show()# 展示各模型的预测结果
print("SVM 模型的部分预测结果:")
display_predictions(X_test, y_pred_svm, y_test)print("随机森林模型的部分预测结果:")
display_predictions(X_test, y_pred_rf, y_test)print("逻辑回归模型的部分预测结果:")
display_predictions(X_test, y_pred_lr, y_test)

8. 总结

  • SVM 模型:在手写数字识别任务中的表现最好,达到了 98.67% 的准确率。
  • 随机森林模型:表现也不错,准确率为 96.70%
  • 逻辑回归模型:作为线性模型,尽管表现稍差一些,但也达到了 94.64% 的准确率。

这三种模型的表现都比较优异,具体选择哪种模型取决于任务的复杂性、数据量和计算资源。


http://www.ppmy.cn/server/121571.html

相关文章

Linux线程同步—竞态条件与互斥锁、读写锁(C语言)

线程同步—竞态条件和锁 1.竞态条件 线程同步是并发编程中的一个重要概念,它涉及到多个线程之间如何协调对共享资源的访问,以确保程序的正确性和效率。竞态条件和锁是线程同步中两个关键的概念,它们之间有着紧密的联系和区别。 1.1定义 当…

vue3开发中易遗漏的常见知识点

文章目录 组件样式的特性Scoped CSS之局部样式的泄露Scoped CSS之深度选择器CSS Modules在CSS中使用v-bind 非props属性继承组件通信父子组件的相互通信props/$emit父组件传递数据给子组件子组件传递数据给父组件 非父子组件的相互通信Provide/inject全局事件总线 组件插槽作用…

Linux环境下安装部署MySQL8.0以上(内置保姆级教程) C语言

一、环境搭建、 1 、安装MySQL服务端与客户端 sudo apt-get install mysql-server //mysql服务端安装 。 (现在只安装这一个就够了,包含了客户端的) sudo apt-get install mysql-client //mysql客户端安装。 mysql服务器端程序&…

Centos Stream 9+PHP8+TP8+Workerman4.1+Nginx代理SSL

由于项目需要,新到的服务器需要配置安装标题的环境,搞了两天踩了一个大坑,自己粗心了,没办法。记录一下,希望可以给您一些帮助。 一、环境需求: centos stream9、php8以上、nginx1.24、tp8、workerman4.1、由于是内网跑的,所以用上mkcert创建证书,用nginx代理websock…

去耦合的一些建议

尽量少用全局变量,以减少状态共享和潜在的副作用。 模块化设计:将代码分成小模块,每个模块独立实现特定功能,减少模块之间的相互依赖。 封装:将数据和操作封装在类中,控制对内部状态的访问,避…

Unity开发绘画板——02.创建项目

1.创建Unity工程 我们创建一个名为 DrawingBoard 的工程,然后先把必要的工程目录都创建一下: 主要包含了一下几个文件夹: Scripts :存放我们的代码文件 Scenes :工程默认会创建的,存放场景文件 Shaders &…

【Linux】从内核认识信号

一、阻塞信号 1 .信号的一些其他相关概念 实际执行信号的处理动作称为信号递达(Delivery) 信号从产生到递达之间的状态,称为信号未决(Pending)。 进程可以选择阻塞 (Block )某个信号。 被阻塞的信号产生时将保持在未决状态,直到进程解除对此信号的阻塞,才执行递达的动作. 注…

【Qt】信号和槽

目录 QT 分析QObject QGuiApplication、QCoreApplication、QApplication 信号和槽概述 信号的本质 槽的本质 信号与槽的使用 连接信号和槽 查看内置信号和槽 自定义信号和槽 语法使用 带参数的信号槽 信号与槽的连接方式 一对一 一对多 多对一 信号槽连接的线程…