Pytorch 自动微分注意点讲解

ops/2024/9/23 10:29:56/

backward()

backward()函数是pytorch框架实现自动微分的关键函数,一般通过loss.backward()调用,这里的loss一般是标量张量

python">import numpy as np
import torch
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print(device
)
data1 = torch.randint(0,10,(1,3),dtype=torch.float,requires_grad=True,device=device)
print(data1)
y = data1.pow(2)+100
print('y =',y)
loss = y.mean()
print('loss =',loss)
loss.backward()
print("data1's grad =",data1.grad)
# mps
# tensor([[1., 9., 0.]], device='mps:0', requires_grad=True)
# y = tensor([[101., 181., 100.]], device='mps:0', grad_fn=<AddBackward0>)
# loss = tensor(127.3333, device='mps:0', grad_fn=<MeanBackward0>)
# data1's grad = tensor([[0.6667, 6.0000, 0.0000]], device='mps:0')

可以看到这里模拟了一个函数计算,mean()模拟了损失计算,目的是将食粮张量转为标量张量

在对损失loss进行了反向传播后,叶子节点data1便有了grad属性.也就是2*data1

设备迁移注意点

python">import numpy as np
import torch
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print(device
)
data1 = torch.randint(0,10,(1,3),dtype=torch.float,requires_grad=True).to(device)
print(data1)
y = data1.pow(2)+100
print('y =',y)
loss = y.mean()
print('loss =',loss)
loss.backward()
print("data1's grad =",data1.grad)
# mps
# tensor([[0., 9., 0.]], device='mps:0', grad_fn=<ToCopyBackward0>)
# y = tensor([[100., 181., 100.]], device='mps:0', grad_fn=<AddBackward0>)
# loss = tensor(127., device='mps:0', grad_fn=<MeanBackward0>)
# /Users/jinhouji/PycharmProjects/pythonProject/lesson01.py:13: UserWarning:
# 
# The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/build/aten/src/ATen/core/TensorBody.h:494.)
# 
# data1's grad = None

这里在创建完张量进行to(device)后,会出现打印梯度为None的情况,这是由于在张量创立后进行设备转移操作会导致grad_fn(也就是函数操作记录)的丢失,所以这个时候可以通过detach()+requires_grad_(True)函数来重新建立叶子节点

detach()

detach()函数可以建立一个与原张量共享内存但不进行梯度计算的全新张量

clone()

clone()函数可以拷贝一个和原张量具有相同计算图和张量值的全新张量

requires_grad_()

requires_grad_()函数可以将张量的梯度计算权限打开

is_leaf()

is_leaf()函数用于判断张量是否为叶子节点,返回布尔值

综上,我们可以尝试去重建叶子节点

python">import numpy as np
import torch
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print(device
)
data1 = torch.randint(0,10,(1,3),dtype=torch.float,requires_grad=True).to(device).detach().requires_grad_(True)
print(data1)
print(data1.is_leaf)
y = data1.pow(2)+100
print('y =',y)
loss = y.mean()
print('loss =',loss)
loss.backward()
print("data1's grad =",data1.grad)
# mps
# tensor([[7., 8., 8.]], device='mps:0', requires_grad=True)
# True
# y = tensor([[149., 164., 164.]], device='mps:0', grad_fn=<AddBackward0>)
# loss = tensor(159., device='mps:0', grad_fn=<MeanBackward0>)
# data1's grad = tensor([[4.6667, 5.3333, 5.3333]], device='mps:0')

以上为使用了detach()和requires_grad()的方法,所以注意要将非叶子节点拆分为叶子节点的方法就是以上所示的流程

暂停梯度计算方法

一般在做模型推理和评估的时候需要暂停梯度计算,以下列举三种停止梯度计算的方法

with torch.no_grad():
python">import numpy as np
import torch
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print(device
)
data1 = torch.randint(0,10,(1,3),dtype=torch.float,requires_grad=True).to(device).clone()
print(data1)
print(data1.is_leaf)
with torch.no_grad():y = data1.pow(2)+100
print(y.requires_grad)
# mps
# tensor([[0., 5., 5.]], device='mps:0', grad_fn=<CloneBackward0>)
# False
# False

with代码块下的语句涉及张量生成的代码,都不会进行梯度计算

@torch.no_grad()

@torch.no_grad()装饰器装饰的函数内部若进行张量生成则不会进行梯度计算

