PyTorch 梯度计算详解:以 detach 示例为例

news/2024/12/15 15:54:44/

PyTorch 梯度计算详解:以 detach 示例为例

在深度学习中,梯度计算是训练模型的核心步骤,而 PyTorch 通过自动微分(autograd)模块实现了高效的梯度求解。本文将通过一个实际代码示例,详细讲解 PyTorch 的梯度计算过程,包括 backward() 函数的作用及工作原理、grad 属性的含义,以及如何分离计算图避免梯度传播。


示例代码

python">import torch# 定义张量
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)# 定义计算
y = x * 2
z = y.detach()  # 分离 z,z 不会参与反向传播
w = z ** 2# 反向传播
w.sum().backward()# 打印梯度
print("x 的梯度:", x.grad)  
# RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

1. 梯度计算的核心概念

1.1 什么是梯度?

梯度是一个标量或张量的导数,表示函数的变化率。在机器学习中,梯度用来衡量损失函数相对于参数的变化方向和大小,以指导参数更新。

1.2 计算图

PyTorch 会在执行张量操作时构建一棵动态的计算图,每个张量都是一个节点,操作(如加法、乘法)是连接这些节点的边。计算图的作用是记录操作的顺序和依赖关系,从而实现反向传播。

1.3 requires_grad 属性
  • requires_grad=True 时,PyTorch 会为该张量记录梯度。
  • 默认情况下,requires_grad=False,意味着该张量不会参与梯度计算。

2. 示例中梯度的具体计算过程

2.1 代码解析
  1. 定义张量 x

    python">x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
    

    x 是一个可求梯度的张量,构成了计算图的起点。

  2. 定义计算 y

    python">y = x * 2
    

    此时,y 的计算被记录在计算图中,且与 x 有依赖关系。公式为:
    y = [ 2.0 , 4.0 , 6.0 ] y = [2.0, 4.0, 6.0] y=[2.0,4.0,6.0]

  3. 分离 z

    python">z = y.detach()
    

    使用 detach() 分离 z 后,z 不再记录计算图。虽然 z 的值等于 y,但它和 y 的计算历史已断开。

  4. 计算 w

    python">w = z ** 2
    

    w 的计算与 z 相关,因为 z 已被分离,它的计算不会影响原始计算图。

  5. 反向传播:

    python">w.sum().backward()
    

    由于 w 是通过分离后的 z 计算而来,反向传播不会更新 x 的梯度。


2.2 梯度的实际计算

让我们一步步分析梯度是如何通过计算图传播的。由于 z 被分离,原始计算图如下:

x (requires_grad=True) → y = x * 2
  • y 是通过 x * 2 得到的,因此其梯度可以表示为:
    ∂ y ∂ x = 2 \frac{\partial y}{\partial x} = 2 xy=2

但是,由于 z = y.detach() 分离了计算图,zx 没有任何依赖关系,因此 梯度不会计算到 x

最终 x.grad 输出为 None


3. backward()grad 的介绍

3.1 backward() 函数

backward() 是 PyTorch 用于计算梯度的核心函数。它从计算图的末端开始,沿着图的依赖关系,逐层向前计算每个张量的梯度。

使用方法:

python">loss.backward()
  • 工作原理:

    • 首先计算损失对每个变量的梯度。
    • 根据链式法则,逐层累积梯度。
    • 将计算结果存储在相关张量的 grad 属性中。
  • 参数说明:

    • retain_graph:是否保留计算图。默认为 False,计算完梯度后会释放计算图。
    • create_graph:是否创建计算图,允许对梯度再次求导。
3.2 grad 属性
  • grad 存储了张量的梯度,是反向传播的结果。
  • 只有 requires_grad=True 的张量才会有 grad 属性。
  • 如果张量未参与任何梯度计算,其 grad 属性为 None

4. 示例代码的修改与改进

为了让梯度正确传递,我们可以移除 detach() 操作:

