实战2. 利用Pytorch解决 CIFAR 数据集中的图像分类为 10 类的问题——提高精度

news/2025/3/19 11:57:08/

实战2. 利用Pytorch解决 CIFAR 数据集中的图像分类为 10 类的问题——提高精度

  • 前期准备
  • 加载数据
  • 建立模型
  • 模型训练
  • 质量指标

让我们回到图像分类问题 CIFAR。

你的主要任务:实现整个模型训练流程,并在测试样本上获得良好的准确度指标值。 任务积分:

  • 0,如果测试样本的准确度<0.5;
  • 如果测试样本的准确度 >0.5 且 <0.6,则为 0.5;
  • 1,如果测试样本的准确度>0.6;

本任务中用于训练模型的代码已完整实现。您需要做的就是为神经网络类编写代码并试验参数以获得良好的质量。除此之外,你要保证模型不超过300M的内存使用!

前期准备

python">import numpy as npimport torch
from torch import nn
from torch.nn import functional as Fimport torchvision
from torchvision import datasets, transformsfrom matplotlib import pyplot as plt
from IPython.display import clear_output
python">import numpy as np
np.random.seed(42)class LinearRegression:def init(self, **kwargs):self.coef_ = Nonepassdef fit(self, x: np.array, y: np.array):# 添加一列全为1的偏置项x = np.concatenate([np.ones((x.shape[0], 1)), x], axis=1)# 使用最小二乘法计算系数self.coef_ = np.linalg.inv(x.T @ x) @ x.T @ yreturn selfdef predict(self, x: np.array):# 添加一列全为1的偏置项x = np.concatenate([np.ones((x.shape[0], 1)), x], axis=1)# 根据计算出的系数进行预测y_pred = x @ self.coef_return y_pred

加载数据

数据加载代码与我们之前课程中的相同。没有必要改变任何东西。

python"># 从 torchvision 加载数据集
train_data = datasets.CIFAR10(root="./cifar10_data", train=True, download=True, transform=transforms.ToTensor())
test_data = datasets.CIFAR10(root="./cifar10_data", train=False, download=True, transform=transforms.ToTensor())# 将训练部分分为train和val# 我们将把 80% 的图片纳入训练样本
train_size = int(len(train_data) * 0.8)
# 进行验证 - 剩余 20%
val_size = len(train_data) - train_sizetrain_data, val_data = torch.utils.data.random_split(train_data, [train_size, val_size])# 我们启动将生成批次的数据加载器
train_loader = torch.utils.data.DataLoader(train_data, batch_size=64, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_data, batch_size=64, shuffle=False)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=64, shuffle=False)

输出:
100%|██████████| 170M/170M [00:04<00:00, 40.6MB/s]

让我们看一下数据集中的一些图像:

python"># 绘制图像的函数
def show_images(images, labels):f, axes= plt.subplots(1, 10, figsize=(30,5))for i, axis in enumerate(axes):# 将图像从张量转换为numpyimg = images[i].numpy()# 将图像转换为尺寸(长度、宽度、颜色通道)img = np.transpose(img, (1, 2, 0))axes[i].imshow(img)axes[i].set_title(labels[i].numpy())plt.show()# 获取一批图像
for batch in train_loader:images, labels = batchbreakshow_images(images, labels)

在这里插入图片描述

建立模型

下面是用于构建模型的单元格。您不应该立即制作具有大量层的大型复杂模型:这样的网络将需要很长时间来训练,并且很可能会过度训练。

您的主要任务是训练模型并在延迟(测试样本)上获得至少 60% 准确度的质量。

注意:你的模型必须由“模型”变量表示。

您可以尝试以下方法来改善网络结果:

  • 尝试不同数量的卷积层和全连接层;
  • 在卷积层中尝试不同数量的过滤器;
  • 在隐藏的全连接层中尝试不同数量的神经元;
  • 尝试在完全连接层和卷积层之后添加 BatchNorm。请注意,nn.BatchNorm2d 用于卷积层。 num_features 参数等于卷积层的过滤器(out_channels)的数量;
  • 尝试添加/删除 max_pooling;
  • 改变学习率;
  • 对网络进行更多次的训练。

如果您的模型过度拟合(验证指标开始变得更糟),请尝试减少模型参数的数量。如果模型没有过拟合,但是效果不佳,请尝试增加模型参数的数量。

模板为:

