torch.utils.checkpoint.checkpoint介绍

devtools/2024/10/24 20:05:50/

torch.utils.checkpoint.checkpoint 是 PyTorch 提供的一种内存优化工具,用于在计算图的反向传播过程中节省显存。它通过重新计算某些前向传播的部分,减少了保存中间激活值所需的显存,特别适用于深度模型,如 Transformer 等层数较多的网络。

主要原理

在标准的反向传播中,前向传播过程中每一层的中间激活值(activation)会被保留,供后续反向传播使用。但在使用 checkpoint 时,某些层的激活值不会在前向传播时保存,而是在反向传播时通过重新计算这些层的前向结果来获得。

这样可以节省大量内存,但代价是增加了计算量,因为反向传播时需要重新计算部分前向传播。

使用方法

示例代码1 

import torch
from torch.utils.checkpoint import checkpointdef forward_pass(x):# 假设是模型的一部分return x * x + 2 * x + 1# 在调用时使用 checkpoint 包裹住前向传播的函数
input_tensor = torch.randn(3, requires_grad=True)
output = checkpoint(forward_pass, input_tensor)
output.backward()print(input_tensor.grad)  # 查看梯度

示例代码2

import torch
from torch import nn
from torch.utils.checkpoint import checkpoint# 假设我们有一个网络的部分前向传播如下
class MyModule(nn.Module):def __init__(self):super(MyModule, self).__init__()self.linear1 = nn.Linear(100, 100)self.linear2 = nn.Linear(100, 100)def forward(self, x):# 使用 checkpoint 来节省中间激活值的显存x = checkpoint(self.linear1, x)  # 不保存 linear1 的激活值x = self.linear2(x)  # 保存 linear2 的激活值return xmodel = MyModule()
input_tensor = torch.randn(10, 100, requires_grad=True)
output = model(input_tensor)
output.sum().backward()

示例代码3

#EncLayer 为一个类 class EncLayer(nn.Module)# 下面是另一个类中的部分代码      
self.encoder_layers = nn.ModuleList([EncLayer(hidden_dim, hidden_dim*2, dropout=dropout)for _ in range(num_encoder_layers)])for layer in self.encoder_layers:h_V, h_E = torch.utils.checkpoint.checkpoint(layer, h_V, h_E, E_idx, mask, mask_attend)

torch.utils.checkpoint.checkpoint 参数

  • function: 被 checkpoint 包裹的前向传播函数。
  • *args: 前向传播函数的参数。这些参数必须支持 requires_grad=True,因为 checkpoint 需要在反向传播时重新计算前向传播。

优点

  • 显存节省:对于较大的模型(如超过数百层的网络),可以通过这种方式大幅减少显存占用,特别适合在有限的 GPU 内存下训练深度模型。

缺点

  • 计算开销增加:因为在反向传播时重新计算部分前向传播,增加了计算时间。

适用场景

torch.utils.checkpoint.checkpoint 特别适用于:

  • 深层网络,如 Transformer 或 ResNet 这样有很多堆叠层的模型。
  • 需要在有限显存的设备上训练大型模型时。

http://www.ppmy.cn/devtools/128504.html

相关文章

(A-D)AtCoder Beginner Contest 376

目录 比赛链接: A - Candy Button 题目链接: 题目描述: 数据范围: 输入样例: 输出样例: 样例解释: 分析: 代码: B - Hands on Ring (Easy) 题目链接&#xff1…

代码随想录算法训练营第三十七天|509. 斐波那契数,70. 爬楼梯,746. 使用最小花费爬楼梯

509. 斐波那契数,70. 爬楼梯,746. 使用最小花费爬楼梯 509. 斐波那契数70. 爬楼梯746. 使用最小花费爬楼梯 509. 斐波那契数 斐波那契数 (通常用 F(n) 表示)形成的序列称为 斐波那契数列 。该数列由 0 和 1 开始,后面…

智慧楼宇平台,构筑未来智慧城市的基石

随着城市化进程的加速,城市面临着前所未有的挑战。人口密度的增加、资源的紧张、环境的恶化以及对高效能源管理的需求,都在推动着我们寻找更加智能、可持续的城市解决方案。智慧楼宇作为智慧城市建设的重要组成部分,正逐渐成为推动城市可持续…

「AIGC」AI设计工具 v0.dev

https://v0.dev/ 1.1 简介 $20 学习前端代码本身似乎并不复杂,但是有平台能够直接生成代码、预览效果,似乎更有性价比。可能对于有前端和开发经验的同学而言,直接实现某个页面效果并不算是太复杂的事情,但是对于没有代码经验的同学而言,直接使用 AI 跑出代码甚至直接落地…

Chainlit集成LlamaIndex和Chromadb实现RAG增强生成对话AI应用

前言 本文主要讲解如何使用LlamaIndex和Chromadb向量数据库实现RAG应用,并使用Chainlit快速搭建一个前端对话网页,实现RAG聊天问答增强的应用。文章中还讲解了LlamaIndex 的CallbackManager回调,实现案例是使用TokenCountingHandler&#xf…

云渲染分布式渲染什么意思?一文详解

渲染和分布式渲染是现代计算机图形学中的重要技术,它们通过将渲染任务分散到多个服务器或计算节点上,显著提高了渲染效率和处理大规模数据的能力。这项技术在动画制作、游戏开发和电影特效等领域发挥着关键作用,为创作者提供了更快速、更灵活…

SolarWinds Web Help Desk曝出严重漏洞,已遭攻击者利用

近日,CISA 在其 “已知漏洞”(KEV)目录中增加了三个漏洞,其中一个是 SolarWinds Web Help Desk (WHD) 中的关键硬编码凭据漏洞,供应商已于 2024 年 8 月底修复了该漏洞。 SolarWinds Web Help Desk 是一款 IT 服务台套…

React前端框架高级技巧

在当今快速发展的前端开发世界中,React依然保持着强大的生命力和广泛的应用。无论你是React新手还是经验丰富的开发者,掌握一些高级技巧都能极大地提升你的开发效率。本文将为你揭示5个鲜为人知但非常实用的React技巧,让你的代码更加简洁、高效、易维护。 1. 使用React.memo()…