深度学习 自动求梯度

devtools/2024/10/25 9:12:01/

 代码示例:

import torch# 创建一个标量张量 x,并启用梯度计算
x = torch.tensor(3.0, requires_grad=True)# 计算 y = x^2
y = torch.pow(x, 2)# 判断 x 和 y 是否需要梯度计算
print(x.requires_grad)  # 输出 x 的 requires_grad 属性
print(y.requires_grad)  # 输出 y 的 requires_grad 属性# 反向传播,计算 y 对 x 的导数
y.backward()# 查看 x 的梯度
print(x.grad)  # 输出 x 的梯度

 代码详解:

代码一:
x = torch.tensor(3.0, requires_grad=True)

在这行代码中,创建了一个 PyTorch 张量 x。

解释

  1. torch.tensor(3.0):

    • 这部分创建了一个张量,其值为 3.0。这是一个标量张量,数据类型为浮点数(float)。
  2. requires_grad=True:

    • 这个参数指定 PyTorch 需要跟踪该张量的所有操作,以便在后续的反向传播过程中计算梯度。换句话说,设置 requires_grad=True 使得这个张量在执行任何操作后,能够计算其梯度。
 代码二:
y = torch.pow(x, 2)

解释

  1. torch.pow(x, 2):

    • 这是 PyTorch 中的一个函数,用于计算 x 的幂。在这里,x 被提高到 2 的幂,即计算 (x^2)。
    • 由于之前我们已经定义了 x = torch.tensor(3.0, requires_grad=True),所以 torch.pow(x, 2) 实际上会计算 (3.0^2),得到的结果是 9.0
  2. y = ...:

    • 将计算得到的结果 9.0 存储在张量 y 中。由于 x 的 requires_grad 属性为 True,PyTorch 会自动设置 y 的 requires_grad 属性为 True,使得 y 也可以用于梯度计算。
代码三: 
y.backward()

 y.backward() 是 PyTorch 中用于计算梯度的重要方法,它在反向传播过程中发挥着关键作用。

解释

  1. y.backward():
    • 这行代码触发反向传播,以计算损失函数 y 相对于输入张量 x 的梯度。
    • 在调用 backward() 方法之前,计算图已经构建完毕,y 是通过某些操作(例如 torch.pow(x, 2))生成的张量。
    • 当 backward() 被调用时,PyTorch 会从 y 开始,沿着计算图向后传播,计算所有需要计算梯度的张量的梯度。

自动微分

  • 在 PyTorch 中,backward() 使用自动微分(automatic differentiation)来计算梯度。这意味着系统会自动根据张量间的运算关系,利用链式法则来计算每个张量的梯度。

计算过程

  • 在之前的例子中,我们定义了 ( y = x^2 ) 并且 ( x = 3.0 )。
  • 在反向传播过程中,PyTorch 计算 ( \frac{dy}{dx} ) 的值:
    • 根据导数公式,( \frac{dy}{dx} = 2x )。
    • 对于 ( x = 3.0 ),因此 ( \frac{dy}{dx} = 2 \times 3.0 = 6.0 )。
  • 这个值会被存储在 x.grad 中,方便后续使用。 

print(x.grad) 语句用于输出张量 x 的梯度值。 

为了更好的理解什么是梯度,看下面示例代码: 

示例二:

import torch
x=torch.tensor(3.0,requires_grad=True)
y=torch.tensor(4.0,requires_grad=False)
z=torch.pow(x,2)+torch.pow(y,2)
print("x.requires_grad:",x.requires_grad)
print("y.requires_grad:",y.requires_grad)
print("z.requires_grad:",z.requires_grad)
z.backward()
print("x.grad:",x.grad)
print("y.grad:",y.grad)
print("z.grad:",z.grad)
print(z)

输出:

x.requires_grad: True
y.requires_grad: False
z.requires_grad: True
x.grad: tensor(6.)
y.grad: None
z.grad: None
tensor(25., grad_fn=<AddBackward0>)

输出解释

  1. 可求导性检查:

    • x.requires_grad: True 表示 x 是一个可求导的张量。
    • y.requires_grad: False 表示 y 不是可求导的张量。
    • z.requires_grad: True 表示 z 是可求导的,因为它是由可求导的张量 x 计算得出的。
  2. 梯度计算:

    • 调用 z.backward() 时,计算了 z 关于 x 的梯度。
    • y.grad 输出 None,因为 y 不可求导。
  3. 关于 z 的梯度:

    • z.grad 输出 None,这是因为 z 不是叶子节点。只有叶子节点的 grad 属性会被自动设置。
我们在运行此段代码时会遇到一个警告:

 大致意思是:

你在访问 z.grad 时遇到的警告提示你正在访问一个非叶子张量的梯度属性。此警告说明 z 不是一个叶子张量,因此其 .grad 属性在执行 backward() 时不会被填充。

叶子张量与非叶子张量