python"># 此处输入您的代码
# 声明一个卷积神经网络类class ConvNet(nn.Module):def __init__(self):super().__init__()# 此处输入您的代码# 定义网络的层...def forward(self, x):# 维度 x ~ [64, 3, 32, 32]# 此处输入您的代码# 实现前向传递网络 ...

我们按照要求进行解决

python">class ConvNet(nn.Module):def __init__(self):super().__init__()# 第一个卷积层,输入通道 3(RGB 图像),输出通道 16,卷积核大小 3self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)# 第一个 BatchNorm 层,对应卷积层的输出通道数 16self.bn1 = nn.BatchNorm2d(16)# 最大池化层,池化核大小 2self.pool1 = nn.MaxPool2d(2)# 第二个卷积层,输入通道 16,输出通道 32,卷积核大小 3self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)# 第二个 BatchNorm 层,对应卷积层的输出通道数 32self.bn2 = nn.BatchNorm2d(32)# 最大池化层,池化核大小 2self.pool2 = nn.MaxPool2d(2)# 展平层self.flatten = nn.Flatten()# 第一个全连接层,输入维度根据前面卷积和池化层输出计算得到(8 * 8 * 32),输出维度 64self.fc1 = nn.Linear(8 * 8 * 32, 64)# 第一个全连接层的 BatchNorm 层self.bn_fc1 = nn.BatchNorm1d(64)# 第二个全连接层,输入维度 64,输出维度 10(CIFAR - 10 有 10 个类别)self.fc2 = nn.Linear(64, 10)def forward(self, x):# 第一个卷积层 -> BatchNorm -> ReLU 激活 -> 最大池化x = F.relu(self.bn1(self.conv1(x)))x = self.pool1(x)# 第二个卷积层 -> BatchNorm -> ReLU 激活 -> 最大池化x = F.relu(self.bn2(self.conv2(x)))x = self.pool2(x)# 展平x = self.flatten(x)# 第一个全连接层 -> BatchNorm -> ReLU 激活x = F.relu(self.bn_fc1(self.fc1(x)))# 第二个全连接层得到最终输出x = self.fc2(x)return x
python">model = ConvNet()

下面的单元检查 GPU 是否可用,如果有,则将神经网络传输到 GPU。

python">device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

模型训练

网络训练功能(无需更改)。

该函数每50次训练迭代输出一次训练样本上的损失和准确率的当前值。此外,每次迭代之后都会计算并显示验证样本的损失和准确度。这些值让你了解你的模型学习得如何。

python">def evaluate(model, dataloader, loss_fn):losses = []num_correct = 0num_elements = 0for i, batch in enumerate(dataloader):# 这就是我们获取当前批次的方法X_batch, y_batch = batchnum_elements += len(y_batch)with torch.no_grad():logits = model(X_batch.to(device))loss = loss_fn(logits, y_batch.to(device))losses.append(loss.item())y_pred = torch.argmax(logits, dim=1)num_correct += torch.sum(y_pred.cpu() == y_batch)accuracy = num_correct / num_elementsreturn accuracy.numpy(), np.mean(losses)def train(model, loss_fn, optimizer, n_epoch=3):# 网络训练周期for epoch in range(n_epoch):print("Epoch:", epoch+1)model.train(True)running_losses = []running_accuracies = []for i, batch in enumerate(train_loader):# 这就是我们获取当前批次的方法X_batch, y_batch = batch# 前向传递(获取对一批图像的响应)logits = model(X_batch.to(device))# 计算网络给出的答案和批次的正确答案的损失loss = loss_fn(logits, y_batch.to(device))running_losses.append(loss.item())loss.backward() # backpropagation (梯度计算)optimizer.step() # 更新网络权重optimizer.zero_grad() # 重置权重# 计算当前训练批次的准确率model_answers = torch.argmax(logits, dim=1)train_accuracy = torch.sum(y_batch == model_answers.cpu()) / len(y_batch)running_accuracies.append(train_accuracy)# 记录结果if (i+1) % 100 == 0:print("Average train loss and accuracy over the last 50 iterations:",np.mean(running_losses), np.mean(running_accuracies), end='\n')# 每个时期之后,我们都会得到验证样本的质量指标model.train(False)val_accuracy, val_loss = evaluate(model, val_loader, loss_fn=loss_fn)print("Epoch {}/{}: val loss and accuracy:".format(epoch+1, n_epoch,),val_loss, val_accuracy, end='\n')return model

我们正在开始训练

