《Python实战进阶》第33集:PyTorch 入门-动态计算图的优势

ops/2025/3/29 4:55:02/

第33集:PyTorch 入门-动态计算图的优势


摘要

PyTorch 是一个灵活且强大的深度学习框架,其核心特性是动态计算图机制。本集将带您探索 PyTorch 的张量操作、自动求导系统以及动态计算图的特点与优势,并通过实战案例演示如何使用 PyTorch 实现线性回归和构建简单的图像分类模型。我们将重点突出 PyTorch 在研究与开发中的灵活性及其在 AI 大模型训练中的应用。
在这里插入图片描述


核心概念和知识点

1. 张量操作与自动求导

  • 张量(Tensor):类似于 NumPy 数组,但支持 GPU 加速。
  • 自动求导(Autograd):PyTorch 提供了自动微分功能,能够高效计算梯度,用于优化模型参数。

2. 动态计算图的特点与优势

  • 动态计算图:PyTorch 的计算图是在运行时动态构建的,支持即时调试和修改。
  • 灵活性:适合实验性研究,便于实现复杂的模型架构。
  • 直观性:代码执行过程清晰可见,易于理解。

3. 自定义模型与训练循环

  • 模型定义:通过继承 torch.nn.Module 自定义模型结构。
  • 训练循环:手动实现前向传播、损失计算和反向传播,提供更细粒度的控制。

4. AI 大模型相关性分析

PyTorch 是目前主流的 AI 大模型框架之一,广泛应用于 GPT、BERT 等模型的训练:

  • 分布式训练支持:通过 torch.distributed 模块实现多 GPU 和多节点训练。
  • 生态系统丰富:结合 Hugging Face Transformers 等库,可快速搭建和训练大模型。

实战案例

案例 1:使用 PyTorch 实现线性回归

背景

线性回归是最基础的机器学习任务之一,我们使用 PyTorch 实现一个简单的线性回归模型。

代码实现
python">import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt# 数据生成
torch.manual_seed(42)
x = torch.linspace(-1, 1, 100).reshape(-1, 1)  # 输入特征
y = 3 * x + 2 + 0.2 * torch.randn(x.size())   # 带噪声的目标值# 定义模型
class LinearRegressionModel(nn.Module):def __init__(self):super().__init__()self.linear = nn.Linear(1, 1)  # 单输入单输出的线性层def forward(self, x):return self.linear(x)model = LinearRegressionModel()# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)# 训练模型
epochs = 100
for epoch in range(epochs):# 前向传播y_pred = model(x)loss = criterion(y_pred, y)# 反向传播与优化optimizer.zero_grad()loss.backward()optimizer.step()if (epoch + 1) % 10 == 0:print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f}")# 可视化结果
predicted = model(x).detach().numpy()
plt.scatter(x.numpy(), y.numpy(), label="Original Data", alpha=0.6)
plt.plot(x.numpy(), predicted, 'r', label="Fitted Line")
plt.legend()
plt.title("Linear Regression with PyTorch")
plt.show()
输出结果
Epoch 10/100, Loss: 0.0431
...
Epoch 100/100, Loss: 0.0012
可视化

案例 2:构建一个简单的图像分类模型

背景

我们使用 CIFAR-10 数据集,构建一个简单的卷积神经网络(CNN)进行图像分类。

