PyTorch VGG16手写数字识别教程

server/2024/9/25 11:16:39/

手写数字识别教程:使用PyTorch和VGG16

1. 环境准备

确保你已安装以下库:

pip install torch torchvision
2. 导入必要的库
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
3. 数据预处理

我们需要对MNIST数据集进行转换,使其适合输入VGG16模型。由于VGG16的输入要求为224x224的图像,因此我们需要调整图像大小,并进行标准化处理。

transform = transforms.Compose([transforms.Resize((224, 224)),  # 将图像大小调整为224x224transforms.ToTensor(),  # 将图像转换为张量transforms.Normalize((0.5,), (0.5,))  # 标准化处理,均值和标准差
])# 下载并加载训练和测试数据集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
4. 定义VGG16模型

VGG16由多个卷积层和全连接层组成。我们将调整输入通道以适应单通道的MNIST数据。

class VGG16(nn.Module):def __init__(self):super(VGG16, self).__init__()# 定义卷积层self.vgg = nn.Sequential(nn.Conv2d(1, 64, kernel_size=3, padding=1),  # 将输入通道设置为1(灰度图)nn.ReLU(),  # 激活函数nn.MaxPool2d(kernel_size=2, stride=2),  # 最大池化层,减小特征图尺寸nn.Conv2d(64, 128, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(128, 256, kernel_size=3, padding=1),nn.ReLU(),nn.Conv2d(256, 256, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(256, 512, kernel_size=3, padding=1),nn.ReLU(),nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(),nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),)# 定义全连接层self.classifier = nn.Sequential(nn.Linear(512 * 7 * 7, 4096),  # 第一个全连接层nn.ReLU(),nn.Dropout(),  # 随机失活,防止过拟合nn.Linear(4096, 4096),  # 第二个全连接层nn.ReLU(),nn.Dropout(),nn.Linear(4096, 10)  # 输出层,10个类(数字0-9))def forward(self, x):x = self.vgg(x)  # 通过卷积层x = x.view(x.size(0), -1)  # 展平特征图x = self.classifier(x)  # 通过全连接层return x
5. 训练模型

我们将使用交叉熵损失函数和Adam优化器,并训练模型。

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  # 检测可用的设备
model = VGG16().to(device)  # 实例化模型并移动到设备上
criterion = nn.CrossEntropyLoss()  # 损失函数
optimizer = optim.Adam(model.parameters(), lr=0.001)  # 优化器# 训练循环
for epoch in range(5):  # 训练5个epochmodel.train()  # 设置为训练模式for images, labels in train_loader:images, labels = images.to(device), labels.to(device)  # 移动到设备optimizer.zero_grad()  # 清空梯度outputs = model(images)  # 前向传播loss = criterion(outputs, labels)  # 计算损失loss.backward()  # 反向传播optimizer.step()  # 更新参数print(f'Epoch [{epoch+1}/5], Loss: {loss.item():.4f}')  # 输出当前epoch的损失
6. 测试模型

在测试阶段,我们将计算模型的准确率。

model.eval()  # 设置为评估模式
with torch.no_grad():  # 禁用梯度计算correct = 0total = 0for images, labels in test_loader:images, labels = images.to(device), labels.to(device)  # 移动到设备outputs = model(images)  # 前向传播_, predicted = torch.max(outputs.data, 1)  # 获取预测结果total += labels.size(0)  # 统计总样本数correct += (predicted == labels).sum().item()  # 统计正确预测的数量print(f'Accuracy: {100 * correct / total:.2f}%')  # 输出准确率

总结

这个教程详细介绍了如何使用VGG16模型对MNIST数据集进行手写数字识别。通过调整网络参数和训练轮数,你可以进一步提高模型的性能。希望这个教程能帮助你更好地理解PyTorch及深度学习的应用!


http://www.ppmy.cn/server/121799.html

相关文章

uniapp实现触底分页加载

uniapp实现触底分页加载 一、页面结构 假设我们有一个简单的列表页面&#xff0c;包含一个用于展示列表项的<view>和一个用于触发加载更多操作的触底事件的区域。 <template><view><view v-for"item in listData" :key"item.id"&g…

清华大学开源 CogVideoX-5B-I2V 模型,以支持图生视频

CogVideoX 是源于清影的开源视频生成模型。 下表列出了我们在此版本中提供的视频生成模型的相关信息。 Model NameCogVideoX-2BCogVideoX-5BCogVideoX-5B-I2V (This Repository)Model DescriptionEntry-level model, balancing compatibility. Low cost for running and second…

K8s安装部署(v1.28)--超详细(cri-docker作为运行时)

1、准备环境 ip角色系统主机名cpumem192.168.40.129mastercentos7.9k8smaster48192.168.40.130node1centos7.9k8snode148192.168.40.131node2centos7.9k8snode248192.168.40.132node3centos7.9k8snode348 2、系统配置&#xff08;所有节点&#xff09; 重要&#xff1a;首先…

docker仓库

一、创建 Docker 镜像仓库 启动私有仓库容器 通过 Docker 启动一个私有镜像仓库。这里使用官方的 registry 镜像。 #docker run -d&#xff1a;在后台运行一个容器。 #-p 5000:5000&#xff1a;将容器内的 5000 端口映射到宿主机的 5000 端口&#xff0c;私有仓库将通过此端口…

【系统架构设计师】专题:基于架构的软件开发方法 ABSD(详细知识点及历年真题)

更多内容请见: 备考系统架构设计师-核心总结索引 文章目录 一、ABSD概述二、ABSD开发过程1、架构需求2、架构设计3、架构文档化4、架构复审5、架构实现6、对架构进行改变一、ABSD概述 基于体系结构(架构)的软件设计(Architecture-Based Software Design,ABSD) 方法是体系结构…

springframework Ordered接口学习

Ordered接口介绍 完整路径&#xff1a; org.springframework.core.Ordered Ordered 接口是 Spring 框架中的一个核心接口&#xff0c;用于定义对象的顺序。这个接口通常用于需要排序的组件&#xff0c;例如 Spring 中的 Bean、过滤器&#xff08;Filters&#xff09;、拦截器…

Java 之泛型详解

1. 泛型是什么&#xff1f; 泛型是 Java 中一种强大的机制&#xff0c;它允许你编写可以与多种数据类型一起工作的代码&#xff0c;而无需在编译时指定具体的类型。这样可以提高代码的灵活性、可读性和安全性。 2. 为什么要使用泛型&#xff1f; 泛型可以帮助我们编写更安全…

chorme浏览器 您的连接不是私密连接

‌当浏览器显示“您的连接不是私密连接&#xff0c;攻击者可能会试图从 localhost 窃取您的信息&#xff08;例如&#xff1a;密码、消息或信用卡信息&#xff09;”的警告时&#xff0c;这通常意味着您正在尝试访问的网站的安全证书存在问题&#xff0c;可能是因为它使用的是自…