使用Pytorch构建神经网络

news/2024/12/22 16:12:31/

构建神经网络的典型流程

  • 定义一个拥有可学习参数的神经网络
  • 遍历训练数据集
  • 处理输入数据使其流经神经网络
  • 计算损失值
  • 将网络参数的梯度进行反向传播
  • 以一定的规则更新网络的权重

我们首先定义一个Pytorch实现的神经网络:

# 导入若干工具包
import torch
import torch.nn as nn
import torch.nn.functional as F# 定义一个简单的网络类
class Net(nn.Module):def __init__(self):super(Net, self).__init__()# 定义第一层卷积神经网络, 输入通道维度=1, 输出通道维度=6, 卷积核大小3*3self.conv1 = nn.Conv2d(1, 6, 3)# 定义第二层卷积神经网络, 输入通道维度=6, 输出通道维度=16, 卷积核大小3*3self.conv2 = nn.Conv2d(6, 16, 3)# 定义三层全连接网络self.fc1 = nn.Linear(16 * 6 * 6, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):# (2, 2)的池化窗口下执行最大池化操作x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))x = F.max_pool2d(F.relu(self.conv2(x)), 2)x = x.view(-1, self.num_flat_features(x))x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return xdef num_flat_features(self, x):# 计算size, 除了第0个维度上的batch_sizesize = x.size()[1:]num_features = 1for s in size:num_features *= sreturn num_featuresnet = Net()
print(net)

运行结果
在这里插入图片描述
注意:
模型中所有的可训练参数, 可以通过net.parameters()来获得.

params = list(net.parameters())
print(len(params))
print(params[0].size())

运行结果:
在这里插入图片描述

  • 假设图像的输入尺寸为32 * 32:
input = torch.randn(1, 1, 32, 32)
out = net(input)
print(out)

运行结果
在这里插入图片描述

  • 有了输出张量后, 就可以执行梯度归零和反向传播的操作了.
net.zero_grad()
out.backward(torch.randn(1, 10))
  • 注意
    - torch.nn构建的神经网络只支持mini-batches的输入, 不支持单一样本的输入.
    - 比如: nn.Conv2d 需要一个4D Tensor, 形状为(nSamples, nChannels, Height, Width). 如果你的输入只有单一样本形式, 则需要执行input.unsqueeze(0), 主动将3D Tensor扩充成4D Tensor.

损失函数

  • 损失函数的输入是一个输入的pair: (output, target), 然后计算出一个数值来评估output和target之间的差距大小.
  • 在torch.nn中有若干不同的损失函数可供使用, 比如nn.MSELoss就是通过计算均方差损失来评估输入和目标值之间的差距
  • 应用nn.MSELoss计算损失的一个例子:
output = net(input)
target = torch.randn(10)# 改变target的形状为二维张量, 为了和output匹配
target = target.view(1, -1)
criterion = nn.MSELoss()loss = criterion(output, target)
print(loss)

运行结果:
在这里插入图片描述

  • 关于方向传播的链条: 如果我们跟踪loss反向传播的方向, 使用.grad_fn属性打印, 将可以看到一张完整的计算图如下:
input -> conv2d -> relu -> maxpool2d -> conv2d -> relu -> maxpool2d-> view -> linear -> relu -> linear -> relu -> linear-> MSELoss-> loss
  • 当调用loss.backward()时, 整张计算图将对loss进行自动求导, 所有属性requires_grad=True的Tensors都将参与梯度求导的运算, 并将梯度累加到Tensors中的.grad属性中.
print(loss.grad_fn)  # MSELoss
print(loss.grad_fn.next_functions[0][0])  # Linear
print(loss.grad_fn.next_functions[0][0].next_functions[0][0])  # ReLU

运行结果:
在这里插入图片描述
反向传播(backpropagation)

  • 在Pytorch中执行反向传播非常简便, 全部的操作就是loss.backward().
  • 在执行反向传播之前, 要先将梯度清零,否则梯度会在不同的批次数据之间被累加.
    执行一个反向传播的小例子:
# Pytorch中执行梯度清零的代码
net.zero_grad()print('conv1.bias.grad before backward')
print(net.conv1.bias.grad)# Pytorch中执行反向传播的代码
loss.backward()print('conv1.bias.grad after backward')
print(net.conv1.bias.grad)

运行结果:
在这里插入图片描述
更新网络参数

  • 更新参数最简单的算法就是SGD(随机梯度下降).
  • 具体的算法公式表达式为: weight = weight - learning_rate
    gradient 首先用传统的Python代码来实现SGD如下:
learning_rate = 0.01
for f in net.parameters():f.data.sub_(f.grad.data * learning_rate)

然后使用Pytorch官方推荐的标准代码如下:

# 首先导入优化器的包, optim中包含若干常用的优化算法, 比如SGD, Adam等
import torch.optim as optim# 通过optim创建优化器对象
optimizer = optim.SGD(net.parameters(), lr=0.01)# 将优化器执行梯度清零的操作
optimizer.zero_grad()output = net(input)
loss = criterion(output, target)# 对损失值执行反向传播的操作
loss.backward()
# 参数的更新通过一行标准代码来执行
optimizer.step()

小节总结
学习了构建一个神经网络的典型流程:

  • 定义一个拥有可学习参数的神经网络
  • 遍历训练数据集
  • 处理输入数据使其流经神经网络
  • 计算损失值
  • 将网络参数的梯度进行反向传播
  • 以一定的规则更新网络的权重

学习了损失函数的定义:

  • 采用torch.nn.MSELoss()计算均方误差.
  • 通过loss.backward()进行反向传播计算时, 整张计算图将对loss进行自动求导,
    所有属性requires_grad=True的Tensors都将参与梯度求导的运算, 并将梯度累加到Tensors中的.grad属性中.

学习了反向传播的计算方法:

  • 在Pytorch中执行反向传播非常简便, 全部的操作就是loss.backward().
  • 在执行反向传播之前, 要先将梯度清零, 否则梯度会在不同的批次数据之间被累加.
  • net.zero_grad()
  • loss.backward()

学习了参数的更新方法:

  • 定义优化器来执行参数的优化与更新.

    optimizer = optim.SGD(net.parameters(), lr=0.01)

  • 通过优化器来执行具体的参数更新.

    optimizer.step()


http://www.ppmy.cn/news/1136638.html

相关文章

MySQL数据库索引练习

1.学生表:Student (Sno, Sname, Ssex , Sage, Sdept) 学号,姓名,性别,年龄,所在系 Sno为主键 课程表:Course (Cno, Cname,) 课程号,课程名 Cno为主键 学生选课表:SC (Sno, Cno, Scor…

京东工业商品详情数据接口

京东工业品平台提供了商品详情API接口,通过该接口,可以获取商品详情信息,包括商品ID、名称、副标题、价格、库存、图片等相关信息。 获取京东工业品平台商品详情数据接口的具体步骤如下: 注册并登录京东工业品平台的API控制台&a…

文本分词排序

文本分词 在这个代码的基础上 把英语单词作为一类汉语,作为一类然后列出选项 1. 大小排序 2. 小大排序 3. 不排序打印保存代码 import jieba# 输入文本,让我陪你聊天吧~ lines [] print("请输入多行文本,以\"2333.3\"结束&am…

阿里云轻量应用服务器和ECS服务器有什么区别?超详细对比

阿里云服务器ECS和轻量应用服务器有什么区别?轻量和ECS优缺点对比,云服务器ECS是明星级云产品,适合企业专业级的使用场景,轻量应用服务器是在ECS的基础上推出的轻量级云服务器,适合个人开发者单机应用访问量不高的网站…

计算机毕业设计 基于SSM的民宿推荐系统的设计与实现 Java实战项目 附源码+文档+视频讲解

博主介绍:✌从事软件开发10年之余,专注于Java技术领域、Python人工智能及数据挖掘、小程序项目开发和Android项目开发等。CSDN、掘金、华为云、InfoQ、阿里云等平台优质作者✌ 🍅文末获取源码联系🍅 👇🏻 精…

设计加速!11个Adobe XD插件推荐!

你是否一直在寻找可以提升 Adobe XD 工作流程和体验的方法?如果是,一定要试试这些 Adobe XD 插件!本文将介绍 11 款好用的 Adobe XD 插件,这些插件可以为 UI/UX 设计添加很酷的新功能,极大提升你的工作效率和产出。让我…

springmvc-controller视图层配置SpringMVC处理请求的流程

目录 1. 什么是springmvc 2.项目中加入springmvc支持 2.1 导入依赖 2.2 springMVC配置文件 2.3 web.xml配置 2.4 中文编码处理 3. 编写一个简单的controller 4. 视图层配置 4.1 视图解析器配 4.2 静态资源配置 4.2 编写页面 4.3 页面跳转方式 5. SpringMVC处理请求…

qt 5.15.2 安卓 macos

macos环境安卓配置 我的系统是monterey12.5.1 打开qt的配置界面 这里版本是java1.8,注意修改这个json文件,显示包内容 {"common": {"sdk_tools_url": {"linux": "https://dl.google.com/android/repository/comm…