PyTorch 中的 torch.Tensor
梯度详解
PyTorch 是一个广泛使用的深度学习框架,它以其动态计算图和强大的自动微分(Autograd)机制而闻名。在训练神经网络时,梯度计算是反向传播算法的核心。
目录
- Tensor 与
requires_grad
属性 - 动态计算图(Computational Graph)
- 反向传播与梯度计算
- 梯度的计算过程
- 非标量输出的梯度计算
- 控制梯度计算
- 梯度的累积
- 梯度的清零
- 高阶求导
- 总结
1. Tensor 与 requires_grad
属性
在 PyTorch 中,torch.Tensor
是存储和操作数据的基本单元。每个张量都有一个属性 requires_grad
,它决定了在对该张量进行操作时,PyTorch 是否需要为其计算梯度。
- 默认情况下,
requires_grad=False
,意味着对该张量的操作不被跟踪,不会计算梯度。 - 如果需要对张量进行梯度计算,需要将其设为
True
。
示例:
python">import torch# 创建一个 requires_grad 为 True 的张量
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
2. 动态计算图(Computational Graph)
当对 requires_grad=True
的张量进行操作时,PyTorch 会构建一个 动态计算图,记录每个操作,从而能够在反向传播时自动计算梯度。
- 节点(Node):表示张量(Tensor)。
- 边(Edge):表示从一个张量到另一个张量的操作(Function)。
- 叶子节点(Leaf Tensor):
requires_grad=True
且不是通过其他运算得到的张量,通常是模型的参数或输入数据。
示例代码:
python">import torch# 创建一个叶子张量 x,启用梯度计算
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)# 定义计算过程
y = x * 2 # 对 x 进行乘法操作
z = y.mean() # 对 y 求均值,得到标量 z
在上述代码中:
x
是叶子张量,因为它是用户直接创建的,且requires_grad=True
。y
和z
则是通过对张量的操作得到的中间结果。
计算图结构:
为了更直观地展示计算过程,下面通过 Mermaid 图绘制出计算图的结构:
解释:
-
节点:
x
:叶子节点,代表输入张量。y
:中间节点,y = x * 2
的结果。z
:输出节点,z = y.mean()
的结果,标量值。
-
边:
x
到y
的边,标注为 乘以 2,表示对x
进行乘法操作得到y
。y
到z
的边,标注为 求均值,表示对y
求均值得到z
。
计算图详解:
-
前向传播(Forward Pass):
- Step 1:对
x
进行乘法运算,得到y
。y = x * 2
- Step 2:对
y
求均值,得到标量z
。z = y.mean()
- Step 1:对
-
反向传播(Backward Pass):
- 当调用
z.backward()
时,PyTorch 会自动计算梯度,步骤如下:- 计算 ∂z/∂y:
z
关于y
的梯度。 - 计算 ∂y/∂x:
y
关于x
的梯度。 - 计算 ∂z/∂x = ∂z/∂y * ∂y/∂x:链式法则。
- 计算 ∂z/∂y:
- 当调用
-
梯度存储:
- 计算得到的梯度会存储在叶子张量
x
的.grad
属性中,即x.grad
。
- 计算得到的梯度会存储在叶子张量
计算图的动态性:
- 动态构建:PyTorch 的计算图是动态的,每次前向计算时都会根据代码执行情况实时构建。这意味着代码中的控制流(如循环、条件判断)都能被计算图正确地捕获和处理。
更复杂的计算图示例:
假设我们有如下的计算过程:
python"># 创建叶子张量
a = torch.tensor([2.0, 3.0], requires_grad=True)
b = torch.tensor([6.0, 4.0], requires_grad=True)# 定义计算
c = a * b
d = c + a
e = torch.sum(d)
对应的计算图可视化为:
解释:
a
和b
都是叶子张量,参与计算操作。c
是a
和b
相乘的结果。d
是c
和a
相加的结果。e
是对d
所有元素求和得到的标量。
计算图的特点:
- 有向无环图(DAG):计算图是一个有向无环图,避免了循环依赖。
- 灵活性:支持动态调整,适用于各种复杂的神经网络结构。
- 自动求导:在反向传播时,根据计算图自动计算梯度,简化了手动求导的过程。
3. 反向传播与梯度计算
当计算完成后,我们可以对结果标量调用 .backward()
方法,PyTorch 将自动计算梯度,并将结果存储在每个叶子张量的 .grad
属性中。
示例:
python"># 对 z 进行反向传播,计算梯度
z.backward()# 查看 x 的梯度
print(x.grad) # 输出:tensor([0.6667, 0.6667, 0.6667])
解释:
- 对于标量
z
,z.backward()
将计算z
对所有叶子张量的梯度。 x.grad
存储了∂z/∂x
的值。
4. 梯度的计算过程
让我们深入了解梯度是如何计算的。
假设:
x = [1.0, 2.0, 3.0]
y = x * 2
(元素级别的乘法)z = y.mean()
(所有元素的平均值)
计算过程:
-
前向传播:
y = [2.0, 4.0, 6.0]
z = (2.0 + 4.0 + 6.0) / 3 = 4.0
-
反向传播:
-
计算 ∂z/∂y_i:
z = (y_1 + y_2 + y_3) / 3
- 因此,
∂z/∂y_i = 1/3
-
计算 ∂y_i/∂x_i:
y_i = 2 * x_i
- 因此,
∂y_i/∂x_i = 2
-
链式法则:
∂z/∂x_i = ∂z/∂y_i * ∂y_i/∂x_i = (1/3) * 2 = 2/3
-
所以,
x.grad = [2/3, 2/3, 2/3] ≈ [0.6667, 0.6667, 0.6667]
-
5. 非标量输出的梯度计算
当输出不是标量时,需要在调用 .backward()
时传入一个和输出同形状的张量,指定各个元素的梯度权重。
示例:
python">out = x * 2 # out 是一个张量,而非标量# 定义梯度权重
gradients = torch.tensor([1.0, 1.0, 1.0])# 计算梯度
out.backward(gradients)# 查看 x 的梯度
print(x.grad) # 输出:tensor([2., 2., 2.])
解释:
- 由于
out
不是标量,直接调用out.backward()
会报错。 - 我们需要传入一个与
out
形状相同的张量gradients
,代表各个元素对标量的梯度。
6. 控制梯度计算
在某些情况下,我们可能不希望自动求导机制对计算进行跟踪,可以使用以下方法:
a. with torch.no_grad()
禁用自动求导,用于评估模型或在推理过程中防止计算图的构建。
示例:
python">with torch.no_grad():y = x * 2 # 不会被追踪
b. detach()
创建一个新的张量,从计算图中分离出来。
示例:
python">y = x.detach() # y 不会有梯度信息
7. 梯度的累积
在 PyTorch 中,张量的 .grad
属性在每次调用 .backward()
时,会将计算得到的梯度 累积 到现有的梯度上,而不会自动清零。这意味着,如果对同一个张量多次调用 .backward()
,其 .grad
中的值会不断增加。
7.1 梯度累积的原理
- 累积机制:每次调用
.backward()
,计算得到的梯度都会加到.grad
属性中,而不是替换掉原有的值。 - 用途:梯度累积在需要模拟更大批量(Batch Size)的训练,或者在分布式训练中,经常被用到。
示例:
python">import torch# 定义一个启用梯度的张量
x = torch.tensor(1.0, requires_grad=True)# 第一次前向和反向传播
y1 = x * 2
y1.backward()
print("第一次梯度计算,x.grad =", x.grad) # 输出:tensor(2.)# 第二次前向和反向传播
y2 = x * 3
y2.backward()
print("第二次梯度计算,x.grad =", x.grad) # 输出:tensor(5.) ,梯度被累积了(2. + 3.)
解释:
- 第一次反向传播,
x.grad
中存储了∂y1/∂x = 2
。 - 第二次反向传播,
∂y2/∂x = 3
,累积到之前的梯度上,因此x.grad = 2 + 3 = 5
。
7.2 梯度累积的应用
在实际训练中,如果显存不足以支持更大的批量大小,可以通过梯度累积的方法,在多次小批量计算后,手动进行一次参数更新。
示例:
python">import torch
from torch import nn, optimmodel = nn.Linear(10, 1)
optimizer = optim.SGD(model.parameters(), lr=0.01)
criterion = nn.MSELoss()# 假设我们希望累积 4 个小批量的梯度
accumulation_steps = 4for epoch in range(num_epochs):optimizer.zero_grad() # 清零梯度for i, (inputs, targets) in enumerate(dataloader):outputs = model(inputs)loss = criterion(outputs, targets)loss = loss / accumulation_steps # 梯度平均loss.backward()if (i + 1) % accumulation_steps == 0:optimizer.step() # 更新参数optimizer.zero_grad() # 清零梯度
解释:
- 每次循环中,调用
loss.backward()
,梯度会累积到模型参数的.grad
属性中。 - 使用
loss = loss / accumulation_steps
对损失进行缩放,保持梯度的规模一致。 - 当累积了指定次数后,进行一次参数更新,并清零梯度。
- 这种方式等价于使用更大的批量大小进行训练。
8. 梯度的清零
由于梯度在反向传播中会累积,为防止前面的梯度对当前计算产生影响,在每次优化步骤之前,需要手动将梯度清零。
8.1 清零梯度的方法
optimizer.zero_grad()
:针对使用优化器的情况,清零所有待优化参数的梯度。x.grad.zero_()
:直接对张量的.grad
属性进行清零,适用于单个张量。
示例:
python">import torchx = torch.tensor(1.0, requires_grad=True)
optimizer = torch.optim.SGD([x], lr=0.1)for epoch in range(2):optimizer.zero_grad() # 清零梯度y = x * 2y.backward()print(f"Epoch {epoch}, x.grad = {x.grad}")optimizer.step()
输出:
Epoch 0, x.grad = tensor(2.)
Epoch 1, x.grad = tensor(2.)
解释:
- 每次循环开始时,调用
optimizer.zero_grad()
,将x.grad
清零。 - 然后计算梯度并更新参数。由于梯度被清零,每次看到的
x.grad
都是当前计算的梯度,没有累积。
8.2 注意事项
- 何时清零梯度:通常在每个优化步骤之前清零梯度,即在调用
backward()
计算梯度之前。 - 防止梯度干扰:如果不清零梯度,之前计算的梯度会影响当前的梯度更新,导致模型训练出现问题。
9. 高阶求导
PyTorch 支持高阶求导,即计算梯度的梯度(例如 Hessian 矩阵或更高阶导数)。在某些优化算法(如二阶优化、元学习)或需要深入分析梯度的情况中,高阶求导非常有用。
9.1 背景知识
在数学中,高阶导数是对函数的导数再求导。例如:
- 一阶导数(First-order Derivative):
dy/dx
- 二阶导数(Second-order Derivative):
d²y/dx²
- 三阶导数(Third-order Derivative):
d³y/dx³
9.2 如何计算高阶导数
为了在 PyTorch 中计算高阶导数,需要在反向传播时设置 create_graph=True
,以确保在计算梯度时保留计算图,从而允许对梯度再次求导。另外,使用 retain_graph=True
可以在多次反向传播中保留计算图。
9.3 示例代码
python">import torch# 创建张量
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)# 定义函数 y = x^3
y = x ** 3# 一阶导数 dy/dx
grad_outputs = torch.ones_like(x)
y.backward(gradient=grad_outputs, create_graph=True)# x.grad 此时存储了一阶导数 dy/dx
print("一阶导数 dy/dx:", x.grad) # 输出:tensor([ 3., 12., 27.])# 为了计算高阶导数,保留一阶导数的计算图
first_derivative = x.grad.clone()# 清零 x.grad,以防止累积
x.grad.zero_()# 二阶导数 d²y/dx²
first_derivative.backward(torch.ones_like(x), create_graph=True)# x.grad 此时存储了二阶导数 d²y/dx²
print("二阶导数 d²y/dx²:", x.grad) # 输出:tensor([ 6., 12., 18.])# 为了计算三阶导数,保留二阶导数的计算图
second_derivative = x.grad.clone()# 清零 x.grad
x.grad.zero_()# 三阶导数 d³y/dx³
second_derivative.backward(torch.ones_like(x))# x.grad 此时存储了三阶导数 d³y/dx³
print("三阶导数 d³y/dx³:", x.grad) # 输出:tensor([ 6., 12., 18.])
解释:
-
计算一阶导数:
- 对
y
关于x
求导,dy/dx = 3x^2
。 - 使用
create_graph=True
,以便在计算一阶导数时保留计算图,允许进一步求导。
- 对
-
计算二阶导数:
- 对一阶导数再关于
x
求导,d²y/dx² = 6x
。 - 再次使用
create_graph=True
,保留计算图。
- 对一阶导数再关于
-
计算三阶导数:
- 对二阶导数再关于
x
求导,d³y/dx³ = 6
。 - 由于三阶导数是常数,进一步求导将得到零。
- 对二阶导数再关于
9.4 注意事项
-
create_graph=True
:在调用backward()
或torch.autograd.grad()
时,设置create_graph=True
,以便在计算梯度时构建计算图,允许对梯度再次求导。 -
retain_graph=True
:在需要多次反向传播时,设置retain_graph=True
,以防止计算图被释放。 -
内存管理:由于在高阶求导中需要保留计算图,可能会导致显存占用增加。应注意显存的使用,避免内存泄漏。
9.5 torch.autograd.grad()
的使用
torch.autograd.grad()
函数可以更灵活地计算梯度,适用于高阶求导。
示例:
python">import torchx = torch.tensor(2.0, requires_grad=True)# 定义函数 y = x^3
y = x ** 3# 计算一阶导数 dy/dx
grad_y = torch.autograd.grad(outputs=y, inputs=x, create_graph=True)[0]
print("一阶导数 dy/dx:", grad_y) # 输出:tensor(12., grad_fn=<MulBackward0>)# 计算二阶导数 d²y/dx²
grad2_y = torch.autograd.grad(outputs=grad_y, inputs=x, create_graph=True)[0]
print("二阶导数 d²y/dx²:", grad2_y) # 输出:tensor(12.)# 计算三阶导数 d³y/dx³
grad3_y = torch.autograd.grad(outputs=grad2_y, inputs=x)[0]
print("三阶导数 d³y/dx³:", grad3_y) # 输出:tensor(6.)
解释:
torch.autograd.grad
:返回计算结果而不影响x.grad
,适合在需要手动控制梯度时使用。
9.6 高阶导数的应用场景
- 元学习(Meta-learning):在学习如何学习的过程中,需要计算模型参数对超参数的导数。
- 优化算法:如牛顿法等二阶优化方法,涉及到 Hessian 矩阵或二阶导数。
- 正则化技术:如梯度惩罚,需要计算梯度的范数或进一步求导。
9.7 防止内存泄漏
- 释放计算图:在不需要继续求导时,应确保不再保留计算图,以释放内存。
- 避免不必要的
create_graph
:仅在需要高阶导数时使用create_graph=True
。
示例:
python"># 计算完成后,手动删除变量,释放内存
del grad_y, grad2_y, grad3_y
10. 总结
-
requires_grad
属性:控制张量是否参与梯度计算。 -
动态计算图:PyTorch 自动构建和维护的计算图,用于自动求导。
-
反向传播:通过
.backward()
方法,自动计算梯度。 -
梯度存储:计算的梯度存储在叶子张量的
.grad
属性中。 -
非标量输出梯度:在输出为张量时,需要在
.backward()
中传入grad_tensors
。 -
控制梯度计算:使用
torch.no_grad()
和detach()
控制计算图的跟踪。 -
梯度的累积:默认情况下,梯度会累积,需要特别注意这一点。
-
梯度清零:在多次反向传播过程中,需要手动清零梯度以防止累积影响。
-
高阶求导:PyTorch 支持高阶导数的计算,需注意
retain_graph
和create_graph
参数的使用,并妥善管理计算图以防止内存泄漏。