PyTorch 2.0: 新特性与升级指南

ops/2024/11/14 2:54:52/

什么是 PyTorch 2.0?

PyTorch 2.0 是 PyTorch 的最新版本,它保留了之前版本的即时执行模式(eager mode),同时引入了一个全新的编译模式。这个编译模式通过 torch.compile 函数实现,有潜力显著提升模型的训练和推理速度。

为什么是 2.0 而不是 1.14?

PyTorch 团队认为这个版本引入的新特性足以改变用户使用 PyTorch 的方式,因此决定将其命名为 2.0 而不是 1.14。

如何安装 PyTorch 2.0?

你可以通过 pip 安装最新的 nightly 版本。根据你的 CUDA 版本或是否使用 CPU,选择相应的安装命令:

# CUDA 11.8
pip3 install numpy --pre torch torchvision torchaudio --force-reinstall --index-url https://download.pytorch.org/whl/nightly/cu118# CUDA 11.7
pip3 install numpy --pre torch torchvision torchaudio --force-reinstall --index-url https://download.pytorch.org/whl/nightly/cu117# CPU
pip3 install numpy --pre torch torchvision torchaudio --force-reinstall --index-url https://download.pytorch.org/whl/nightly/cpu

2.0 版本的兼容性如何?

PyTorch 2.0 完全向后兼容 1.x 版本。你无需修改现有的 PyTorch 工作流程。只需添加一行代码 model = torch.compile(model) 就可以优化你的模型以使用 2.0 的新特性。

如何迁移到 PyTorch 2.0?

大多数情况下,你的代码无需任何改动就可以在 PyTorch 2.0 上运行。如果你想使用新的编译模式特性,只需要在你的模型上调用 torch.compile

python">import torchdef train(model, dataloader):model = torch.compile(model)for batch in dataloader:run_epoch(model, batch)def infer(model, input):model = torch.compile(model)return model(**input)

PyTorch 2.0 的工作原理

当你使用 torch.compile(model) 包装你的模型时,模型会经历以下三个步骤:

  1. 图获取:模型被重写为子图块。
  2. 图降低:PyTorch 操作被分解为特定后端的核心操作。
  3. 图编译:核心操作调用相应的低级设备特定操作。

PyTorch 2.0 的新组件

  1. TorchDynamo:从 Python 字节码生成 FX 图。
  2. AOTAutograd:为 TorchDynamo 捕获的前向图生成对应的反向图。
  3. PrimTorch:将复杂的 PyTorch 操作分解为更简单和基本的操作。
  4. 后端:与 TorchDynamo 集成,将图编译为可在加速器上运行的 IR。

分布式训练

在编译模式下,DDP 和 FSDP 可以比即时执行模式快 15%(FP32)到 80%(AMP 精度)。使用 DDP 时,请确保设置 static_graph=False

遇到问题怎么办?

如果你的代码在编译模式下运行变慢或崩溃,很可能是由于图断裂(graph breaks)导致的。你可以参考 PyTorch 官方文档 来诊断和解决这些问题。

PyTorch 2.0 带来了显著的性能提升和新特性,同时保持了与旧版本的兼容性。通过简单的一行代码,你就可以享受到这些优化带来的好处.


http://www.ppmy.cn/ops/133447.html

相关文章

【掌握未来办公:OnlyOffice 8.2深度使用指南与新功能揭秘】

🌈个人主页: Aileen_0v0 🔥热门专栏: 华为鸿蒙系统学习|计算机网络|数据结构与算法 ​💫个人格言:“没有罗马,那就自己创造罗马~” 文章目录 一、功能亮点1. PDF文档的协同编辑2. PDF表单的电子签名3. 界面的现代化改造4. 性能的显著提升5. 文…

每天五分钟深度学习PyTorch:基于全连接神经网络完成手写字体识别

本文重点 上一节我们学习了搭建普通的全连接神经网络,我们现在用它来解决一个实际问题,我们用它跑一下手写字体识别的数据,然后看看它的效果如何。 网络模型 class ThreeNet(nn.Module) : def __init__ (self,in_dim,n_hidden_1,n_hidden_2,out_dim): super(ThreeNet, self…

重构代码之参数化方法

在代码重构中,参数化方法 通过将方法内部的硬编码值替换为参数,使方法的适用性更广。这不仅可以减少重复代码,还能提高代码的灵活性和可维护性。让我们来深入探讨这种技术的应用场景、步骤以及一些例子。 一、适用场景 参数化方法通常适用于…

Spring Security 认证流程,长话简说

一、代码先行 1、设计模式 SpringSecurity 采用的是 责任链 的设计模式,是一堆过滤器链的组合,它有一条很长的过滤器链。 不过我们不需要去仔细了解每一个过滤器的含义和用法,只需要搞定以下几个问题即可:怎么登录、怎么校验账户、认证失败…

泷羽sec学习打卡-Linux基础2

声明 学习视频来自B站UP主 泷羽sec,如涉及侵权马上删除文章 笔记的只是方便各位师傅学习知识,以下网站只涉及学习内容,其他的都与本人无关,切莫逾越法律红线,否则后果自负 关于Linux的那些事儿-Base2 一、Linux-Base2linux有哪些目录呢?不同目录下有哪些具体的文件呢…

C++ 中的异常处理机制是怎样的?

异常处理的基本概念: 异常: 程序在运行时发生的错误或意外情况。 抛出异常: 使用 throw 关键字将异常传递给调用堆栈。 捕获异常: 使用 try-catch 块捕获和处理异常。 异常类型: 表示异常类别的标识符。 异常处理流程: 抛出异常: 当检测到错误或意…

python制作一个简单的端口扫描器,用于检测目标主机上指定端口的开放状态

import argparse # 用于解析命令行参数 from socket import * # 导入 socket 库的所有内容,用于网络通信 from threading import * # 导入 threading 库的所有内容,用于多线程操作 # 创建一个信号量,初始值为 1,用于线程同步&…

NumPy与TensorFlow-tf.tensor异同点

NumPy数组与TenosrFlow中的张量(即tf.tensor)有很多相似地方,而且可以互相转换。下表总结了NumPy与tf.tensor的异同点。 NumPy与tf.tensor的异同点 操作类别NumPyTensorFlow 2数据类型np.ndarraytf.Tensornp.float32tf.float32np.float64tf…