在 PyTorch 中,叶子张量(leaf tensors)是指那些没有任何历史计算的张量,通常是由用户直接创建的张量(例如通过 torch.tensor() 创建)。而 非叶子张量 是由其他张量经过操作计算得出的张量(例如加法、乘法等操作生成的结果)。

为了使非叶子张量的 .grad 属性被填充,你可以在计算图中使用 .retain_grad() 方法。这将允许你在调用 backward() 后访问非叶子张量的梯度。

请看修改后的示例三:

 示例三:

import torch# 创建一个可求导的张量 x 和一个不可求导的张量 y
x = torch.tensor(3.0, requires_grad=True)  # x 可求导
y = torch.tensor(4.0, requires_grad=False) # y 不可求导# 定义函数 z = f(x, y) = x^2 + y^2
z = torch.pow(x, 2) + torch.pow(y, 2)# 让 z 保留梯度
z.retain_grad()# 打印每个张量的 requires_grad 属性
print("x.requires_grad:", x.requires_grad)  # 输出: True
print("y.requires_grad:", y.requires_grad)  # 输出: False
print("z.requires_grad:", z.requires_grad)  # 输出: True# 反向传播以计算梯度
z.backward()# 打印 x 和 y 的梯度
print("x.grad:", x.grad)  # 输出: tensor(6.)
print("y.grad:", y.grad)  # 输出: None
print("z.grad:", z.grad)  # 输出: tensor(1.)

为什么z的梯度为1或者z的导为1?

z 对自身的导数为1

举个例子:

y=x**2;

y对于x的导为2*x;

y对于自身的导为1。


http://www.ppmy.cn/devtools/128646.html

相关文章

JMeter快速入门示例

JMeter是一款开源的性能测试工具&#xff0c;常用于对Web服务和接口进行性能测试。 下载安装 官方下载网址&#xff1a; https://jmeter.apache.org/download_jmeter.cgi也可以到如下地址下载&#xff1a;https://download.csdn.net/download/oscar999/89910834 这里下载Wi…

CMake中的List关键词:详细指南

CMake中的List关键词&#xff1a;详细指南 一、List的基本概念二、List的常用命令1. 获取List的长度2. 获取List中指定索引的元素3. 将元素追加到List中4. 在List中指定位置插入元素5. 在List的开头插入元素6. 从List中移除元素7. 移除List中的重复元素8. 对List进行排序9. 将L…

【TIMM库】是一个专门为PyTorch用户设计的图像模型库 python库

TIMM库 1、引言&#xff1a;遇见TIMM2、初识TIMM&#xff1a;安装与基本结构3、实战案例一&#xff1a;图像分类4、实战案例二&#xff1a;迁移学习5、实战案例三&#xff1a;模型可视化6、结语&#xff1a;TIMM的无限可能 1、引言&#xff1a;遇见TIMM 大家好&#xff0c;我是…

基于 Datawhale 开源量化投资学习指南(8):量化调仓策略

1. 引言 在前面的章节中&#xff0c;我们学习了如何通过多因子模型和量化择时策略对股票的未来收益进行预测。我们探讨了如何根据这些预测信号进行投资决策。量化投资的一个核心挑战是如何在有限的资金约束下&#xff0c;合理分配资金到多个标的上&#xff0c;从而构建一个优化…

网站被浏览器提示“不安全”,如何快速解决

当网站被浏览器提示“不安全”时&#xff0c;这通常意味着网站存在某些安全隐患&#xff0c;需要立即采取措施进行解决。 一、具体原因如下&#xff1a; 1.如果网站使用的是HTTP协议&#xff0c;应立即升级HTTPS。HTTPS通过使用SSL证书加密来保护数据传输&#xff0c;提高了网…

centos7.x安装openCV 4.6.0版本

## 从源代码编译安装 1.更新系统 sudo yum update -y 2.安装依赖项 sudo yum groupinstall "Development Tools" sudo yum install cmake gcc-c git libjpeg-turbo-devel libpng-devel libtiff-devel libwebp-devel openexr-devel gstreamer1-plugins-base-devel…

百科知识|选购指南

百科知识||选购指南 百科知识选购指南茶叶分类茶叶的味道来源茶叶制作步骤名茶其他一些茶叶的知识 百科知识 选购指南 茶叶 分类 茶叶种类: 六大茶类完美分析介绍&#xff01;茶友推荐收藏 (aboxtik.com) 1.绿茶&#xff08;发酵率0%&#xff09; 2.白茶&#xff08;发酵率…

算法汇总整理篇——回溯与图论的千丝万缕及问题的抽象思考

回溯算法(重中之重) 回溯法解决的问题都可以抽象为树形结构&#xff0c;集合的大小就构成了树的广度&#xff0c;递归的深度就构成了树的深度。 (回溯的核心&#xff1a;分清楚什么数据作为广度&#xff0c;什么数据作为深度&#xff01;&#xff01;&#xff01;&#xff01…