使用 PyTorch 实现并训练 VGGNet 用于 MNIST 分类

server/2024/11/23 21:47:37/

        本文将展示如何使用 PyTorch 实现一个经典的 VGGNet 网络,并在 MNIST 数据集上进行训练和测试。我们将从模型构建开始,涵盖数据预处理、模型训练、评估、保存与加载模型,以及可视化预测结果等全过程。


1. VGGNet 模型的实现

        首先,我们实现一个标准的 VGGNet 网络。VGGNet 是一个深度卷积神经网络,它由多个卷积层和全连接层组成,广泛应用于图像分类任务。

VGGNet 模型结构:
  • 卷积层:VGGNet 采用了简单的结构,使用多个卷积层,每层卷积后跟一个 ReLU 激活函数和一个 最大池化 层。
  • 全连接层:经过卷积层提取特征后,VGGNet 会将特征图展平,并通过全连接层进行分类
import torch.nn as nnclass VGG(nn.Module):def __init__(self, num_classes=10, input_channels=1):"""VGG 网络的初始化方法,包含卷积层和全连接层。参数:- num_classes (int): 分类的类别数量,默认 10 (适用于 MNIST)- input_channels (int): 输入图片的通道数,默认 1 (适用于灰度图像)"""super(VGG, self).__init__()# 构建卷积层部分self.features = self._make_layers(input_channels)# 构建分类器部分self.classifier = self._make_classifier(num_classes)def _make_layers(self, input_channels):"""构建卷积层部分,通过堆叠卷积层、ReLU 激活和池化层来构建特征提取部分参数:- input_channels (int): 输入图像的通道数,默认为 1(灰度图)返回:- features (nn.Sequential): 包含卷积层和池化层的神经网络模块"""layers = []# 卷积块 1layers += self._conv_block(input_channels, 64)# 卷积块 2layers += self._conv_block(64, 128)# 卷积块 3layers += self._conv_block(128, 256)# 卷积块 4layers += self._conv_block(256, 512)# 将所有卷积块和池化层堆叠在一起return nn.Sequential(*layers)def _conv_block(self, in_channels, out_channels):"""创建一个卷积块,包含两个卷积层和一个最大池化层参数:- in_channels (int): 输入通道数- out_channels (int): 输出通道数返回:- block (list): 卷积块 [卷积层 + ReLU + 卷积层 + ReLU + 最大池化层]"""block = [nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2)]return blockdef _make_classifier(self, num_classes):"""构建全连接层部分,最后的输出层为分类层。参数:- num_classes (int): 分类类别数返回:- classifier (nn.Sequential): 包含全连接层和 Dropout 层的网络模块"""return nn.Sequential(nn.Linear(512 * 1 * 1, 4096),nn.ReLU(inplace=True),nn.Dropout(),nn.Linear(4096, 4096),nn.ReLU(inplace=True),nn.Dropout(),nn.Linear(4096, num_classes))def forward(self, x):"""前向传播方法,输入图像通过卷积层提取特征后再通过全连接层进行分类。参数:- x (Tensor): 输入的图像数据返回:- x (Tensor): 分类结果"""# 通过卷积层提取特征x = self.features(x)# 将特征图展平为一维向量x = x.view(x.size(0), -1)  # 这里将 4D 张量转换为 2D,保留 batch_size# 通过分类器进行最终分类x = self.classifier(x)return x

2. 训练模型

        使用 PyTorch 实现的 VGGNet 网络后,我们需要对模型进行训练。在这个过程中,我们会使用 AdamW 优化器、交叉熵损失 以及 混合精度训练 来提升训练效率。

import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm
from torch.cuda.amp import GradScaler, autocastdef get_data_loader(batch_size=64, num_workers=2):""" 获取 MNIST 数据加载器 """transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))])train_dataset = datasets.MNIST(root='D:/workspace/data', train=True, download=True, transform=transform)return DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)def initialize_model(device, num_classes=10):""" 初始化模型、优化器和损失函数 """model = VGG(num_classes=num_classes).to(device)optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-4)criterion = torch.nn.CrossEntropyLoss()return model, optimizer, criteriondef train_epoch(model, train_loader, device, criterion, optimizer, scaler):""" 训练一个 epoch,并返回该 epoch 的平均损失和准确率 """model.train()running_loss = 0.0correct = 0total = 0with tqdm(train_loader, desc="Training", unit="batch", ncols=100) as pbar:for data, target in pbar:data, target = data.to(device, non_blocking=True), target.to(device, non_blocking=True)optimizer.zero_grad()# 混合精度训练with autocast():output = model(data)loss = criterion(output, target)scaler.scale(loss).backward()scaler.step(optimizer)scaler.update()running_loss += loss.item()_, predicted = torch.max(output, 1)total += target.size(0)correct += (predicted == target).sum().item()# 更新进度条pbar.set_postfix(loss=running_loss / (total // len(data)), accuracy=100 * correct / total)return running_loss / len(train_loader), 100 * correct / total


3. 保存与加载模型

        在训练完成后,我们将保存模型,并在后续的测试过程中加载模型以进行评估。

def save_model(model, filepath='vggnet_mnist.pth'):""" 保存训练的模型到指定文件(覆盖之前的文件) """torch.save(model.state_dict(), filepath)print(f"Model saved to {filepath}")def load_model(model_path='vggnet_mnist.pth', num_classes=10):""" 加载预训练模型 """model = VGG(num_classes=num_classes)model.load_state_dict(torch.load(model_path))return model


