torch.utils.checkpoint.checkpoint
是 PyTorch 提供的一种内存优化工具,用于在计算图的反向传播过程中节省显存。它通过重新计算某些前向传播的部分,减少了保存中间激活值所需的显存,特别适用于深度模型,如 Transformer 等层数较多的网络。
主要原理
在标准的反向传播中,前向传播过程中每一层的中间激活值(activation)会被保留,供后续反向传播使用。但在使用 checkpoint
时,某些层的激活值不会在前向传播时保存,而是在反向传播时通过重新计算这些层的前向结果来获得。
这样可以节省大量内存,但代价是增加了计算量,因为反向传播时需要重新计算部分前向传播。
使用方法
示例代码1
import torch
from torch.utils.checkpoint import checkpointdef forward_pass(x):# 假设是模型的一部分return x * x + 2 * x + 1# 在调用时使用 checkpoint 包裹住前向传播的函数
input_tensor = torch.randn(3, requires_grad=True)
output = checkpoint(forward_pass, input_tensor)
output.backward()print(input_tensor.grad) # 查看梯度
示例代码2
import torch
from torch import nn
from torch.utils.checkpoint import checkpoint# 假设我们有一个网络的部分前向传播如下
class MyModule(nn.Module):def __init__(self):super(MyModule, self).__init__()self.linear1 = nn.Linear(100, 100)self.linear2 = nn.Linear(100, 100)def forward(self, x):# 使用 checkpoint 来节省中间激活值的显存x = checkpoint(self.linear1, x) # 不保存 linear1 的激活值x = self.linear2(x) # 保存 linear2 的激活值return xmodel = MyModule()
input_tensor = torch.randn(10, 100, requires_grad=True)
output = model(input_tensor)
output.sum().backward()
示例代码3
#EncLayer 为一个类 class EncLayer(nn.Module)# 下面是另一个类中的部分代码
self.encoder_layers = nn.ModuleList([EncLayer(hidden_dim, hidden_dim*2, dropout=dropout)for _ in range(num_encoder_layers)])for layer in self.encoder_layers:h_V, h_E = torch.utils.checkpoint.checkpoint(layer, h_V, h_E, E_idx, mask, mask_attend)
torch.utils.checkpoint.checkpoint
参数
function
: 被 checkpoint 包裹的前向传播函数。*args
: 前向传播函数的参数。这些参数必须支持requires_grad=True
,因为checkpoint
需要在反向传播时重新计算前向传播。
优点
- 显存节省:对于较大的模型(如超过数百层的网络),可以通过这种方式大幅减少显存占用,特别适合在有限的 GPU 内存下训练深度模型。
缺点
- 计算开销增加:因为在反向传播时重新计算部分前向传播,增加了计算时间。
适用场景
torch.utils.checkpoint.checkpoint
特别适用于:
- 深层网络,如 Transformer 或 ResNet 这样有很多堆叠层的模型。
- 需要在有限显存的设备上训练大型模型时。