利用 Resnet50 重新训练,完成宠物数据集的识别,附源代码。。

devtools/2025/3/31 22:42:50/

如果你对深度学习有所了解,知道神经网络可以识别图片,但还没自己动手训练过模型,这篇文章会非常适合你。

这篇文章将使用 PyTorch 和 ResNet50,基于 Oxford-IIIT Pet 数据集(37 类宠物)完成一个完整的训练过程。

这个方法也可以应用到你自己的数据集上,比如识别不同种类的花或物体。

接下来,带你一步步完成这个任务。

Attention:全网最全的 AI 小白到 AI 大神的天梯成长学习路线,几十万原创专栏和硬核视频,点击这里查看:AI小白到AI大神的天梯之路

什么是 ResNet50,为什么选择它?

ResNet50 是一个深度卷积神经网络,包含 50 层,设计用来处理图像分类任务。

它在 ImageNet 数据集上表现优异,能识别 1000 种物体。

我们今天的目标是重新训练它,让它学会识别新的类别——37 种宠物

选择 ResNet50 的理由很简单——

  • 成熟的结构,它已经被广泛验证,适合大多数图像分类任务。
  • 开箱即用:PyTorch 提供了现成的实现,省去自己设计的麻烦。
  • 高效性:即使从零开始训练,也能得到不错的结果。

下面,我们将训练过程拆成几个关键步骤,逐步讲解。

训练 ResNet50 的四大步骤

步骤 1:准备数据

模型训练的第一步是准备数据。

Oxford-IIIT Pet 数据集包含大量宠物照片,我们需要调整它们的格式,确保模型能正确处理。

代码是这样实现的:

transform = transforms.Compose([transforms.Resize((224, 224)),  # 将图像调整为 224x224 像素transforms.ToTensor(),          # 将图像转换为 Tensor 格式transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])  # 标准化像素值])   
  • ResizeResNet50 的输入要求是 224x224,所有图像需要统一到这个尺寸。
  • ToTensor将图片从普通格式转为模型能处理的数字格式(范围 0 到 1)。
  • Normalize用 ImageNet 的均值和标准差标准化数据,帮助模型更快收敛。

接着,用 DataLoader 将数据分成小批次:

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)

这里 batch_size=32 表示每次处理 32 张图片,shuffle=True 打乱顺序,避免模型记住数据的排列。

步骤 2:搭建模型——调整 ResNet50 的结构

ResNet50 是一个现成的模型,但我们需要根据任务调整它。

原始的 ResNet50 输出 1000 类,而我们的数据集只有 37 类,因此需要修改最后一层。

代码实现如下:

model = torchvision.models.resnet50(weights=None)  # 初始化 ResNet50,不使用预训练权重
model.fc = nn.Linear(model.fc.in_features, 37)     # 将全连接层改为 37 类输出
model = model.to(device)      # 转移到 GPU 或 CPU
  • weights=None表示将从零开始训练模型。
  • model.fc这一行代码修改了模型最后一层(全连接层),将输出特征数改为 37 个,对应 37 类宠物。如果你有自己的数据集,且分类数量与原始模型不一致,也需要进行类似的修改。
  • to(device)根据设备(GPU 或 CPU)运行模型,GPU 会显著加速训练。

步骤 3:定义学习方式

模型需要知道如何学习以及学习步长是什么样的,这样才能优化模型参数的调整过程。

这个过程主要涉及损失函数和优化器。

损失函数衡量的是模型预测值与真实答案之间的差距,优化器则负责调整模型的参数。

用代码中是这样定义的:

criterion = nn.CrossEntropyLoss()          # 交叉熵损失,用于分类任务
optimizer = optim.Adam(model.parameters(), lr=0.001)  # Adam 优化器,学习率 0.001
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)  # 学习率调度器
  • 损失函数采用的是交叉熵损失函数,该函数是多分类任务的标准选择。
  • 优化器Adam 是一种高效的优化算法,lr=0.001 是初始学习率。
  • 调度器每 5 个 epoch,学习率乘以 0.1,逐步降低以稳定训练。

步骤 4:训练与测试——让模型学习和验证

训练其实就是让模型反复调整自己参数的过程,验证则是检查训练的效果。

训练和验证的逻辑分别在两个函数中实现。

训练函数:

def train(epoch):model.train()  # 进入训练模式for inputs, targets in train_loader:inputs, targets = inputs.to(device), targets.to(device)optimizer.zero_grad()  # 清零梯度outputs = model(inputs)  # 前向传播loss = criterion(outputs, targets)  # 计算损失loss.backward()  # 反向传播optimizer.step()  # 更新参数
  • train()激活模型的训练模式(启用 dropout/BN 层的全局统计功能)。
  • 流程模型预测 -> 计算损失 -> 调整参数。