改进代码:
python">import torch# 定义张量
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)# 定义计算
y = x * 2
z = y  # 不使用 detach
w = z ** 2# 反向传播
w.sum().backward()# 打印梯度
print("x 的梯度:", x.grad)  # 输出:x 的梯度: tensor([4., 8., 12.])
计算过程:
  1. y = x * 2
    ∂ y ∂ x = 2 \frac{\partial y}{\partial x} = 2 xy=2

  2. w = z^2 = y^2
    ∂ w ∂ y = 2 y \frac{\partial w}{\partial y} = 2y yw=2y

  3. 利用链式法则:
    ∂ w ∂ x = ∂ w ∂ y ⋅ ∂ y ∂ x = 2 y ⋅ 2 = 4 x \frac{\partial w}{\partial x} = \frac{\partial w}{\partial y} \cdot \frac{\partial y}{\partial x} = 2y \cdot 2 = 4x xw=ywxy=2y2=4x

    最终,x.grad = [4.0, 8.0, 12.0]


5. 注意事项

  1. detach 和 no_grad 的区别

    • detach 是对单个张量操作,将其从计算图中分离。
    • torch.no_grad() 是一个上下文管理器,用于禁止其内所有计算图的创建,常用于推理阶段。
  2. 梯度累积
    默认情况下,PyTorch 会累积梯度(即多次调用 backward() 会叠加梯度)。如果不需要累积,可在每次计算前手动清零:

    python">optimizer.zero_grad()
    
  3. 链式法则与梯度传播
    PyTorch 的自动微分基于链式法则,因此每一步的梯度都会被准确传播。


总结

本文通过一个实际示例,详细解析了 PyTorch 中梯度的计算过程,重点介绍了计算图的概念以及 backward()grad 的使用。理解这些核心机制对于深度学习模型的开发与调试至关重要,希望这篇文章能帮助你更深入地掌握 PyTorch 的梯度计算。

后记

2024年12月13日10点24分于上海,在GPT4o大模型辅助下完成。


http://www.ppmy.cn/news/1555336.html

相关文章

测试脚本并发多进程:pytest-xdist用法

参考:https://www.cnblogs.com/poloyy/p/12694861.html pytest-xdist详解: https://www.cnblogs.com/poloyy/p/14708825.html 总 https://www.cnblogs.com/poloyy/category/1690628.html

`BertModel` 和 `BertForMaskedLM

是的,BertModel 和 BertForMaskedLM 是两个不同的类,它们的功能和应用场景有所区别。以下是两者的详细对比: 1. BertModel 功能 BertModel 是基础的 BERT 模型,输出的是编码器的隐层表示(hidden states)&…

MySQL八股文

MySQL 自己学习过程中的MySQL八股笔记。 主要来源于 小林coding 牛客MySQL面试八股文背诵版 以及b站和其他的网上资料。 MySQL是一种开放源代码的关系型数据库管理系统(RDBMS),使用最常用的数据库管理语言–结构化查询语言(SQL&…

3D扫描和3D打印的结合应用

3D扫描和3D打印是两种紧密相连的增材制造技术,它们在多个领域中都发挥着重要作用。以下是对3D扫描和3D打印的详细解释: 一、3D扫描 3D扫描是运用软件对物体结构进行多方位扫描,从而建立物体的三维数字模型的技术。积木易搭在三维扫描设备&a…

Apache APISIX快速入门

本文将介绍Apache APISIX,这是一个开源API网关,可以处理速率限制选项,并且可以轻松地完全控制外部流量对内部后端API服务的访问。我们将看看是什么使它从其他网关服务中脱颖而出。我们还将详细讨论如何开始使用Apache APISIX网关。 在深入讨…

Python学习通移动端自动化刷课脚本,雷电模拟器近期可能出现的问题,与解决方案

前言 欢迎来到我的博客 个人主页:北岭敲键盘的荒漠猫-CSDN博客 这个文章是专门处理学习通脚本最近出现的问题的 我可是开源的好博主啊,不给我俩赞??? 本帅哥开源的刷课脚本导航 python安卓自动化pyaibote实践------学…

scala基础学习_变量

文章目录 scala中的变量常量 val(不可变变量)变量 var变量声明多变量声明匿名变量 _ 声明 变量类型声明变量命名规范 scala中的变量 常量 val(不可变变量) 使用val关键字声明变量是不可变的,一旦赋值后不能被修改 对…

林曦词典|无聊

“林曦词典”是在水墨画家林曦的课堂与访谈里,频频邂逅的话语,总能生发出无尽的思考。那些悠然轻快的、微妙纷繁的,亦或耳熟能详的词,经由林曦老师的独到解析,意蕴无穷,让人受益。于是,我们将诸…