使用pytorch构建ResNet50模型训练猫狗数据集

news/2024/10/21 19:04:53/

数据集

1.导包

python">import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
import numpy as np
import matplotlib.pyplot as plt
import os
from tqdm.auto import tqdm  # 引入tqdm库以显示进度条

2.数据预处理

ResNet50模型适合的图片大小为224x244

python"># 定义数据转换
data_transforms = {'train': transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),'test': transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
}

3.加载数据集和模型构建

python"># 加载数据集
data_dir = 'catdog_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),data_transforms[x])for x in ['train', 'test']}
dataloaders = {x: DataLoader(image_datasets[x], batch_size=4,shuffle=True, num_workers=4)for x in ['train', 'test']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'test']}
class_names = image_datasets['train'].classes# 加载ResNet-50模型
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)# 替换最后的全连接层以适配我们的分类问题
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, len(class_names))# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

4.训练

python"># 训练次数
num_epochs = 10# 初始化训练次数计数器
train_count = 0
for epoch in range(num_epochs):  # num_epochs 是你希望训练的轮数for phase in ['train', 'test']:if phase == 'train':model.train()else:model.eval()running_loss = 0.0running_corrects = 0# 使用tqdm显示进度条with tqdm(total=len(dataloaders[phase]), desc=f'Epoch {epoch+1}/{num_epochs}', leave=False) as progress_bar:for inputs, labels in dataloaders[phase]:optimizer.zero_grad()with torch.set_grad_enabled(phase == 'train'):outputs = model(inputs)_, preds = torch.max(outputs, 1)loss = criterion(outputs, labels)if phase == 'train':loss.backward()optimizer.step()running_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)epoch_loss = running_loss / dataset_sizes[phase]epoch_acc = running_corrects.double() / dataset_sizes[phase]progress_bar.set_postfix(loss=epoch_loss, acc=epoch_acc)progress_bar.update(1)print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')# 更新训练次数计数器train_count += 1print(f'Training Count: {train_count}')

训练过程

5.预测

python">import torch
import torchvision
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt# 定义模型的类别数量
num_classes = 2# 加载模型
model = torchvision.models.resnet50(pretrained=False)
# 修改模型的fc层以匹配训练时的结构
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
# 加载保存的权重
model.load_state_dict(torch.load('mg_ResNet50model.pth'))
model.eval()# 图像预处理
preprocess = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])# 测试图片
img_path = 'mao_1.jpg'  # 替换为你的图片路径
img = Image.open(img_path)
img_t = preprocess(img)# 扩展维度,因为模型需要4维输入(Batch, Channels, Height, Width)
batch_t = torch.unsqueeze(img_t, 0)# 预测
with torch.no_grad():out = model(batch_t)# 获取最高分数的类别
_, index = torch.max(out, 1)# 可视化结果
plt.imshow(img)
plt.title(f'Predicted: {index.item()}')
plt.show()

预测效果

0就是猫咪,1就是小狗

全部代码

python">import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
import numpy as np
import matplotlib.pyplot as plt
import os
from tqdm.auto import tqdm  # 引入tqdm库以显示进度条# 定义数据转换
data_transforms = {'train': transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),'test': transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
}# 加载数据集
data_dir = 'catdog_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),data_transforms[x])for x in ['train', 'test']}
dataloaders = {x: DataLoader(image_datasets[x], batch_size=4,shuffle=True, num_workers=4)for x in ['train', 'test']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'test']}
class_names = image_datasets['train'].classes# 加载ResNet-50模型
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)# 替换最后的全连接层以适配我们的分类问题
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, len(class_names))# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)# 训练次数
num_epochs = 10# 初始化训练次数计数器
train_count = 0
for epoch in range(num_epochs):  # num_epochs 是你希望训练的轮数for phase in ['train', 'test']:if phase == 'train':model.train()else:model.eval()running_loss = 0.0running_corrects = 0# 使用tqdm显示进度条with tqdm(total=len(dataloaders[phase]), desc=f'Epoch {epoch+1}/{num_epochs}', leave=False) as progress_bar:for inputs, labels in dataloaders[phase]:optimizer.zero_grad()with torch.set_grad_enabled(phase == 'train'):outputs = model(inputs)_, preds = torch.max(outputs, 1)loss = criterion(outputs, labels)if phase == 'train':loss.backward()optimizer.step()running_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)epoch_loss = running_loss / dataset_sizes[phase]epoch_acc = running_corrects.double() / dataset_sizes[phase]progress_bar.set_postfix(loss=epoch_loss, acc=epoch_acc)progress_bar.update(1)print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')# 更新训练次数计数器train_count += 1print(f'Training Count: {train_count}')


