pytorch训练过程搭建及模型的保存与加载

server/2024/9/22 23:02:33/

1. 训练一个分类器

python">import torch.utils.data
import torchvision as tv
import torchvision.transforms as transforms
from torchvision.transforms import ToPILImage
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable# 定义网络结构
class myNet(nn.Module):def __init__(self):super(myNet, self).__init__()self.conv1 = nn.Conv2d(3,6,5)self.conv2 = nn.Conv2d(6,16,5)self.fc1 = nn.Linear(16*5*5, 120)  # 两次卷积、两次2*2的max_pool2d之后尺寸变成5*5self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = F.max_pool2d(F.relu(self.conv1(x)), (2,2))x = F.max_pool2d(F.relu(self.conv2(x)), (2,2))x = x.view(x.size()[0], -1)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return xnet = myNet()
#print(net)# 定义损失函数和优化器(loss and optimizer)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9)# 预处理
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
# 加载训练集
trainset = tv.datasets.CIFAR10(root=r'D:\02.Work\06.LearnPyTorch\001\data', train=True, download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,shuffle=True,num_workers=2)
# 加载验证集
testset = tv.datasets.CIFAR10(r'D:\02.Work\06.LearnPyTorch\001\data', train=False, download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,shuffle=False,num_workers=2)class_names = ('airplane', 'automobile', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')#(data, label) = trainset[100]
#print(class_names[label])
#print(type(data))show = ToPILImage()  # 可以把Tensor转成Image,方便可视化
if __name__ == '__main__':#img_forshow = show((data + 1)/2).resize((100,100))#img_forshow.show()#dataiter = iter(trainloader)#images, labels = next(dataiter)#imgs_forshow = show(tv.utils.make_grid((images + 1)/2)).resize((400,100))#imgs_forshow.show()# 训练for epoch in range(15):running_loss = 0.0for i, data in enumerate(trainloader, 0):# 输入数据inputs, labels = datainputs, labels = Variable(inputs), Variable(labels)if torch.cuda.is_available():net.cuda()inputs = inputs.cuda()labels = labels.cuda()# 每次前向都要清零梯度optimizer.zero_grad()# 前向传播、反向传播outputs = net(inputs)loss = criterion(outputs, labels)loss.backward()# 更新参数optimizer.step()running_loss += loss.item()# 打印lossif i % 2000 == 0:print('[epoch %d, step %d], loss = %f' % (epoch, i , running_loss / 2000))running_loss = 0.0print('Training finished.')# 测试一个test batch的结果dataiter = iter(testloader)images, labels = next(dataiter)print('groundtruth: ', ' '.join('%08s'%class_names[labels[j]] for j in range(4)))show(tv.utils.make_grid(images / 2 - 0.5)).resize((400, 100))if torch.cuda.is_available():images = images.cuda()outputs = net(Variable(images))_, predict = torch.max(outputs.data, 1)print('predicted results:', ' '.join('%5s'%class_names[predict[j]] for j in range(4)))# 测试整个测试集的效果correct = 0total = 0for data in testloader:images, labels = dataif torch.cuda.is_available():images = images.cuda()labels = labels.cuda()outputs = net(Variable(images))_, predict = torch.max(outputs.data, 1)total += labels.size(0)correct += (predict == labels).sum()print('1W张测试集中的准确率为: %d %%'%(100 * correct / total))print()

2. 训练一个线性回归

python"># y = 2*x + 3
import torch
import torch.nn as nn
from torch import optim
from matplotlib import pyplot as plt
#from IPython import displayclass LinearNet(nn.Module):def __init__(self):super(LinearNet, self).__init__()self.fc = nn.Linear(1, 1)def forward(self, input):return self.fc(input)#torch.manual_seed(1)def getFakeData(batch_size = 18):x = torch.rand(batch_size, 1)*30y = 2 * x + 3 + (torch.randn(batch_size, 1))  # 尾部是随机噪声return x,yif __name__ == '__main__':net = LinearNet()print(net.state_dict().keys())  # 打印所有参数名称# 随机初始化的参数print("随机初始化的参数:")print('net.fc.weight = ', net.fc.weight.detach().numpy())print('net.fc.bias = ', net.fc.bias.detach().numpy())criterion = nn.MSELoss()optimizer = optim.SGD(net.parameters(), lr=1e-3)for i in range(20000):x, y = getFakeData()y_pred = net(x)loss = criterion(y, y_pred)optimizer.zero_grad()loss.backward()optimizer.step()print('*'*30)print("训练后的参数:")print('net.fc.weight = ', net.fc.weight.detach().numpy())print('net.fc.bias = ', net.fc.bias.detach().numpy())if 10:torch.save(net, 'net.pth')   # 不推荐,因为这种保存方式依赖模型定义及文件路径结构等b = torch.load('net.pth')print('b.fc.weight = ', b.fc.weight.detach().numpy())print('b.fc.bias = ', b.fc.bias.detach().numpy())print('*'*30)if 1:torch.save(net.state_dict(), 'net2.pth')net2 = LinearNet()c = net2.load_state_dict(torch.load('net2.pth'))print('net2.fc.weight = ', net2.fc.weight.detach().numpy())print('net2.fc.bias = ', net2.fc.bias.detach().numpy())

输出:

odict_keys(['fc.weight', 'fc.bias'])
随机初始化的参数:
net.fc.weight =  [[0.37508917]]
net.fc.bias =  [0.81224954]
******************************
训练后的参数:
net.fc.weight =  [[2.0011635]]
net.fc.bias =  [2.9967082]
b.fc.weight =  [[2.0011635]]
b.fc.bias =  [2.9967082]
******************************
net2.fc.weight =  [[2.0011635]]
net2.fc.bias =  [2.9967082]

http://www.ppmy.cn/server/118292.html

相关文章

laravel 11 区分多模块的token

数据表:用户表(users)、管理员表(admin_user), 配置bootstrap/app.php guards > [web > [driver > session,provider > admin_users,],home > [driver > sanctum,provider > users,]…

2024ICPC网络赛第一场

A 最终答案与中国队能力值的排名有关&#xff0c;具体每个情况手推一下&#xff0c;用 if else 即可通过。 #include <bits/stdc.h> using namespace std;int main() {ios::sync_with_stdio(false); cin.tie(0);int t, a[40];cin >> t;while (t--) {int num 0;f…

如何快速解决程序中的BUG

前提 获得更多信息 - 搞清楚为什么bug会发生什么情况下会发生、用户到底做了什么操作&#xff0c;才导致这个bug、是每次都会出现bug、还是偶发性、是否可以复现&#xff08;不能复现的bug&#xff0c;还能叫bug&#xff09;&#xff1f;拿到用户详细的报错输出明确边界&#…

3D云渲染农场为何怎么贵?主要消耗成本介绍

随着对高质量3D动画的需求持续增长&#xff0c;云渲染农场对于旨在以高效速度生产高质量视觉效果的工作室来说变得至关重要。然而&#xff0c;用户经常想知道为什么渲染农场的价格如此之高&#xff0c;理解背后的原因可以帮助艺术家做出更好的选择。 什么是云渲染农场&#xff…

《深度学习》PyTorch 手写数字识别 案例解析及实现 <上>

目录 一、了解MINIST数据集 1、什么是MINIST 2、查看MINIST由来 二、实操代码 1、下载训练数据集 2、下载测试数据集 运行结果&#xff1a; 3、展示手写数字图片 运行结果&#xff1a; 4、打包图片 运行结果&#xff1a; 5、判断当前pytorch使用的设备 1&#xff…

c语言--力扣简单题目(回文链表)讲解

题目如下&#xff1a; 给你一个单链表的头节点 head &#xff0c;请你判断该链表是否为 回文链表。 如果是&#xff0c;返回 true &#xff1b;否则&#xff0c;返回 false 。 示例 1&#xff1a; 输入&#xff1a;head [1,2,2,1] 输出&#xff1a;true 示例 2&#xff1…

DockerDocker Compose安装(离线+在线)

Docker&Docker Compose安装(离线在线) Docker离线安装 下载想要安装的docker软件版本&#xff1a;https://download.docker.com/linux/static/stable/x86_64/ 如目标机无法从链接下载&#xff0c;可以在本机下载后 scp docker版本压缩包[如docker-20.10.9.tgz] usernameh…

电巢科技携Ecosmos元宇宙产品亮相第25届中国光博会

第25届中国国际光电博览会&#xff08;“CIOE中国光博会”&#xff09;今日在深圳国际会展中心盛大开幕。本届博览会以“光电引领未来&#xff0c;驱动应用创新”为主题&#xff0c;吸引了全球超过3700家优质光电企业参展&#xff0c;展示了光电产业的最新成果和前沿技术。 电…