反向传播和优化 pytorch

news/2024/10/19 7:03:13/

**前置知识:

     优化器:optim=torch.optim.SGD(xigua1.parameters(),lr=0.01) 传入模型的参数、学习速率

     计算损失:result_loss=loss(outputs,targets)

     梯度清零:optim.zero_grad()

     计算梯度并反向传播:result_loss.backward()

     更新参数:optim.step()

  1. optim.zero_grad(): 在每次训练迭代之前清除所有优化器(如SGD、Adam等)维护的梯度信息。在神经网络中,每个参数(如权重和偏置)都有一个与之关联的梯度,这个梯度表示参数对损失函数的贡献程度。随着训练的进行,这些梯度会被累积,如果不加以重置,会导致梯度累加,从而影响模型的学习效果。因此,zero_grad() 函数通过将这些梯度重置为零,确保了每次迭代都是在无偏见的情况下开始。

  2. result_loss.backward(): 执行反向传播算法,计算损失函数相对于模型参数的梯度。在神经网络前向传播过程中,网络输出与实际标签之间的差异被量化为损失函数。backward() 函数通过链式法则自动计算损失函数对每个参数的梯度,这些梯度随后被存储在相应的参数的 .grad 属性中。这一步是优化过程的核心,因为它直接关系到参数如何被调整以最小化损失。

  3. optim.step(): 在计算出损失函数的梯度后,step() 函数根据这些梯度来更新模型参数。优化器使用特定的算法(如梯度下降、Adam等)来决定如何更新每个参数,以便在下一次迭代中减少损失。简而言之,step() 函数实现了从当前参数状态向更优参数状态的“跳跃”。

总的来说,这三个函数协同工作,形成了深度学习中参数优化的基本流程:首先清除旧的梯度信息,然后计算新的梯度,最后根据这些梯度更新参数。这一过程在每次训练迭代中重复进行,直到模型的性能满足要求或达到预设的停止条件。

**代码: 

python">import torch
import torchvision.datasets
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter#以CIFAR10的分类检测为例test_set=torchvision.datasets.CIFAR10(root="./dataset",train=False,transform=torchvision.transforms.ToTensor(),download=True)
dataloader=DataLoader(test_set,batch_size=1)class Xigua(nn.Module):def __init__(self):super().__init__()self.model1=Sequential(Conv2d(3,32,5,padding=2),MaxPool2d(2),Conv2d(32,32,5,padding=2),MaxPool2d(2),Conv2d(32,64,5,padding=2),MaxPool2d(2),Flatten(),Linear(1024,64),Linear(64,10),)def forward(self,x):x=self.model1(x)return xxigua1=Xigua()
loss=nn.CrossEntropyLoss()
optim=torch.optim.SGD(xigua1.parameters(),lr=0.01)#为了节省时间,这里能显示出优化的效果即可,就只训练5轮,每轮都只是计算前10个数据
for epoch in range(5): #训练5轮running_loss=0.0 #每轮都计算出一个所有数据损失的总和step=0for data in dataloader:imgs,targets=dataoutputs=xigua1(imgs)result_loss=loss(outputs,targets)optim.zero_grad() #将梯度清零result_loss.backward() #计算损失对应的梯度,并将其反向传播optim.step() #更新模型参数#loss函数在其中只是起到了一个提供梯度的作用,而这个梯度就藏在optim中running_loss+=result_lossstep+=1if step>=10:breakprint(running_loss)


http://www.ppmy.cn/news/1540202.html

相关文章

开发一个微信小程序要多少钱?

在当今数字化时代,微信小程序成为众多企业和个人拓展业务、提供服务的热门选择。那么,开发一个微信小程序究竟需要多少钱呢? 开发成本主要取决于多个因素。首先是功能需求的复杂程度。如果只是一个简单的信息展示小程序,功能仅限…

刷题 排序算法

912. 排序数组 注意这道题目所有 O(n^2) 复杂度的算法都会超过时间限制&#xff0c;只有 O(nlogn) 的可以通过 快速排序空间复杂度为 O(logn)是由于递归的栈的调用归并排序空间复杂度为 O(n) 是由于需要一个临时数组 (当然也需要栈的调用&#xff0c;但是 O(logn) < O(n) 的…

Godot中类和静态类型

目录 类 关键字class_name 除了为类定义方法&#xff0c;我们也可以为类定义属性字段 实例释放前后的打印 Refcounted RefCounted维护了一个引用计数器 get_reference_count 类是引用类型数据 class关键字 静态类型 静态方法 静态方法只能访问静态变量 类 是面向…

通过OpenCV实现 Lucas-Kanade 算法

目录 简介 Lucas-Kanade 光流算法 实现步骤 1. 导入所需库 2. 视频捕捉与初始化 3. 设置特征点参数 4. 创建掩模 5. 光流估计循环 6. 释放资源 结论 简介 在计算机视觉领域&#xff0c;光流估计是一种追踪物体运动的技术。它通过比较连续帧之间的像素强度变化来估计图…

如何在OceanBase中新增系统变量及应用实践

因为系统变量涉及复杂的工程文件&#xff0c;为防止新增变量操作对软件系统的潜在影响&#xff0c;OceanBase为多数开发者设计了一套高效的编程框架。此框架允许开发者在新增及使用系统变量时&#xff0c;仅需专注于变量定义的细节。具体来说&#xff0c;通过运行一个Python脚本…

ISNULL 和 COALESCE 区别

ISNULL 数据库支持&#xff1a;ISNULL 是 SQL Server 特有的函数。 参数数量&#xff1a;ISNULL 接受两个参数。第一个参数是要检查是否为 NULL 的表达式&#xff0c;第二个参数是当第一个参数为 NULL 时要返回的值。 类型转换&#xff1a;如果 ISNULL 的两个参数数据类型不同&…

git 报错 SSL certificate problem: certificate has expired

git小乌龟 报错 SSL certificate problem: certificate has expired 场景复现&#xff1a; 原因&#xff1a; 这个错误表明你在使用Git时尝试通过HTTPS进行通信&#xff0c;但是SSL证书已经过期。这通常发生在使用自签名证书或证书有效期已到期的情况下。 解决方法: 1.如果是…

关于jmeter设置为中文问题之后无法保存设置的若干问题

1、jemeter如何设置中文模式 Options--->Choose Language--->Chinese(Simplifies), 如此设置后就可显示中文模式(缺点&#xff1a;下次打开还是英文)&#xff1b;如下图所示&#xff1a; 操作完成之后&#xff1a; 但是下次重启之后依旧是英文&#xff1b; 2、在jmeter.…