pytorch的动态计算图机制

devtools/2024/9/22 22:08:20/

pytorch_0">pytorch的动态计算图机制

一,动态计算图简介

在这里插入图片描述

Pytorch的计算图由节点和边组成,节点表示张量或者Function,边表示张量和Function之间的依赖关系。

Pytorch中的计算图是动态图。这里的动态主要有两重含义。

第一层含义是:计算图的正向传播是立即执行的。无需等待完整的计算图创建完毕,每条语句都会在计算图中动态添加节点和边,并立即执行正向传播得到计算结果。

第二层含义是:计算图在反向传播后立即销毁。下次调用需要重新构建计算图。如果在程序中使用了backward方法执行了反向传播,或者利用torch.autograd.grad方法计算了梯度,那么创建的计算图会被立即销毁,释放存储空间,下次调用需要重新创建。

1,计算图的正向传播是立即执行的。

python">import torch 
w = torch.tensor([[3.0,1.0]],requires_grad=True)
b = torch.tensor([[3.0]],requires_grad=True)
X = torch.randn(10,2)
Y = torch.randn(10,1)
Y_hat = X@w.t() + b  # Y_hat定义后其正向传播被立即执行,与其后面的loss创建语句无关
loss = torch.mean(torch.pow(Y_hat-Y,2))print(loss.data)
print(Y_hat.data)
tensor(17.8969)
tensor([[3.2613],[4.7322],[4.5037],[7.5899],[7.0973],[1.3287],[6.1473],[1.3492],[1.3911],[1.2150]])

2,计算图在反向传播后立即销毁。

python">import torch 
w = torch.tensor([[3.0,1.0]],requires_grad=True)
b = torch.tensor([[3.0]],requires_grad=True)
X = torch.randn(10,2)
Y = torch.randn(10,1)
Y_hat = X@w.t() + b  # Y_hat定义后其正向传播被立即执行,与其后面的loss创建语句无关
loss = torch.mean(torch.pow(Y_hat-Y,2))#计算图在反向传播后立即销毁,如果需要保留计算图, 需要设置retain_graph = True
loss.backward()  #loss.backward(retain_graph = True) #loss.backward() #如果再次执行反向传播将报错

二,计算图中的Function

计算图中的另外一种节点是Function, 实际上就是 Pytorch中各种对张量操作的函数。

这些Function和我们Python中的函数有一个较大的区别,那就是它同时包括正向计算逻辑和反向传播的逻辑。

我们可以通过继承torch.autograd.Function来创建这种支持反向传播的Function

python">class MyReLU(torch.autograd.Function):#正向传播逻辑,可以用ctx存储一些值,供反向传播使用。@staticmethoddef forward(ctx, input):ctx.save_for_backward(input)return input.clamp(min=0)#反向传播逻辑@staticmethoddef backward(ctx, grad_output):input, = ctx.saved_tensorsgrad_input = grad_output.clone()grad_input[input < 0] = 0return grad_input
import torch 
w = torch.tensor([[3.0,1.0]],requires_grad=True)
b = torch.tensor([[3.0]],requires_grad=True)
X = torch.tensor([[-1.0,-1.0],[1.0,1.0]])
Y = torch.tensor([[2.0,3.0]])relu = MyReLU.apply # relu现在也可以具有正向传播和反向传播功能
Y_hat = relu(X@w.t() + b)
loss = torch.mean(torch.pow(Y_hat-Y,2))loss.backward()print(w.grad)
print(b.grad)
tensor([[4.5000, 4.5000]])
tensor([[4.5000]])
# Y_hat的梯度函数即是我们自己所定义的 MyReLU.backwardprint(Y_hat.grad_fn)
<torch.autograd.function.MyReLUBackward object at 0x1205a46c8>

三,计算图与反向传播

了解了Function的功能,我们可以简单地理解一下反向传播的原理和过程。理解该部分原理需要一些高等数学中求导链式法则的基础知识。

python">import torch x = torch.tensor(3.0,requires_grad=True)
y1 = x + 1
y2 = 2*x
loss = (y1-y2)**2loss.backward()

loss.backward()语句调用后,依次发生以下计算过程。

1,loss自己的grad梯度赋值为1,即对自身的梯度为1。

2,loss根据其自身梯度以及关联的backward方法,计算出其对应的自变量即y1和y2的梯度,将该值赋值到y1.grad和y2.grad。

3,y2和y1根据其自身梯度以及关联的backward方法, 分别计算出其对应的自变量x的梯度,x.grad将其收到的多个梯度值累加。

(注意,1,2,3步骤的求梯度顺序和对多个梯度值的累加规则恰好是求导链式法则的程序表述)

正因为求导链式法则衍生的梯度累加规则,张量的grad梯度不会自动清零,在需要的时候需要手动置零。

四,叶子节点和非叶子节点

执行下面代码,我们会发现 loss.grad并不是我们期望的1,而是 None。

类似地 y1.grad 以及 y2.grad也是 None.

这是为什么呢?这是由于它们不是叶子节点张量。

在反向传播过程中,只有 is_leaf=True 的叶子节点,需要求导的张量的导数结果才会被最后保留下来。

那么什么是叶子节点张量呢?叶子节点张量需要满足两个条件。

1,叶子节点张量是由用户直接创建的张量,而非由某个Function通过计算得到的张量。

2,叶子节点张量的 requires_grad属性必须为True.

Pytorch设计这样的规则主要是为了节约内存或者显存空间,因为几乎所有的时候,用户只会关心他自己直接创建的张量的梯度。

所有依赖于叶子节点张量的张量, 其requires_grad 属性必定是True的,但其梯度值只在计算过程中被用到,不会最终存储到grad属性中。

