PyTorch入门之【CNN】

news/2024/11/29 10:48:50/

参考:https://www.bilibili.com/video/BV1114y1d79e/?spm_id_from=333.999.0.0&vd_source=98d31d5c9db8c0021988f2c2c25a9620
书接上回的MLP故本章就不详细解释了

目录

  • train
  • test

train

import torch
from torchvision.transforms import ToTensor
from torchvision import datasets
import torch.nn as nn# load MNIST dataset
training_data = datasets.MNIST(root='../02_dataset/data',train=True,download=True,transform=ToTensor()
)train_data_loader = torch.utils.data.DataLoader(training_data, batch_size=64, shuffle=True)# define a CNN model
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv_1 = nn.Sequential(nn.Conv2d(1, 32, kernel_size=3, stride=1),nn.BatchNorm2d(32),nn.ReLU())self.conv_2 = nn.Sequential(nn.Conv2d(32, 64, kernel_size=3, stride=1),nn.BatchNorm2d(64),nn.ReLU(),)self.maxpool = nn.MaxPool2d(2)self.flatten = nn.Flatten()self.fc_1 = nn.Sequential(nn.Linear(9216, 128),nn.BatchNorm1d(128),nn.ReLU())self.fc_2 = nn.Linear(128, 10)def forward(self, x):x = self.conv_1(x)x = self.conv_2(x)x = self.maxpool(x)x = self.flatten(x)x = self.fc_1(x)logits = self.fc_2(x)return logits# create a CNN model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
cnn = CNN().to(device)
optimizer = torch.optim.Adam(cnn.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()# train the model
num_epochs = 20for epoch in range(num_epochs):print(f'Epoch {epoch+1}\n-------------------------------')for idx, (img, label) in enumerate(train_data_loader):size = len(train_data_loader.dataset)img, label = img.to(device), label.to(device)# compute prediction errorpred = cnn(img)loss = loss_fn(pred, label)# backpropagationoptimizer.zero_grad()loss.backward()optimizer.step()if idx % 400 == 0:loss, current = loss.item(), idx*len(img)print(f'loss: {loss:>7f} [{current:>5d}/{size:>5d}]')# save the model
torch.save(cnn.state_dict(), 'cnn.pth')
print('Saved PyTorch Model State to cnn.pth')

test

import torch
from torchvision import datasets
from torchvision import transforms
from torchvision.transforms import ToTensor
from torchvision.datasets import ImageFolder
import torch.nn as nn# load test data
test_data = datasets.MNIST(root='../02_dataset/data',train=False,download=True,transform=ToTensor()
)
test_data_loader = torch.utils.data.DataLoader(test_data, batch_size=64, shuffle=True)transform = transforms.Compose([transforms.Grayscale(),transforms.RandomRotation(10),transforms.ToTensor()
])
my_mnist = ImageFolder(root='../02_dataset/my-mnist', transform=transform)
my_mnist_loader = torch.utils.data.DataLoader(my_mnist, batch_size=64, shuffle=True)# define a CNN model
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv_1 = nn.Sequential(nn.Conv2d(1, 32, kernel_size=3, stride=1),nn.BatchNorm2d(32),nn.ReLU())self.conv_2 = nn.Sequential(nn.Conv2d(32, 64, kernel_size=3, stride=1),nn.BatchNorm2d(64),nn.ReLU(),)self.maxpool = nn.MaxPool2d(2)self.flatten = nn.Flatten()self.fc_1 = nn.Sequential(nn.Linear(9216, 128),nn.BatchNorm1d(128),nn.ReLU())self.fc_2 = nn.Linear(128, 10)def forward(self, x):x = self.conv_1(x)x = self.conv_2(x)x = self.maxpool(x)x = self.flatten(x)x = self.fc_1(x)logits = self.fc_2(x)return logits# load the pretrained model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
cnn = CNN()
cnn.load_state_dict(torch.load('cnn.pth', map_location=device))
cnn.eval().to(device)# test the pretrained model on MNIST test data
size = len(test_data_loader.dataset)
correct = 0with torch.no_grad():for img, label in test_data_loader:img, label = img.to(device), label.to(device)pred = cnn(img)correct += (pred.argmax(1) == label).type(torch.float).sum().item()correct /= size
print(f'Accuracy on MNIST: {(100*correct):>0.1f}%')# test the pretrained model on my MNIST test data
size = len(my_mnist_loader.dataset)
correct = 0with torch.no_grad():for img, label in my_mnist_loader:img, label = img.to(device), label.to(device)pred = cnn(img)correct += (pred.argmax(1) == label).type(torch.float).sum().item()correct /= size
print(f'Accuracy on my MNIST: {(100*correct):>0.1f}%')

http://www.ppmy.cn/news/1141023.html

相关文章

商城小程序代客下单程序开发演示

一款专为传统电商、实体商家开发的商城系统小程序,做私域、做留存、做社交必备功能全都有。 1、丰富的营销玩法:拼团、秒杀、定金预售、分销、社区团购、积分商城、支付有礼等主流获客玩法都有。 2、强大的会员体系:普通会员、付费会员、会…

[论文笔记]Poly-encoder

引言 本文是Poly-encoder1的阅读笔记,论文题目为基于预训练模型的快速准确多句评分模型。 也是本系列第一篇基于Transformer架构的模型,对于进行句子对之间比较的任务,有两种常用的途经:Cross-encoder在句子对上进行交互完全自注意力;Bi-encoder单独地编码不同的句子。前…

回顾Softing 2023工博之旅精彩瞬间

2023年9月23日,为期5天的第23届中国国际工业博览会(CIIF)于上海国家会展中心圆满落幕。Softing作为PROFIBUS创始人之一,德国工业4.0的领军企业之一,在本次展会上向大家呈现了众多工业自动化及IT网络方面的领先产品及方…

几道web题目

总结几道国庆写的web题目 [ACTF2020 新生赛]Include1 点进去发现就一个flag.php,源代码和抓包都没拿到好东西 结合题目猜是文件包含,构建payload ?filephp://filter/readconvert.base64-encode/resourceflag.php 得到base64编码过的flag,解码即可 此题…

【数据恢复篇】记一次Winhex镜像还原(恢复)到磁盘测试记录

【数据恢复篇】记一次Winhex镜像还原(恢复)到磁盘测试记录 镜像恢复到磁盘,怎么操作?会不会对磁盘有影响,是恢复到空磁盘?还是恢复到有数据的磁盘也可以?有数据的盘磁盘空间很多,恢…

LeetCode算法心得——有序三元组中的最大值 II (简单的动规思想)

大家好,我是晴天学长,枚举+简单的动态规划思想,需要的小伙伴可以关注支持一下哦!后续会继续更新的。 1) .有序三元组中的最大值 II 有序三元组中的最大值 II 给你一个下标从 0 开始的整数数组 nums 。 请你从所有满足 …

SwiftUI Spacer() onTapGesture 无法触发

问题:点击这个黑色区域不会 print,黑色区域看上去刚好是 Spacer() 占据的区域 解决办法:不使用 onTapGesture,用 Button 包裹一下 Code: import SwiftUIstruct TestTap: View {var body: some View {NavigationStack {List {Sect…

洗地机哪个好?2023最好用的洗地机

随着科技的进步,洗地机已成为家庭清洁的好帮手,不仅能减少体力消耗,还能有更加出色的清洁表现,不过面对鱼龙混杂的洗地机市场,如果不了解洗地机很容易买错,今天笔者教大家快速随了解市场主流洗地机的配置信…