python">import numpy as np
import torch
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print(device
)
data1 = torch.randint(0,10,(1,3),dtype=torch.float,requires_grad=True).to(device).clone()
print(data1)
print(data1.is_leaf)
@torch.no_grad()
def func11(data1):return data1.pow(2)+100
print(func11(data1).requires_grad)
# mps
# tensor([[3., 4., 4.]], device='mps:0', grad_fn=<CloneBackward0>)
# False
# False
torch.set_grad_enabled()

torch.set_grad_enabled()函数可以通过设置参数为False来关闭梯度计算,如需开启梯度计算则需要重新调用函数设置参数为True

python">import numpy as np
import torch
from torch import set_grad_enableddevice = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print(device
)
data1 = torch.randint(0,10,(1,3),dtype=torch.float,requires_grad=True).to(device).clone()
print(data1)
print(data1.is_leaf)
torch.set_grad_enabled(False)
y=data1.pow(2)+100
print(y.requires_grad)
set_grad_enabled(True)
y=data1.pow(2)+100
print(y.requires_grad)
# mps
# tensor([[7., 8., 6.]], device='mps:0', grad_fn=<CloneBackward0>)
# False
# False
# True


http://www.ppmy.cn/ops/99306.html

相关文章

Web3链上聚合器声呐已全球上线,开启区块链数据洞察新时代

在全球区块链技术高速发展的浪潮中&#xff0c;在创新发展理念的驱动下&#xff0c;区块链领域的工具类应用备受资本青睐。 2024年8月20日&#xff0c;由生纳&#xff08;香港&#xff09;国际集团倾力打造的一款链上应用工具——“声呐链上聚合器”&#xff0c;即“声呐链上数…

Prompt-Tuning 和 LoRA大模型微调方法区别

Prompt-Tuning 和 LoRA&#xff08;Low-Rank Adaptation&#xff09;都是在预训练语言模型基础上进行微调的方法&#xff0c;它们有以下一些区别&#xff1a; 一、调整方式 Prompt-Tuning&#xff1a; 主要是通过优化特定任务的提示&#xff08;prompt&#xff09;来实现微调。…

Eureka Server高可用模式详解:实现无缝的故障转移与容灾

目录 引言 Eureka Server背景与重要性高可用模式的必要性故障转移与容灾的核心概念 Eureka Server概述 Eureka架构简介Eureka Server与Eureka Client的工作机制Eureka在微服务架构中的角色与功能 Eureka Server的单节点架构及其局限性 单节点部署的特点单点故障的影响面临的挑…

Adobe Animate (AN)软件安装,硬件配置(附安装包)

目录 一、Adobe An 软件简介 Adobe An 软件的特点 Adobe An 软件的优势 下载 二、Adobe An 软件安装 安装前的准备工作 安装过程中的注意事项 安装后的设置 三、Adobe An 软件使用 高级动画技巧 交互设计 优化与性能提升 四、Adobe An 软件快捷键 选择工具快捷键…

设计模式 - 行为型模式(第六章)

目录 6、行为型模式 6.1 模板方法模式 6.1.1 概述 6.1.2 结构 6.1.3 案例实现 6.1.3 优缺点 6.1.4 适用场景 6.1.5 JDK源码解析 6.2 策略模式 6.2.1 概述 6.2.2 结构 6.2.3 案例实现 6.2.4 优缺点 6.2.5 使用场景 6.2.6 JDK源码解析 6.3 命令模式 6.3.1 概述 …

C程序设计——运算符1

条件运算符 这是一个三目运算符&#xff0c;用于条件求值(?:)。 来源&#xff1a;百度百科 这是C语言里&#xff0c;唯一三目&#xff08;即三个表达式&#xff09;运算符。具体格式如下&#xff1a; (表达式1) ? (表达式2) : (表达式3) ; 翻译成人话&#xff0c;就是&…

rust api接口开发(以登陆和中间件鉴权为例)

rust rest api接口开发 所需依赖 axumtokioredis cargo add axum redis cargo add tokio --featuresfull路由服务创建和运行 //子路由 let v1router axum::Router::new(); //主路由,并将子路由绑定到主路由 let routeraxum::Router::new().nest("/v1",v1router)…

MySQL分页查询--LIMIT、OFFSET

limit简介 LIMIT 是 MySQL 中的一个特殊关键字&#xff0c;用于指定查询结果从哪条记录开始显示&#xff0c;一共显示多少条记录。 LIMIT 关键字有两种使用方式&#xff0c;即指定初始位置、与 OFFSET 组合使用 用法一&#xff1a;获取结果集的特定部分 limit中文可翻译为限…