4. 评估模型与可视化结果

        我们可以加载训练好的模型并对其在测试集上的表现进行评估。我们还可以通过 matplotlib 可视化前六张测试图像的预测结果。

import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision import datasets, transformsdef get_test_loader(batch_size=64, data_dir='D:/workspace/data'):""" 获取 MNIST 测试数据加载器 """transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))])test_dataset = datasets.MNIST(root=data_dir, train=False, download=True, transform=transform)return DataLoader(test_dataset, batch_size=batch_size, shuffle=False)def evaluate_model(model, test_loader, device):""" 评估模型并返回准确率和前六张图片的预测与标签 """model.eval()correct = 0total = 0images, labels, preds = [], [], []with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)_, predicted = torch.max(output, 1)total += target.size(0)correct += (predicted == target).sum().item()# 记录前六张图片及其标签和预测if len(images) < 6:batch_size = data.size(0)for i in range(min(6 - len(images), batch_size)):images.append(data[i].cpu())labels.append(target[i].cpu())preds.append(predicted[i].cpu())accuracy = 100 * correct / totalreturn accuracy, images, labels, predsdef display_images(images, labels, preds):""" 可视化前六张图片及其真实标签和预测标签 """fig, axes = plt.subplots(2, 3, figsize=(10, 6))axes = axes.ravel()for i in range(6):axes[i].imshow(images[i][0].squeeze(), cmap='gray')  # MNIST 是单通道灰度图像axes[i].set_title(f"True: {labels[i].item()}, Pred: {preds[i].item()}")axes[i].axis('off')  # 不显示坐标轴plt.show()


5. 总结

        通过以上步骤,我们成功实现并训练了一个 VGGNet 网络,并在 MNIST 数据集上进行了测试与评估。我们使用了混合精度训练来加速训练过程,并通过可视化展示了模型的预测效果。

        这种方法可以推广到其他数据集和任务中,例如 CIFAR-10、CIFAR-100 或其他图像分类问题。

完整项目:

qxd-ljy/VGGNet-PyTorch: 使用PyTorch实现VGGNet进行MINST图像分类icon-default.png?t=O83Ahttps://github.com/qxd-ljy/VGGNet-PyTorchVGGNet-PyTorch: 使用PyTorch实现VGGNet进行MINST图像分类icon-default.png?t=O83Ahttps://gitee.com/qxdlll/vggnet-py-torch


http://www.ppmy.cn/server/144359.html

相关文章

Node.js笔记(三)局域网聊天室构建1

目标 用户与服务端建立通信&#xff0c;服务端能检测到用户端的连接信息 代码 JS部分<chatroom.js> const express require(express) const http require(http) const {Server} require(socket.io)const app express() const se…

table元素纯css无限滚动,流畅过度

<template><div class"monitor-table-container"><table class"monitor-table"><thead><th>标题</th><th>标题</th><th>标题</th><th>标题</th></thead><tbody ref&quo…

RabbitMQ高可用延迟消息惰性队列

目录 生产者确认 消息持久化 消费者确认 TTL延迟队列 TTL延迟消息 惰性队列 生产者确认 生产者确认就是&#xff1a;发送消息的人&#xff0c;要确保消息发送给了消息队列&#xff0c;分别是确保到了交换机&#xff0c;确保到了消息队列这两步。 1、在发送消息服务的ap…

基于 NCD 与优化函数结合的非线性优化 PID 控制

基于 NCD 与优化函数结合的非线性优化 PID 控制 1. 引言 NCD&#xff08;Normalized Coprime Factorization Distance&#xff09;优化是一种用于非线性系统的先进控制方法。通过将 NCD 指标与优化算法结合&#xff0c;可以在动态调整控制参数的同时优化控制器性能。此方法特别…

Applied Intelligence投稿

一、关于手稿格式&#xff1a; 1、该期刊是一个二区的&#xff0c;模板使用Springer nature格式&#xff0c; 期刊投稿要求&#xff0c;详细期刊投稿指南&#xff0c;大部分按Soringernature模板即可&#xff0c;图片表格声明参考文献命名要求需注意。 2、参考文献&#xff…

自动驾驶系列—探索自动驾驶数据管理的核心技术与平台

&#x1f31f;&#x1f31f; 欢迎来到我的技术小筑&#xff0c;一个专为技术探索者打造的交流空间。在这里&#xff0c;我们不仅分享代码的智慧&#xff0c;还探讨技术的深度与广度。无论您是资深开发者还是技术新手&#xff0c;这里都有一片属于您的天空。让我们在知识的海洋中…

Kotlin return与return@forEachIndexed

Kotlin return与returnforEachIndexed fun main() {val data arrayOf(0, 1, 2, 3, 4)println("a")data.forEachIndexed { index, v ->if (v 2) {//类似while循环中的continue//跳过&#xff0c;继续下一个forEachIndexed迭代returnforEachIndexed}println("…

【手写一个spring】spring源码的简单实现--初始化机制,回调机制

文章目录 A. 初始化机制实现初始化机制的方法1.实现InitializingBean接口 B. 回调机制(Aware)初始化机制和回调机制之间的区别 A. 初始化机制 在Spring框架中&#xff0c;初始化机制是Bean生命周期管理的一个重要组成部分。它确保了Bean在创建和依赖注入完成后&#xff0c;能够…