如果需要保留中间计算结果的梯度到grad属性中,可以使用 retain_grad方法。
如果仅仅是为了调试代码查看梯度值,可以利用register_hook打印日志。

python">import torch x = torch.tensor(3.0,requires_grad=True)
y1 = x + 1
y2 = 2*x
loss = (y1-y2)**2loss.backward()
print("loss.grad:", loss.grad)
print("y1.grad:", y1.grad)
print("y2.grad:", y2.grad)
print(x.grad)
loss.grad: None
y1.grad: None
y2.grad: None
tensor(4.)
print(x.is_leaf)
print(y1.is_leaf)
print(y2.is_leaf)
print(loss.is_leaf)
True
False
False
False

利用retain_grad可以保留非叶子节点的梯度值,利用register_hook可以查看非叶子节点的梯度值。

python">import torch #正向传播
x = torch.tensor(3.0,requires_grad=True)
y1 = x + 1
y2 = 2*x
loss = (y1-y2)**2#非叶子节点梯度显示控制
y1.register_hook(lambda grad: print('y1 grad: ', grad))
y2.register_hook(lambda grad: print('y2 grad: ', grad))
loss.retain_grad()#反向传播
loss.backward()
print("loss.grad:", loss.grad)
print("x.grad:", x.grad)
y2 grad:  tensor(4.)
y1 grad:  tensor(-4.)
loss.grad: tensor(1.)
x.grad: tensor(4.)

五,计算图在TensorBoard中的可视化

可以利用 torch.utils.tensorboard 将计算图导出到 TensorBoard进行可视化。

python">from torch import nn 
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.w = nn.Parameter(torch.randn(2,1))self.b = nn.Parameter(torch.zeros(1,1))def forward(self, x):y = x@self.w + self.breturn ynet = Net()
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter('../data/tensorboard')
writer.add_graph(net,input_to_model = torch.rand(10,2))
writer.close()
%load_ext tensorboard
#%tensorboard --logdir ../data/tensorboard
from tensorboard import notebook
notebook.list() 
#在tensorboard中查看模型
notebook.start("--logdir ../data/tensorboard")

在这里插入图片描述


Reference:

https://jackiexiao.github.io/eat_pytorch_in_20_days/2.%E6%A0%B8%E5%BF%83%E6%A6%82%E5%BF%B5/2-3%2C%E5%8A%A8%E6%80%81%E8%AE%A1%E7%AE%97%E5%9B%BE/


http://www.ppmy.cn/devtools/115665.html

相关文章

C# System.BadImageFormatException问题及解决

C# System.BadImageFormatException问题 出现System.BadImageFormatException 异常有两种情况&#xff1a;程序目标平台不一致&引用dll文件的系统平台不一致。 异常参考 BadImageFormatException 程序目标平台不一致&#xff1a; 项目>属性>生成&#xff1a;x86 …

CentOS 上配置多服务器 SSH 免密登录

以下是在 CentOS 上配置多服务器 SSH 免密登录的步骤&#xff1a; 一、准备工作 假设有服务器 A 和服务器 B&#xff0c;需要从服务器 A 免密登录到服务器 B。 二、在服务器 A 上生成密钥对 打开终端&#xff0c;执行以下命令生成 SSH 密钥对&#xff1a; ssh-keygen -t rsa一路…

系统架构设计师 大数据架构篇二

大数据架构 &#x1f310; 大数据处理系统分析 &#x1f50d; 大数据处理系统三大挑战 &#x1f680; 非结构化数据处理&#xff1a;如何处理非结构化和半结构化数据。复杂性与不确定性&#xff1a;大数据复杂性、不确定性特征描述的刻画方法和大数据的系统建模。异构性影响…

RTMP协议在无人机巡检中的应用场景

为什么要用无人机巡检 好多开发者对无人机巡检技术方案&#xff0c;相对陌生&#xff0c;实际上&#xff0c;无人机巡检就是利用无人机对特定区域或设施进行定期或不定期的检查。这种巡检方式相比传统的人工巡检具有显著的优势&#xff0c;包括速度快、覆盖广、风险低、准确性…

git学习【完结】

git学习【完结】 文章目录 git学习【完结】一、Git基本操作1.创建本地仓库2.配置本地仓库1.局部配置2.全局配置 3.认识工作区、暂存区、版本库4.添加文件5.修改文件6.版本回退7.撤销修改8.删除文件 二、Git分支管理1.理解分支2.创建、切换、合并分支3.删除分支4.合并冲突5.合并…

【排序算法】插入排序_直接插入排序、希尔排序

文章目录 直接插入排序直接插入排序的基本思想直接插入排序的过程插入排序算法的C代码举例分析插入排序的复杂度分析插入排序的优点 希尔排序希尔排序&#xff08;Shell Sort&#xff09;详解希尔排序的步骤&#xff1a;希尔排序的过程示例&#xff1a;希尔排序的C语言实现举例…

Vue介绍、窗体内操作、窗体间操作学习

系列文章目录 第一章 基础知识、数据类型学习 第二章 万年历项目 第三章 代码逻辑训练习题 第四章 方法、数组学习 第五章 图书管理系统项目 第六章 面向对象编程&#xff1a;封装、继承、多态学习 第七章 封装继承多态习题 第八章 常用类、包装类、异常处理机制学习 第九章 集…

vue 中属性值上变量和字符串怎么拼接

在Vue 3中&#xff0c;可以使用模板字面量&#xff08;template literals&#xff09;或者表达式绑定&#xff08;directives&#xff09;来实现属性值上变量和字符串的拼接。 例如&#xff0c;假设你有一个变量text和一个字符串hello&#xff0c;你可以这样拼接它们&#xff…