利用 Resnet50 重新训练,完成宠物数据集的识别,附源代码。。

news/2025/3/20 14:47:27/

如果你对深度学习有所了解,知道神经网络可以识别图片,但还没自己动手训练过模型,这篇文章会非常适合你。

这篇文章将使用 PyTorch 和 ResNet50,基于 Oxford-IIIT Pet 数据集(37 类宠物)完成一个完整的训练过程。

这个方法也可以应用到你自己的数据集上,比如识别不同种类的花或物体。

接下来,带你一步步完成这个任务。

Attention:全网最全的 AI 小白到 AI 大神的天梯成长学习路线,几十万原创专栏和硬核视频,点击这里查看:AI小白到AI大神的天梯之路

什么是 ResNet50,为什么选择它?

ResNet50 是一个深度卷积神经网络,包含 50 层,设计用来处理图像分类任务。

它在 ImageNet 数据集上表现优异,能识别 1000 种物体。

我们今天的目标是重新训练它,让它学会识别新的类别——37 种宠物

选择 ResNet50 的理由很简单——

  • 成熟的结构,它已经被广泛验证,适合大多数图像分类任务。
  • 开箱即用:PyTorch 提供了现成的实现,省去自己设计的麻烦。
  • 高效性:即使从零开始训练,也能得到不错的结果。

下面,我们将训练过程拆成几个关键步骤,逐步讲解。

训练 ResNet50 的四大步骤

步骤 1:准备数据

模型训练的第一步是准备数据。

Oxford-IIIT Pet 数据集包含大量宠物照片,我们需要调整它们的格式,确保模型能正确处理。

代码是这样实现的:

transform = transforms.Compose([transforms.Resize((224, 224)),  # 将图像调整为 224x224 像素transforms.ToTensor(),          # 将图像转换为 Tensor 格式transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])  # 标准化像素值])   
  • ResizeResNet50 的输入要求是 224x224,所有图像需要统一到这个尺寸。
  • ToTensor将图片从普通格式转为模型能处理的数字格式(范围 0 到 1)。
  • Normalize用 ImageNet 的均值和标准差标准化数据,帮助模型更快收敛。

接着,用 DataLoader 将数据分成小批次:

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)

这里 batch_size=32 表示每次处理 32 张图片,shuffle=True 打乱顺序,避免模型记住数据的排列。

步骤 2:搭建模型——调整 ResNet50 的结构

ResNet50 是一个现成的模型,但我们需要根据任务调整它。

原始的 ResNet50 输出 1000 类,而我们的数据集只有 37 类,因此需要修改最后一层。

代码实现如下:

model = torchvision.models.resnet50(weights=None)  # 初始化 ResNet50,不使用预训练权重
model.fc = nn.Linear(model.fc.in_features, 37)     # 将全连接层改为 37 类输出
model = model.to(device)      # 转移到 GPU 或 CPU
  • weights=None表示将从零开始训练模型。
  • model.fc这一行代码修改了模型最后一层(全连接层),将输出特征数改为 37 个,对应 37 类宠物。如果你有自己的数据集,且分类数量与原始模型不一致,也需要进行类似的修改。
  • to(device)根据设备(GPU 或 CPU)运行模型,GPU 会显著加速训练。

步骤 3:定义学习方式

模型需要知道如何学习以及学习步长是什么样的,这样才能优化模型参数的调整过程。

这个过程主要涉及损失函数和优化器。

损失函数衡量的是模型预测值与真实答案之间的差距,优化器则负责调整模型的参数。

用代码中是这样定义的:

criterion = nn.CrossEntropyLoss()          # 交叉熵损失,用于分类任务
optimizer = optim.Adam(model.parameters(), lr=0.001)  # Adam 优化器,学习率 0.001
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)  # 学习率调度器
  • 损失函数采用的是交叉熵损失函数,该函数是多分类任务的标准选择。
  • 优化器Adam 是一种高效的优化算法,lr=0.001 是初始学习率。
  • 调度器每 5 个 epoch,学习率乘以 0.1,逐步降低以稳定训练。

步骤 4:训练与测试——让模型学习和验证

训练其实就是让模型反复调整自己参数的过程,验证则是检查训练的效果。

训练和验证的逻辑分别在两个函数中实现。

训练函数:

def train(epoch):model.train()  # 进入训练模式for inputs, targets in train_loader:inputs, targets = inputs.to(device), targets.to(device)optimizer.zero_grad()  # 清零梯度outputs = model(inputs)  # 前向传播loss = criterion(outputs, targets)  # 计算损失loss.backward()  # 反向传播optimizer.step()  # 更新参数
  • train()激活模型的训练模式(启用 dropout/BN 层的全局统计功能)。
  • 流程模型预测 -> 计算损失 -> 调整参数。

测试函数:

def test(epoch):model.eval()  # 进入测试模式with torch.no_grad():  # 关闭梯度计算for inputs, targets in test_loader:outputs = model(inputs)# 计算准确率...
  • eval()切换到测试模式,关闭训练时的随机性(Dropout, BN 不再进行全局统计)。
  • no_grad()节省内存,提高测试效率。

主循环运行 20 个 epoch,每次训练后测试,并保存最佳模型:

for epoch in range(1, 21):train(epoch)test_acc = test(epoch)if test_acc > best_acc:best_acc = test_acctorch.save(model.state_dict(), "best_pet_model.pth")

训练效果

运行完整的代码后,你会看到类似这样的结果:

Epoch 1 | Train Acc: 50.23%Epoch 1 | Test Acc: 52.10%...Best Test Accuracy: 85.67%

这表示模型在测试集上的最高准确率达到 85.67%。

如果效果不理想,可以尝试下面的改进方法。

改进建议

使用预训练权重

weights=None 改为 weights='DEFAULT',利用 ImageNet 的经验加速训练。

数据增强

transform 中加入 transforms.RandomHorizontalFlip(),增加数据多样性。

调整参数

尝试不同的学习率(如 0.0001)或 batch_size(如 64),找到最佳组合。

通过以上的四个步骤——准备数据、搭建模型、设定规则、训练测试,你就可以用 ResNet50 训练自己的数据集了。

这个过程并不复杂,只要理解每个部分的逻辑,就能灵活应用到其他任务上。

如果你有自己的数据集,不妨试一试。

宠物训练的完整代码见这里:https://github.com/dongdongcan/ai_model_samples/tree/main/resnet50_train_oxford_iiit_pet

备注,本文的完整代码最好在 GPU 环境下运行。


http://www.ppmy.cn/news/1580621.html

相关文章

嵌入式八股ARM篇

前言 ARM篇主要介绍一下寄存器和中断机制,至于汇编这一块…还请大家感兴趣自行学习 1.寄存器 R0 - R3 R4 - R11 寄存器 R0 - R3一般用作函数传参 R4 - R11用来保存程序运算的中间结果或函数的局部变量 在函数调用过程中 注意在发生异常的时候 cortex-M0架构会自动将R0-R3压入…

【Python机器学习】3.2. 决策树理论(进阶):ID3算法、信息熵原理、信息增益

喜欢的话别忘了点赞、收藏加关注哦(关注即可查看全文),对接下来的教程有兴趣的可以关注专栏。谢谢喵!(・ω・) 本文承接 3.1. 决策树理论(基础),没看过的建议先看前文。 3.2.1. ID3算法数学原理…

云钥科技工业相机定制服务,助力企业实现智能智造

在工业自动化、智能制造和机器视觉快速发展的今天,工业相机作为核心感知设备,其性能直接决定了检测精度、生产效率和产品质量。然而,标准化工业相机往往难以满足复杂多样的应用场景需求,‌工业相机定制‌逐渐成为企业突破技术瓶颈…

面试总结之 Glide自定义的三级缓存策略

一、为什么需要三级缓存? 在移动应用开发中,图片加载性能直接影响用户体验。根据 Google 统计,图片加载延迟超过 1 秒会导致 32% 的用户流失。传统图片加载方案存在以下痛点: 内存占用高:未压缩的大图直接占用大量内…

WindowsAD域服务权限提升漏洞

WindowsAD 域服务权限提升漏洞(CVE-2021-42287, CVE-2021-42278) 1.漏洞描述 Windows域服务权限提升漏洞(CVE-2021-42287, CVE-2021-42278)是由于Active Directory 域服务没有进行适当的安全限制,导致可绕过安…

09 python函数(上)

一、函数的介绍 什么是函数? 函数的诞生为了解决两个问题:可读性、重复性。使用函数可以将一些代码放在一起成为一个功能,方便调用,出现了函数也方便用户阅读代码。 函数是组织好的,可重复使用的,用来实现…

《C#上位机开发从门外到门内》3-5:基于FastAPI的Web上位机系统

文章目录 一、项目概述二、系统架构设计三、前后端开发四、数据可视化五、远程控制六、系统安全性与稳定性七、性能优化与测试八、实际应用案例九、结论 随着互联网技术的快速发展,Web上位机系统在工业自动化、智能家居、环境监测等领域的应用日益广泛。基于FastAPI…

腾讯云MySQL数据库架构分析与使用场景

TDSQL-C for MySQL TDSQL-C MySQL 版(TDSQL-C for MySQL)是腾讯云自研的新一代云原生关系型数据库。融合了传统数据库、云计算与新硬件技术的优势,为用户提供具备高弹性、高性能、海量存储、安全可靠的数据库服务。TDSQL-C MySQL 版100%兼容…