http://www.ppmy.cn/news/1464869.html

相关文章

前端基础入门三大核心之HTML篇:Webpack的八种常用Loader用途及加载方式全解析

前端基础入门三大核心之HTML篇:Webpack的八种常用Loader用途及加载方式全解析 一、Loader基础与配置加载顺序与链式调用 二、八大Loader详解1. **css-loader & style-loader**2. **less-loader & sass-loader**3. **babel-loader**4. **file-loader & u…

MySQL(进阶)--索引

目录 一.存储引擎 1.MySQL体系结构​编辑 2.存储引擎简介 3.存储引擎特点 (1.InnoDB (2.MyISAM (3.Memory 4.存储引擎选择 二.索引 1.索引概述 2.索引结构 3.索引分类 4.索引语法 (1.创建索引 (2.查看索引 (3.删除索引 5.SQL性能分析 (1.SQL执行频率 (2.慢查…

代码随想录二刷 Day05 | 242.有效的字母异位词,349. 两个数组的交集,202. 快乐数,1. 两数之和,454.四数相加II,383. 赎金信

题目与题解 参考资料:哈希表理论基础 Tips: 一般哈希表都是用来快速判断一个元素是否出现集合里哈希表生成原理:先通过哈希函数将变量映射为hashcode,如果二者hashcode相同,再通过哈希碰撞方法(拉链法&…

数据挖掘与机器学习——回归分析

目录 回归分析定义: 案例: 线性回归 预备知识: 定义: 一元线性回归: 如何找出最佳的一元线性回归模型: 案例: python实现: 多元线性回归 案例: 线性回归的优缺点…

网络故障与排除(一)

一、Router-ID冲突导致OSPF路由环路 路由器收到相同Router-ID的两台设备发送的LSA,所以查看路由表看到的OSPF缺省路由信息就会不断变动。而当C1的缺省路由从C2中学到,C2的缺省路由又从C1中学到时,就形成了路由环路,因此出现路由不…

vue3学习(四)

前言 接上篇学习笔记&#xff0c;分享3个内置组件&#xff1a;动态组件、缓存组件、分发组件基本用法。大家一起通过code的示例&#xff0c;从现象理解,注意再次理解生命周期。 一、code示例 组件A&#xff1a;CompA <script setup> import {onMounted, onUnmounted} f…

NXP i.MX8系列平台开发讲解 - 3.12 Linux 之Audio子系统(一)

专栏文章目录传送门&#xff1a;返回专栏目录 目录 1. Audio 基础介绍 1.1 音频信号 1.2 音频的处理过程 1.3 音频硬件接口 1.3 音频专业术语解释 2. Linux Audio子系统介绍 3. Linux Audio子系统框架 Linux嵌入式系统中的音频子系统扮演着至关重要的角色&#xff0c;它涉…

day21二叉树part07|530.二叉搜索树的最小绝对差 501.二叉搜索树中的众数 236. 二叉树的最近公共祖先

**530.二叉搜索树的最小绝对差 ** 遇到在二叉搜索树上求什么最值&#xff0c;求差值之类的&#xff0c;都要思考一下二叉搜索树可是有序的&#xff0c;要利用好这一特点。 class Solution { public:void trival(TreeNode* node, vector<int>& nums) {if (node nul…