测试函数:

def test(epoch):model.eval()  # 进入测试模式with torch.no_grad():  # 关闭梯度计算for inputs, targets in test_loader:outputs = model(inputs)# 计算准确率...
  • eval()切换到测试模式,关闭训练时的随机性(Dropout, BN 不再进行全局统计)。
  • no_grad()节省内存,提高测试效率。

主循环运行 20 个 epoch,每次训练后测试,并保存最佳模型:

for epoch in range(1, 21):train(epoch)test_acc = test(epoch)if test_acc > best_acc:best_acc = test_acctorch.save(model.state_dict(), "best_pet_model.pth")

训练效果

运行完整的代码后,你会看到类似这样的结果:

Epoch 1 | Train Acc: 50.23%Epoch 1 | Test Acc: 52.10%...Best Test Accuracy: 85.67%

这表示模型在测试集上的最高准确率达到 85.67%。

如果效果不理想,可以尝试下面的改进方法。

改进建议

使用预训练权重

weights=None 改为 weights='DEFAULT',利用 ImageNet 的经验加速训练。

数据增强

transform 中加入 transforms.RandomHorizontalFlip(),增加数据多样性。

调整参数

尝试不同的学习率(如 0.0001)或 batch_size(如 64),找到最佳组合。

通过以上的四个步骤——准备数据、搭建模型、设定规则、训练测试,你就可以用 ResNet50 训练自己的数据集了。

这个过程并不复杂,只要理解每个部分的逻辑,就能灵活应用到其他任务上。

如果你有自己的数据集,不妨试一试。

宠物训练的完整代码见这里:https://github.com/dongdongcan/ai_model_samples/tree/main/resnet50_train_oxford_iiit_pet

备注,本文的完整代码最好在 GPU 环境下运行。


http://www.ppmy.cn/devtools/169245.html

相关文章

【SpringBoot】MorningBox小程序的完整后端接口文档

以下是「晨光宅配」小程序的完整接口文档,涵盖了所有12个表的接口。 每个接口包括请求方法、URL、请求参数、响应格式和示例 接口文档 1. 用户模块 1.1 获取用户信息 URL: /user/{userId}方法: GET请求参数: userId (路径参数): 用户ID响应格式:{"userId": 1,&qu…

Android 接 Twitter Share ,常见问题及解决方案

1. 应用未授权或授权失败 问题描述:当尝试分享内容到 Twitter 时,应用提示未授权,或者在授权过程中出现错误,无法获取授权码或访问令牌。解决方案 检查 Twitter API 密钥和密钥密码:确保在 Twitter 开发者平台创建应用后,获取的 API 密钥(Consumer Key)和 API 密钥密码…

FRP在远程办公中的实战应用

远程办公场景中,FRP可穿透企业防火墙,安全访问内网资源。以下是典型用例: SSH远程连接 配置示例: 客户端配置SSH映射,将本地22端口映射至公网服务器的6000端口,用户通过ssh -p 6000 user公网IP即可连接内网…

如何借助es的snapshot跨集群迁移部分索引

1.创建源集群的快照仓库 使用fs方式,首先需要在所有节点挂载文件系统 然后在elasticsearch.yaml中新增配置path.repo 必须确保对应目录具备读写权限 path.repo: /mount/backups 修改配置重启完之后,开始创建快照仓库 PUT /_snapshot/my_repository…

Java面试高频问题深度解析:JVM、锁机制、SQL优化与并发处理

问题列表 Java中如何实现一个工作流引擎?Bean的作用域有哪些?JVM中的锁机制是如何工作的?三个方法分别被 synchronized 锁住,方法 a 调用方法 b,b 能获取到 a 的锁吗?会有什么问题?SQL优化时,EXPLAIN 中需要关注哪些关键点?什么是覆盖索引?SELECT * 一定不会命中索引…

rust Send Sync 以及对象安全和对象不安全

开头:菜鸟小明的疑惑 小明: “李哥,我最近学 Rust,感觉它超级严谨,啥 Send、Sync、对象安全、静态分发、动态分发的,我都搞晕了!为啥 Rust 要设计得这么复杂啊?” 小李(…

Mac使用pycharm+基于Kaggle的社交媒体情绪分析数据集,用python做词云的可视化

pycharm版本 一开始用的专业版,但是太久没有写代码就账户过期了,找半天Activation Code也没有找到,重新下载一个社区版,我点进去是社区版的页面,但是下载结果是专业版,后面仔细看,mad社区版在下…

【HDLBits】Procedures合集

Always blocks 基础定义 由于数字电路是由用导线连接的逻辑门组成的,因此任何电路都可以表示为模块和赋值语句的某种组合。然而,有时这不是最方便的方式来描述电路。过程(以always块为例)为描述电路提供了另一种语法。 对于综合…