python"># 再次声明模型
model = ConvNet()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)# 选择损失函数
loss_fn = torch.nn.CrossEntropyLoss()# 选择优化算法和学习率。
# 你可以尝试不同的 learning_rate 值
learning_rate = 1e-3
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
python"># 让我们开始训练模型
# 参数 n_epoch 可以变化
model = train(model, loss_fn, optimizer, n_epoch=3)

输出:
在这里插入图片描述

质量指标

获取测试样本的质量指标

python">test_accuracy, _ = evaluate(model, test_loader, loss_fn)
print('Accuracy on the test', test_accuracy)

输出:
Accuracy on the test 0.6913

检查是否满足所需的阈值:

python">if test_accuracy <= 0.5:print("测试质量低于0.5,0分")
elif test_accuracy < 0.6:print("测试质量在 0.5 至 0.6 之间,得 0.5 分")
elif test_accuracy >= 0.6:print("测试质量高于 0.6,得 1 分")

输出:
测试质量高于 0.6,得 1 分

下面的单元格包含使用训练过的网络获取文件的代码,获取pth文件

python">model.eval()
x = torch.randn((1, 3, 32, 32))
torch.jit.save(torch.jit.trace(model.cpu(), (x)), "model.pth")

http://www.ppmy.cn/news/1580312.html

相关文章

idea 编译打包nacos2.0.3源码,生成可执行jar 包常见问题

目录 问题1 问题2 问题3 问题4 简单记录一下nacos2.0.3&#xff0c;编译打包的步骤&#xff0c;首先下载源码&#xff0c;免积分下载&#xff1a; nacos源码&#xff1a; https://download.csdn.net/download/fyihdg/90461118 protoc 安装包 https://download.csdn.net…

出租车数据可视化分析-大数据-实训大作业

第1章 项目绪论 1.1项目的总体说明背景及意义 在纽约&#xff0c;游客们往往把自由女神象、帝国大厦、中央公园等视为纽约的象征, 但穿梭在人海中的出租车也是纽约靓丽的人文景观之一, 是其流动的风景线, 在纽约公共文化中别具魅力。本项目利用之前从seaborn上下载的数据tax…

基于Asp.net的物流配送管理系统

作者&#xff1a;计算机学姐 开发技术&#xff1a;SpringBoot、SSM、Vue、MySQL、JSP、ElementUI、Python、小程序等&#xff0c;“文末源码”。 专栏推荐&#xff1a;前后端分离项目源码、SpringBoot项目源码、Vue项目源码、SSM项目源码、微信小程序源码 精品专栏&#xff1a;…

DeepSeek技术解析:MoE架构实现与代码实战

以下是一篇结合DeepSeek技术解析与代码示例的技术文章&#xff0c;重点展示其核心算法实现与落地应用&#xff1a; DeepSeek技术解析&#xff1a;MoE架构实现与代码实战 作为中国AI领域的创新代表&#xff0c;DeepSeek在混合专家模型&#xff08;Mixture of Experts, MoE&…

Chat2DB:自然语言生成 SQL 的时代来临,数据库管理更简单

作者&#xff1a;后端小肥肠 目录 1. 前言 2. 数据库管理工具对比 3. Chat2DB安装及实际测评 3.1. Chat2DB安装 3.2. AI功能测评 3.2.1. 自然语言创建表 3.2.2. 自然语言查询 4. 结语 1. 前言 提到数据库管理工具&#xff0c;Navicat 曾经是大家的首选&#xff0c;但随…

【C语言】:学生管理系统(多文件版)

一、文件框架 二、Data data.txt 三、Inc 1. list.h 学生结构体 #ifndef __LIST_H__ #define __LIST_H__#include <stdio.h> #include <stdlib.h> #include <string.h> #include <stdbool.h> #include <time.h>#define MAX_LEN 20// 学生信息…

Nuxt2 vue 给特定的页面 body 设置 background 不影响其他页面

首先认识一下 BODY_ATTRS 他可以在页面单独设置 head () {return {bodyAttrs: {form: form-body}};},设置完效果是只有这个页面会加上 接下来在APP.vue中添加样式

顺序表和链表的对比(一)

前言 今天给小伙伴们分享的是在数据结构中顺序表和链表的对比。它们在计算机科学和软件开发中具有广泛的应用&#xff0c;是理解更复杂数据结构&#xff08;如栈、队列、树、图等&#xff09;的基础。这次将会给大家从定义初始化&#xff0c;以及功能增删查改上进行详细对比&a…