使用 PyTorch 实现并训练 VGGNet 用于 MNIST 分类

embedded/2024/11/23 21:50:43/

        本文将展示如何使用 PyTorch 实现一个经典的 VGGNet 网络,并在 MNIST 数据集上进行训练和测试。我们将从模型构建开始,涵盖数据预处理、模型训练、评估、保存与加载模型,以及可视化预测结果等全过程。


1. VGGNet 模型的实现

        首先,我们实现一个标准的 VGGNet 网络。VGGNet 是一个深度卷积神经网络,它由多个卷积层和全连接层组成,广泛应用于图像分类任务。

VGGNet 模型结构:
  • 卷积层:VGGNet 采用了简单的结构,使用多个卷积层,每层卷积后跟一个 ReLU 激活函数和一个 最大池化 层。
  • 全连接层:经过卷积层提取特征后,VGGNet 会将特征图展平,并通过全连接层进行分类
import torch.nn as nnclass VGG(nn.Module):def __init__(self, num_classes=10, input_channels=1):"""VGG 网络的初始化方法,包含卷积层和全连接层。参数:- num_classes (int): 分类的类别数量,默认 10 (适用于 MNIST)- input_channels (int): 输入图片的通道数,默认 1 (适用于灰度图像)"""super(VGG, self).__init__()# 构建卷积层部分self.features = self._make_layers(input_channels)# 构建分类器部分self.classifier = self._make_classifier(num_classes)def _make_layers(self, input_channels):"""构建卷积层部分,通过堆叠卷积层、ReLU 激活和池化层来构建特征提取部分参数:- input_channels (int): 输入图像的通道数,默认为 1(灰度图)返回:- features (nn.Sequential): 包含卷积层和池化层的神经网络模块"""layers = []# 卷积块 1layers += self._conv_block(input_channels, 64)# 卷积块 2layers += self._conv_block(64, 128)# 卷积块 3layers += self._conv_block(128, 256)# 卷积块 4layers += self._conv_block(256, 512)# 将所有卷积块和池化层堆叠在一起return nn.Sequential(*layers)def _conv_block(self, in_channels, out_channels):"""创建一个卷积块,包含两个卷积层和一个最大池化层参数:- in_channels (int): 输入通道数- out_channels (int): 输出通道数返回:- block (list): 卷积块 [卷积层 + ReLU + 卷积层 + ReLU + 最大池化层]"""block = [nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2)]return blockdef _make_classifier(self, num_classes):"""构建全连接层部分,最后的输出层为分类层。参数:- num_classes (int): 分类类别数返回:- classifier (nn.Sequential): 包含全连接层和 Dropout 层的网络模块"""return nn.Sequential(nn.Linear(512 * 1 * 1, 4096),nn.ReLU(inplace=True),nn.Dropout(),nn.Linear(4096, 4096),nn.ReLU(inplace=True),nn.Dropout(),nn.Linear(4096, num_classes))def forward(self, x):"""前向传播方法,输入图像通过卷积层提取特征后再通过全连接层进行分类。参数:- x (Tensor): 输入的图像数据返回:- x (Tensor): 分类结果"""# 通过卷积层提取特征x = self.features(x)# 将特征图展平为一维向量x = x.view(x.size(0), -1)  # 这里将 4D 张量转换为 2D,保留 batch_size# 通过分类器进行最终分类x = self.classifier(x)return x

2. 训练模型

        使用 PyTorch 实现的 VGGNet 网络后,我们需要对模型进行训练。在这个过程中,我们会使用 AdamW 优化器、交叉熵损失 以及 混合精度训练 来提升训练效率。

import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm
from torch.cuda.amp import GradScaler, autocastdef get_data_loader(batch_size=64, num_workers=2):""" 获取 MNIST 数据加载器 """transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))])train_dataset = datasets.MNIST(root='D:/workspace/data', train=True, download=True, transform=transform)return DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)def initialize_model(device, num_classes=10):""" 初始化模型、优化器和损失函数 """model = VGG(num_classes=num_classes).to(device)optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-4)criterion = torch.nn.CrossEntropyLoss()return model, optimizer, criteriondef train_epoch(model, train_loader, device, criterion, optimizer, scaler):""" 训练一个 epoch,并返回该 epoch 的平均损失和准确率 """model.train()running_loss = 0.0correct = 0total = 0with tqdm(train_loader, desc="Training", unit="batch", ncols=100) as pbar:for data, target in pbar:data, target = data.to(device, non_blocking=True), target.to(device, non_blocking=True)optimizer.zero_grad()# 混合精度训练with autocast():output = model(data)loss = criterion(output, target)scaler.scale(loss).backward()scaler.step(optimizer)scaler.update()running_loss += loss.item()_, predicted = torch.max(output, 1)total += target.size(0)correct += (predicted == target).sum().item()# 更新进度条pbar.set_postfix(loss=running_loss / (total // len(data)), accuracy=100 * correct / total)return running_loss / len(train_loader), 100 * correct / total


3. 保存与加载模型

        在训练完成后,我们将保存模型,并在后续的测试过程中加载模型以进行评估。

def save_model(model, filepath='vggnet_mnist.pth'):""" 保存训练的模型到指定文件(覆盖之前的文件) """torch.save(model.state_dict(), filepath)print(f"Model saved to {filepath}")def load_model(model_path='vggnet_mnist.pth', num_classes=10):""" 加载预训练模型 """model = VGG(num_classes=num_classes)model.load_state_dict(torch.load(model_path))return model


4. 评估模型与可视化结果

        我们可以加载训练好的模型并对其在测试集上的表现进行评估。我们还可以通过 matplotlib 可视化前六张测试图像的预测结果。

