pytorch训练和使用resnet

server/2024/10/20 0:53:48/

pytorchresnet_0">pytorch训练和使用resnet

使用 CIFAR-10数据集

训练 resnet

resnet-train.py

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim# 在CIFAR-10数据集中
# 训练集:包含50000张图像,用于训练模型。
# 测试集:包含10000张图像,用于评估模型的性能。
TRAIN_SIZE=50000
TEST_SIZE=10000# 批量大小
BATCH_SIZE=128# 数据预处理
transform_train = transforms.Compose([transforms.RandomHorizontalFlip(),transforms.RandomCrop(32, padding=4),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])transform_test = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])# 加载CIFAR-10数据集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,shuffle=True, num_workers=2)testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE,shuffle=False, num_workers=2)classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')# 使用预训练的ResNet模型 , 不从默认url下载预训练的模型
model = torchvision.models.resnet18(weights=None)
# 从当前路径加载预训练权重
model_path = './model/resnet18-f37072fd.pth'
model.load_state_dict(torch.load(model_path))# 修改最后一层以适应CIFAR-10的10个类别
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 10)# 将模型移到GPU(如果有)
if torch.cuda.is_available() :print('Using GPU')device = torch.device("cuda:0")
else :print('Using CPU')device = torch.device("cpu")   model = model.to(device)# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=5e-4)# 学习率调度器
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)# 训练网络
num_epochs = 50print('start Training')for epoch in range(num_epochs):model.train()running_loss = 0.0#总迭代次数 = 训练集大小 / 批量大小 =  向上取整(TRAIN_SIZE=50000 / BATCH_SIZE=128) = 391 次循环for i, data in enumerate(trainloader, 0):inputs, labels = datainputs, labels = inputs.to(device), labels.to(device)# 梯度清零optimizer.zero_grad()# 前向传播 + 向后传播 + 优化outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()# 打印统计信息running_loss += loss.item()if i % 100 == 99:    # 每100个小批量打印一次print(f'[Epoch {epoch + 1}, Batch {i + 1}] loss: {running_loss / 100:.3f}')running_loss = 0.0# 更新学习率scheduler.step()print('Finished Training')# 测试网络
model.eval()
correct = 0
total = 0
with torch.no_grad():# 总迭代次数 = 测试集 / 批量大小 向上取整(TEST_SIZE=10000/BATCH_SIZE=128) = 79 次循环for data in testloader:images, labels = dataimages, labels = images.to(device), labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy_test = 100 * correct / total
print(f'Accuracy of the network on the 10000 test images: {accuracy_test:.2f}%')# [Epoch 50, Batch 300] loss: 0.142
# Finished Training
# Accuracy of the network on the 10000 test images: 84.53%# 准确率>0.8保存模型
if(accuracy_test > 0.8):print("Accuracy  > 0.8 ,save model")model_path = './model/trained_resnet18_cifar10.pth'torch.save(model.state_dict(), model_path)print(f'Model saved to {model_path}')

使用训练后的 resnet

评估数据
1.jpeg :

请添加图片描述

2.jpeg:

请添加图片描述

restnet-eval.py

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
from PIL import Image# 模型路径
model_path = './model/trained_resnet18_cifar10.pth'# 类别标签
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')# 数据预处理
transform = transforms.Compose([transforms.Resize((32, 32)),  # 调整图像大小为32x32transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 归一化
])# 加载预训练的ResNet模型
model = torchvision.models.resnet18(pretrained=False)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 10)
model.load_state_dict(torch.load(model_path))
model.eval()  # 设置模型为评估模式# 将模型移到GPU(如果有)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)def predict_image(image_path):# 加载并预处理图像image = Image.open(image_path).convert('RGB')image = transform(image).unsqueeze(0)  # 添加批次维度image = image.to(device)# 进行预测with torch.no_grad():outputs = model(image)_, predicted = torch.max(outputs.data, 1)# 输出预测结果predicted_class = classes[predicted.item()]print(f'Predicted class: {predicted_class}')# img is in classes
predict_image('./data/1.jpeg')# img is not in classes
predict_image('./data/2.jpeg')

http://www.ppmy.cn/server/133195.html

相关文章

滚雪球学Redis[3.3讲]:Redis数据持久化深入探讨:从 AOF 到混合持久化的演进

全文目录: 前言混合持久化1. RDB 与 AOF 之间的权衡2. 混合持久化的工作原理工作机制详解 3. 配置与实践实例演示 4. 实际应用中的案例分析5. 深入探讨混合持久化的优势与局限6. 扩展思考:如何选择 Redis 的持久化策略? 总结附:案…

字节跳动实习生投毒自家大模型细节曝光 影响到底有多大?

10月19日,字节跳动大模型训练遭实习生攻击一事引发广泛关注。据多位知情人士透露,字节跳动某技术团队在今年6月遭遇了一起内部技术袭击事件,一名实习生因对团队资源分配不满,使用攻击代码破坏了团队的模型训练任务。 据悉&#xf…

驱动开发系列21 - 编译内核模块的Makefile解释

一:内核模块Makefile #这一行定义了要编译的内核模块目标文件。obj-m表示目标模块对象文件(.o文件), #并指定了两个模块源文件:helloworld-params.c 和 helloworld.c。最终会生成这 #这两个.c文件的.o对象文件。 obj-m := helloworld-params.o helloworld.o#这行定义了内核…

H.264视频,HEVC视频,VP9视频,AV1视频小知识

H.264、HEVC(H.265)、VP9和AV1是不同的视频编码格式,它们的主要区别在于压缩效率、支持的分辨率、编码技术以及专利和授权费用等方面。以下是这些编码格式的主要区别: H.264(AVC): 压缩效率&…

python 猜数字游戏

要求: 设计一个猜数字游戏,程序会随机生成一个1~100之间的整数,然后让用户猜这个数字是多少。 解答: import randomprint("大家一起来猜数!") print("*"*50) print("系统生成随机数中...&…

高级java每日一道面试题-2024年10月15日-JVM篇-说一下JVM的主要组成部分?及其作用?

如果有遗漏,评论区告诉我进行补充 面试官: 说一下JVM的主要组成部分?及其作用? 我回答: Java 虚拟机(JVM)是 Java 运行时环境的核心组件,它负责执行 Java 字节码。JVM 的主要组成部分及其作用如下: 类加载器子系统 (Class L…

golang ws升级为wss

首先需要一份openssl证书 1.安装openssl windows安装openssl 的下载地址在 https://slproweb.com/products/Win32OpenSSL.html 无脑点安装就行,记得最后安装完成的页面取消勾选 安装完成后记得配置环境变量 2.生成证书 openssl req -x509 -days 36500 -nodes …

【秋招笔试-支持在线评测】10.12美团(已改编)秋招-三语言题解

🍭 大家好这里是 春秋招笔试突围,一起备战大厂笔试 💻 ACM金牌团队🏅️ | 多次AK大厂笔试 | 大厂实习经历 ✨ 本系列打算持续跟新 春秋招笔试题 👏 感谢大家的订阅➕ 和 喜欢💗 和 手里的小花花🌸 ✨ 笔试合集传送们 -> 🧷春秋招笔试合集 🍒 本专栏已收集…