Python机器学习——利用Keras和基础神经网络进行手写数字识别(MNIST数据集)

server/2024/9/23 3:43:13/

Python机器学习——利用Keras和基础神经网络进行手写数字识别(MNIST数据集)

  • 配置环境
    • 创建虚拟环境
    • 安装功能包并进环境
  • 编程
    • 1. 导入功能包
    • 2. 加载数据集
    • 3. 数据预处理
    • 4. 构建神经网络
    • 5. 神经网络训练
    • 6. 测试模型训练效果

配置环境

首先安装Anaconda,随便找个视频或者教程按照下

创建虚拟环境

conda env list 查看虚拟环境 (*代表在哪个环境下)
conda create -n 环境名字 python=版本
conda activate yixuepytorch 进入我们创建好的虚拟环境
conda list 查看当下环境下,有哪些功能包
conda remove -n 虚拟环境名字 --all 删除所选环境

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

安装功能包并进环境

通过pip install xxx按照下我们需要的功能包

pip install numpy
pip install pandas
pip install keras
pip install tensorflow

在这里插入图片描述

输入jupyter notebook进入notebook并创建新Notebook进行编程
在这里插入图片描述

在这里插入图片描述

编程

1. 导入功能包

python"># 导入功能包
import numpy as np # 数学工具箱
import pandas as pd # 数据处理工具箱
from keras.datasets import mnist # 从 Keras中导入 mnist数据集

2. 加载数据集

python"># 查看数据集
mnist.load_data()
python">(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
print('训练集图片: ', train_images.shape)
print('训练集标签: ', train_labels.shape)
print('测试集图片: ', test_images.shape)
print('测试集标签: ', test_labels.shape)

keras 中自带的mnist模块,加载数据集load_data进来,分别赋值给四个变量。
其中:train_images保存用来训练的图像,train_labels是与之对应的标签。如果图像中的数字是1,那么标签就是1。test_images和test_labels分别为用来验证的图像和标签,也就是验证集。训练完神经网络后,可以使用验证集中的数据进行验证。

3. 数据预处理

python"># 用keras.utils工具箱的类别转换工具,作用是将样本标签转为one-hot编码
from keras.utils import to_categorical
# 给标签增加维度,使其满足模型的需要
# 原始标签,比如训练集标签的维度信息是[60000, 28, 28, 1]
train_images = train_images.reshape((60000, 28*28)).astype('float') # 60000张训练图像,每张图像的长宽均为28个像素
test_images = test_images.reshape((10000, 28*28)).astype('float') # 10000张验证图像,每张图像的长宽均为28个像素
# 特征转换为one-hot编码
train_labels = to_categorical(train_labels, 10)
test_labels = to_categorical(test_labels, 10)

one-hot 编码:
对于输出 0-9 这10个标签而言,每个标签的地位应该是相等的,并不存在标签数字2大于数字1的情况。因此,在大部分情况下,都需要将标签转换为 one-hot 编码,也就独热编码,这样标签之间便没有任何大小而言。
这个例子中,数字 0-9 转换为的独热编码为:
array([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]]
每一行的向量代表一个标签。

4. 构建神经网络

python"># 从keras中导入模型,神经元等操作
from keras import models, layers, regularizers
# 构建一个最基础的连续的模型,所谓连续,就是一层接着一层(也就是指的神经网络一层一层穿起来)
network = models.Sequential()
# 隐藏层, 设置128个神经元,使用relu作为激活函数,输入尺寸是784,进行l1正则化进行泛化处理
network.add(layers.Dense(units=128, activation='relu', input_shape=(28*28, ), kernel_regularizer=regularizers.l1(0.0001)))
# 隐藏层, 设置32个神经元,使用relu作为激活函数,进行l1正则化进行泛化处理
network.add(layers.Dense(units=32, activation='relu', kernel_regularizer=regularizers.l1(0.0001)))
# 输出层是10个神经元,用softmax进行多分类
network.add(layers.Dense(units=10, activation='softmax'))

5. 神经网络训练

python">from keras.optimizers import RMSprop
# 设置编译,optimizer优化器为RMSprop自适应学习率,损失函数使用的是交叉熵,模型评估标准是获取模型准确率
network.compile(optimizer=RMSprop(0.001), loss='categorical_crossentropy', metrics=['accuracy'])
# 训练网络,用fit函数, epochs表示训练多少个回合, batch_size表示每次训练给多大的数据,verbose=2是指输出更详细的训练信息,包括每一轮迭代的损失值
network.fit(train_images, train_labels, epochs=20, batch_size=128, verbose=2)

6. 测试模型训练效果

python"># 测试集上测试效果
test_loss, test_accuracy = network.evaluate(test_images, test_labels)
print("test_loss:", test_loss, "test_accuracy:", test_accuracy)

输出:在这里插入图片描述


http://www.ppmy.cn/server/114531.html

相关文章

python怎么输入中文

解决中文输入的两种应用: 在脚本中加语言编码声明 “-*- coding: uft-8 -*-” 应用一:print中出现中文 方法一:用unicode( , encoding utf-8 ) 或者 unicode(" ", encoding "utf-8" )。 方法二:用u 或者…

手机到了外地ip地址就变了吗

手机到了外地IP地址就变了吗?随着智能手机的普及,人们越来越频繁地使用手机进行各种网络活动。然而,关于手机IP地址是否会随着地理位置的变化而改变,许多用户仍心存疑惑。本文将深入探讨这一问题,揭示IP地址变化的奥秘…

SpringBoot2.7 + Nacos + GateWay

1. pom包&#xff0c;主要是记录版本 <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM/4.0.0" xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance"xsi:schemaLocation"http:/…

Spring-di基本使用

SpringDI 1 基础环境准备 流程如下 1.在自己的工程中建一个module用于SpringDi注入 2.导入spring相关的依赖 <dependencies><!--导入spring-context依赖--><dependency><groupId>org.springframework</groupId><artifactId>spring-cont…

Linux内核的调试(TODO)

&#xff08;TODO&#xff09; 参考&#xff1a; 万字长文&#xff0c;汇总 Linux 内核调试的方法

[项目][WebServer][项目介绍及知识铺垫][下]详细讲解

目录 1.HTTP请求与相应1.整体2.细节说明请求响应 3.请求方法GET[重点]POST[重点]PUTHEADDELETEOPTIONSTRACECONNECT总结方法 4.HTTP响应 -- 状态码及其描述5.常见状态码2XX 成功 表明请求结果被正确处理了3XX 成功 浏览器需要执行某些特殊的处理以正确处理请求4XX 表明客户端发…

【专业解析】电脑文件夹打不开的深层原因与高效数据恢复策略

一、初探文件夹无法打开的困境 在日常使用电脑的过程中&#xff0c;遇到文件夹突然无法打开的情况&#xff0c;往往令人感到困扰。这种情况可能由多种因素引发&#xff0c;包括但不限于文件系统损坏、磁盘错误、软件冲突、权限设置不当或是文件夹本身遭遇未知逻辑错误等。当急…

使用paddlerocr识别固定颜色验证码

1 引言 本文使用opencv和paddlerocr识别出固定颜色的验证码&#xff0c;原理不解释&#xff0c;安装包的方法自行查找&#xff0c;只提供代码和思路。 1 使用opencv对特定颜色区域进行提取2 使用paddlerocr识别并输出验证码 2 代码 2.1 读取图片&#xff0c;提取蓝色区域 …