import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision import datasets, transformsdef get_test_loader(batch_size=64, data_dir='D:/workspace/data'):""" 获取 MNIST 测试数据加载器 """transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))])test_dataset = datasets.MNIST(root=data_dir, train=False, download=True, transform=transform)return DataLoader(test_dataset, batch_size=batch_size, shuffle=False)def evaluate_model(model, test_loader, device):""" 评估模型并返回准确率和前六张图片的预测与标签 """model.eval()correct = 0total = 0images, labels, preds = [], [], []with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)_, predicted = torch.max(output, 1)total += target.size(0)correct += (predicted == target).sum().item()# 记录前六张图片及其标签和预测if len(images) < 6:batch_size = data.size(0)for i in range(min(6 - len(images), batch_size)):images.append(data[i].cpu())labels.append(target[i].cpu())preds.append(predicted[i].cpu())accuracy = 100 * correct / totalreturn accuracy, images, labels, predsdef display_images(images, labels, preds):""" 可视化前六张图片及其真实标签和预测标签 """fig, axes = plt.subplots(2, 3, figsize=(10, 6))axes = axes.ravel()for i in range(6):axes[i].imshow(images[i][0].squeeze(), cmap='gray')  # MNIST 是单通道灰度图像axes[i].set_title(f"True: {labels[i].item()}, Pred: {preds[i].item()}")axes[i].axis('off')  # 不显示坐标轴plt.show()


5. 总结

        通过以上步骤,我们成功实现并训练了一个 VGGNet 网络,并在 MNIST 数据集上进行了测试与评估。我们使用了混合精度训练来加速训练过程,并通过可视化展示了模型的预测效果。

        这种方法可以推广到其他数据集和任务中,例如 CIFAR-10、CIFAR-100 或其他图像分类问题。

完整项目:

qxd-ljy/VGGNet-PyTorch: 使用PyTorch实现VGGNet进行MINST图像分类icon-default.png?t=O83Ahttps://github.com/qxd-ljy/VGGNet-PyTorchVGGNet-PyTorch: 使用PyTorch实现VGGNet进行MINST图像分类icon-default.png?t=O83Ahttps://gitee.com/qxdlll/vggnet-py-torch


http://www.ppmy.cn/embedded/139951.html

相关文章

Docker nginx容器高可用(Keepalived)

概述 Keepalived主要作用&#xff1a;在多个服务器上安装Keepalived并且为各个服务器的Keepalived指定相同的虚拟IP。该虚拟IP根据服务器上Keepalived配置的角色、优先级 决定出现在其中一台服务器上&#xff0c;当拥有虚拟IP的服务器Keepalived进程被杀死后&#xff0c;那么此…

「Mac玩转仓颉内测版27」基础篇7 - 字符串类型详解

本篇将介绍 Cangjie 中的字符串类型&#xff0c;包括字符串的定义、字面量形式、插值表达、常用操作及应用场景&#xff0c;帮助开发者熟练掌握字符串的使用。 关键词 字符串类型定义字符串字面量插值字符串字符串拼接常用操作 一、字符串类型概述 在 Cangjie 中&#xff0c;…

长文解读:OSAID 1.0,全球首个开源AI标准,审视探讨其对AI行业实践开源的影响

引言 在人工智能&#xff08;AI&#xff09;的快速发展中&#xff0c;开源已经成为推动技术创新和知识共享的重要力量。随着AI技术的广泛应用&#xff0c;确保其开放性、透明性和可访问性变得至关重要。在这样的背景下&#xff0c;OSAID 1.0&#xff08;Open Source AI Defini…

深度学习:GPT-1的MindSpore实践

GPT-1简介 GPT-1&#xff08;Generative Pre-trained Transformer&#xff09;是2018年由Open AI提出的一个结合预训练和微调的用于解决文本理解和文本生成任务的模型。它的基础是Transformer架构&#xff0c;具有如下创新点&#xff1a; NLP领域的迁移学习&#xff1a;通过最…

RAG与微调:大模型落地的最佳路径选择(文末赠书)

一、大模型技术发展现状 自2022年底ChatGPT掀起AI革命以来&#xff0c;大语言模型&#xff08;LLM&#xff09;技术快速迭代发展&#xff0c;从GPT-4到Claude 2&#xff0c;从文心一言到通义千问&#xff0c;大模型技术以惊人的速度发展。然而&#xff0c;在企业实际应用场景中…

圣诞节秘诀

&#x1f570;️你想在2024年圣诞节脱颖而出吗&#xff1f;利用我们的数据洞察&#xff0c;发现今年最受欢迎的礼物&#xff01;无论是在亚马逊、速卖通、Shopify还是直销平台上&#xff0c;我们的排行榜都将帮助您找到最畅销和最受欢迎的产品。立即优化您的库存&#xff0c;以…

Nexus搭建go私有仓库,加速下载go依赖包

一、搭建go私库 本文我们梳理一下go依赖包的私库搭建以及使用。 它只分为proxy和group两种仓库&#xff0c;这一点和maven仓库有所不同。 1、创建Blob Stores 为了区分不同的私库依赖包&#xff0c;存储的位置分隔开。 2、新建go proxy官网 Remote storage&#xff1a;htt…

TM1可视化解决方案:企业增效降本的智控大脑

您是否还费时费力整合从各部门收集不同来源的数据资料&#xff0c;或是分析财务数据时在Excel和各可视化软件之间来回切换&#xff1f; 让我们看看咨询顾问小C (Cubewiser) 如何使用 TM1 系统的展示平台—— Apliqo UX 对企业运营成本及费用进行智能管控。 预实分析&#xff…