pytorch retain_grad vs requires_grad

news/2025/3/10 18:17:51/

requires_grad大家都挺熟悉的,因此穿插在retain_grad的例子里进行捎带讲解就行。下面看一个代码片段:

python">import torch# 创建一个标量 tensor,并开启梯度计算
x = torch.tensor(2.0, requires_grad=True)# 中间计算:y 依赖于 x,是非叶子节点
y = x * 3# 继续计算,得到 z
z = y * 4# 反向传播
z.backward()# 查看梯度
print("x.grad:", x.grad)  
print("y.grad:", y.grad)  

输出结果为:

python">x.grad: tensor(12.)
y.grad: None
/tmp/ipykernel_219007/1060175670.py:17: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations. (Triggered internally at aten/src/ATen/core/TensorBody.h:489.)print("y.grad:", y.grad)

警告的大致意思是:访问了非叶子节点的.grad属性,但非叶子节点的.grad属性并不会在反向传播的过程中被自动保存下来(这是为了节省内存,毕竟我们只需要计算那些手动设置.requires_gradTrue的张量的梯度,并进行梯度更新,对吧?)

因此,我们只需要添加一行代码y.retain_grad(),修改后的代码如下:

python">import torch# 创建一个标量 tensor,并开启梯度计算
x = torch.tensor(2.0, requires_grad=True)# 中间计算:y 依赖于 x,是非叶子节点
y = x * 3
y.retain_grad()# 继续计算,得到 z
z = y * 4# 反向传播
z.backward()# 查看梯度
print("x.grad:", x.grad)  
print("y.grad:", y.grad)  

输出结果为:

python">x.grad: tensor(12.)
y.grad: tensor(4.)

可以看到,现在非叶子节点y的梯度也在反向传播以后被正确保存了!


http://www.ppmy.cn/news/1578149.html

相关文章

sysbench手动测试OceanBase v4.2.4集群

环境: 1、ocp(sysbench节点) 192.192.103.128 2、ob集群1-1-1 observer 192.192.103.125、192.192.103.126、192.192.103.127,primary_zone:random haproxy 192.192.103.125、192.192.103.126、192.192.103.127 一、安装sysben…

uniapp版本加密货币行情应用

uniapp版本加密货币行情应用 项目概述 这是一个使用uniapp开发的鸿蒙原生应用,提供加密货币的实时行情查询功能。本应用旨在为用户提供便捷、实时的加密货币市场信息,帮助用户随时了解市场动态,做出明智的投资决策。 应用采用轻量级设计&a…

使用QT + 文件IO + 鼠标拖拽事件 + 线程 ,实现大文件的传输

第一题、使用qss&#xff0c;通过线程&#xff0c;使进度条自己动起来 mythread.h #ifndef MYTHREAD_H #define MYTHREAD_H#include <QObject> #include <QThread> #include <QDebug>class mythread : public QThread {Q_OBJECT public:mythread(QObject* …

Xenium数据分析 | 下机数据读取

今天我们将下载10x官方人肺癌FFPE样本Xenium5k下机数据&#xff0c;使用python的spatialdata库&#xff0c;演示如何进行Xenium单个样本/多样本数据读取&#xff0c;以及简单绘图功能展示。 1. 示例数据下载&#xff1a; 数据下载地址: https://www.10xgenomics.com/datasets…

关于Springboot 应配置外移和Maven个性化打包一些做法

期望达到的效果是每次更新服务器端应用只需要更新主程序jar 依赖jar单独分离。配置文件独立存放于文件夹内&#xff0c;更新程序并不会覆盖已有的配置信息。 一、配置外移 1、开发环境外移 做法&#xff1a;在项目同级或者上级创建config文件夹放置配置文件&#xff0c;具体m…

Windows控制台函数:控制台读取输入函数ReadConsoleA()

目录 什么是 ReadConsoleA&#xff1f; 它长什么样&#xff1f; 怎么用它&#xff1f; 它跟 std::cin 有什么不一样&#xff1f; 注意事项 什么是 ReadConsoleA&#xff1f; ReadConsoleA 是一个 Windows API 函数&#xff0c;用来从控制台读取用户输入。想象一下&#…

Java进阶:Docker

1. Docker概述 1.1. Docker简介 Docker 是一个开源的应用容器引擎&#xff0c;基于 Go 语言开发。Docker 可以让开发者打包他们的应用以及依赖包到一个轻量级、可移植的容器中&#xff0c;然后发布到任何流行的 Linux 机器上&#xff0c;也可以实现虚拟化。容器是完全使用沙箱…

DR和BDR的选举规则

在 OSPF&#xff08;开放最短路径优先&#xff09;协议中&#xff0c;DR&#xff08;Designated Router&#xff0c;指定路由器&#xff09; 和 BDR&#xff08;Backup Designated Router&#xff0c;备份指定路由器&#xff09; 的选举是为了在广播型网络&#xff08;如以太网…