代码实现
python">import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms# 数据加载与预处理
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True)testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False)# 定义 CNN 模型
class SimpleCNN(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(3, 16, 3, padding=1)self.pool = nn.MaxPool2d(2, 2)self.fc1 = nn.Linear(16 * 16 * 16, 10)def forward(self, x):x = self.pool(torch.relu(self.conv1(x)))x = x.view(-1, 16 * 16 * 16)x = self.fc1(x)return xmodel = SimpleCNN()# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练模型
for epoch in range(5):  # 仅训练 5 个 epochrunning_loss = 0.0for i, data in enumerate(trainloader):inputs, labels = dataoptimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()if i % 100 == 99:  # 每 100 个 batch 打印一次损失print(f"[{epoch+1}, {i+1}] Loss: {running_loss / 100:.3f}")running_loss = 0.0print("Finished Training")# 测试模型
correct = 0
total = 0
with torch.no_grad():for data in testloader:images, labels = dataoutputs = model(images)_, predicted = torch.max(outputs, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f"Accuracy on Test Set: {100 * correct / total:.2f}%")
输出结果
[1, 100] Loss: 2.123
...
Accuracy on Test Set: 55.25%

总结

PyTorch 的动态计算图机制使其成为深度学习研究与开发的理想工具。通过本集的学习,我们掌握了如何使用 PyTorch 实现线性回归和构建简单的图像分类模型,并了解了其在灵活性和实验性方面的优势。


扩展思考

1. PyTorch 在 AI 大模型训练中的应用

PyTorch 是训练 GPT、BERT 等大模型的核心工具之一。其动态计算图机制使得研究人员能够快速迭代模型架构,而分布式训练支持则确保了大模型的高效训练。

2. PyTorch Lightning 的简化功能

PyTorch Lightning 是一个高级接口,旨在简化 PyTorch 的使用。它隐藏了训练循环的复杂性,同时保留了底层灵活性,特别适合大规模实验和生产环境。


专栏链接:Python实战进阶
下期预告:No34 - 使用 Pandas 高效处理时间序列数据


http://www.ppmy.cn/ops/169937.html

相关文章

鸿蒙开发:父组件如何调用子组件中的方法?

前言 本文基于Api13 很多的场景下,父组件需要触发子组件中的某个方法,来实现一些特定的逻辑,但是ArkUI是声明式UI,不能直接调用子组件中的方法,那么怎么去实现这个功能呢? 举一个很常见的案例,通…

vscode 插件推荐

1、中文化插件 Chinese (Simplified) (简体中文) 2、中文标点符号转英文 中文标点符号转英文 3、标签补全 Auto Close Tag 4、git仓库信息查看 GitLens — Git supercharged 5、随机/顺序数据生成 Insert Sequences 6、html项目本地运行 Live Server 7、代码格式化 7.1、…

VMware打开ubuntu正在使用中怎么解决

1.如图1所示,打开ubuntu,出现该虚拟机正在使用中的情况; 图1 2.如图2所示,找到ubuntu文件夹下.lck的文件夹,删除它们即可; 图2 3.如图3所示,打开虚拟机正常,可以启动。 图3

物联网为什么用MQTT不用 HTTP 或 UDP?

先来两个代码对比,上传温度数据给服务器。 MQTT代码示例 // MQTT 客户端连接到 MQTT 服务器 mqttClient.connect("mqtt://broker.server.com:8883", clientId) // 订阅特定主题 mqttClient.subscribe("sensor/data", qos1) // …

万字C++STL——vector模拟实现

模拟实现总览 namespace wlw {//命名空间为了让其隔离//模拟实现vectortemplate<class T>class vector{public:typedef T* iterator;typedef const T* const_iterator;//默认成员函数vector(); //构造函数vector(size_t n, c…

【从零实现Json-Rpc框架】- 第三方库介绍 - Muduo篇

&#x1f4e2;博客主页&#xff1a;https://blog.csdn.net/2301_779549673 &#x1f4e2;博客仓库&#xff1a;https://gitee.com/JohnKingW/linux_test/tree/master/lesson &#x1f4e2;欢迎点赞 &#x1f44d; 收藏 ⭐留言 &#x1f4dd; 如有错误敬请指正&#xff01; &…

贪心算法——思路与例题

贪心算法&#xff1a;当我们分析一个问题时&#xff0c;我们往往先以最优的方式来解决问题&#xff0c;所以顾名思义为贪心。 例题1 题目分析&#xff1a;这题利用贪心算法来分析&#xff0c;最优解&#xff08;可容纳人数最多时&#xff09;一定是先考虑六人桌&#xff0c;然…

WELL健康建筑认证是什么?

**WELL健康建筑认证&#xff1a;全方位呵护居住者福祉的权威标准** WELL健康建筑认证&#xff0c;这一源自美国的全球性健康建筑标准&#xff0c;宛如建筑界的璀璨明珠&#xff0c;以其独特的光芒照亮了健康建筑的发展之路。它不仅是全球首部专门针对室内环境提升人体健康与福…