TensorFlow 与 PyTorch 的直观区别

devtools/2025/2/6 2:37:29/

背景

TensorFlow 与 PyTorch 都是比较流行的深度学习框架。tf 由谷歌在 2015 年发布,而 PyTorch 则是 Facecbook AI 研究团队 2016 年在原来 Torch 的基础上发布的。

tf 采用的是静态计算图。这意味着在执行任何计算之前,你需要先定义好整个计算图,之后再执行。这种方式适合大规模生产环境,可以优化计算图以提高效率。tf 的早期版本比较复杂,但在集成 Keras 库之后相当容易上手。

PyTorch 的设计目标是提供一个易于使用、灵活且高效的框架,所以采用的是动态图,特别适合研究人员和开发人员进行快速实验和原型设计。它强调灵活性和易用性,采用了动态图机制,使得代码更接近于 Python 原生风格,便于调试和修改。PyTorch 使用更加像原来的 Python 代码。

总体来说,TensorFlow 更加容易上手,PyTorch 更加灵活且需要自己操作,例如 tf 提供了训练的方法,而 PyTorch 则需要手动训练:

python"># TensorFlow
model.fit(train_images, train_labels, epochs=5, batch_size=128)

而 PyTorch 需要先手动将数据分批,然后自己编写训练和测试函数,函数详细内容后面会写:

python"># PyTorch
epochs = 5
for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")train(train_dataloader, model, loss_fn, optimizer)test(test_dataloader, model, loss_fn)
print("Done!")

示例

MNIST 数据集

对于两者的示例,仍然使用 MNIST 手写数字集来做演示。MNIST 是 28 * 28 大小的单通道(黑白)手写数字图片,每个像素亮度值为 0 ~ 255。

首先加载数据集:

python"># TensorFlow
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

PyTorch 除了加载数据,还需要定义 DataLoader,因为其提供的框架更加底层,需要自己定义加载器,包括数据打包,转换等更加灵活的功能。

python"># PyTorch
to_tensor = transforms.Compose([transforms.ToTensor()])
training_data = datasets.MNIST(root="data", train=True, download=True, transform=to_tensor)
test_data = datasets.MNIST(root="data", train=False, download=True, transform=to_tensor)train_dataloader = DataLoader(training_data, batch_size=128, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=128, shuffle=True)

定义神经网络

TensorFlow 集成了 Keras,这里可以看见对神经网络的定义非常简洁明了:

python"># TensorFlow
model = keras.Sequential([layers.Flatten(input_shape=(28, 28)),  # Flatten the input image to a vector of size 784layers.Dense(512, activation="relu"),layers.Dense(10, activation="softmax")
])

在 PyTorch 中,更倾向于将神经网络打包成一个类,这个类由框架提供的网络模型继承。

python"># PyTorch
class NeuralNetwork(nn.Module):def __init__(self):super(NeuralNetwork, self).__init__()self.flatten = nn.Flatten()self.fc1 = nn.Linear(in_features=28 * 28, out_features=512)# Output layer with 10 neurous for classificationself.fc2 = nn.Linear(in_features=512, out_features=10)def forward(self, x):x = self.flatten(x) # Flatten the input tensorx = nn.functional.relu(self.fc1(x)) # ReLU activation after first layerx = self.fc2(x)return xprint(model)

PyTorch 可以检查神经网络模型

NeuralNetwork((flatten): Flatten(start_dim=1, end_dim=-1)(fc1): Linear(in_features=784, out_features=512, bias=True)(fc2): Linear(in_features=512, out_features=10, bias=True)
)

训练

TensorFlow 在网络模型定义完成后,指定损失函数和优化器,来使模型训练让参数收敛。

python">model.compile(optimizer="rmsprop",loss="sparse_categorical_crossentropy",metrics=["accuracy"])model.fit(train_images, train_labels, epochs=5, batch_size=128)
Epoch 1/5
469/469 [==============================] - 3s 5ms/step - loss: 5.4884 - accuracy: 0.8992
Epoch 2/5
469/469 [==============================] - 2s 4ms/step - loss: 0.6828 - accuracy: 0.9538
Epoch 3/5
469/469 [==============================] - 2s 4ms/step - loss: 0.4634 - accuracy: 0.9662
Epoch 4/5
469/469 [==============================] - 2s 4ms/step - loss: 0.3742 - accuracy: 0.9730
Epoch 5/5
469/469 [==============================] - 2s 4ms/step - loss: 0.2930 - accuracy: 0.9774

而在 PyTorch 中则更加复杂,需要自己定义训练函数和测试函数,并不断训练,框架只提供了一些基础的训练所需函数:

python"># Define loss function and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.RMSprop(model.parameters(), lr=0.001)def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)model.train()for batch, (X, y) in enumerate(dataloader):X, y = X.to(device), y.to(device)# Compute prediction errorpred = model(X)loss = loss_fn(pred, y)# Backpropagationloss.backward()optimizer.step()optimizer.zero_grad()if batch % 100 == 0:loss, current = loss.item(), (batch + 1) * len(X)print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")def test(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchescorrect /= sizeprint(f"Test Error: \n Accuracy: {100*(correct):>0.1f}%, Avg loss: {test_loss:>8f}\n")epochs = 5
for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")train(train_dataloader, model, loss_fn, optimizer)test(test_dataloader, model, loss_fn)
print("Done!")
Epoch 1
-------------------------------
loss: 2.319496  [  128/60000]
loss: 0.443893  [12928/60000]
loss: 0.253097  [25728/60000]
loss: 0.106967  [38528/60000]
loss: 0.208099  [51328/60000]
Test Error: Accuracy: 95.9%, Avg loss: 0.136413Epoch 2
-------------------------------
loss: 0.102781  [  128/60000]
loss: 0.089506  [12928/60000]
loss: 0.177988  [25728/60000]
loss: 0.058250  [38528/60000]
loss: 0.131542  [51328/60000]
Test Error: Accuracy: 97.3%, Avg loss: 0.087681Epoch 3
-------------------------------
loss: 0.100185  [  128/60000]
loss: 0.021117  [12928/60000]
loss: 0.058108  [25728/60000]
loss: 0.070415  [38528/60000]
loss: 0.050509  [51328/60000]
Test Error: Accuracy: 97.7%, Avg loss: 0.075040Epoch 4
-------------------------------
loss: 0.051223  [  128/60000]
loss: 0.049627  [12928/60000]
loss: 0.025712  [25728/60000]
loss: 0.090960  [38528/60000]
loss: 0.046523  [51328/60000]
Test Error: Accuracy: 97.9%, Avg loss: 0.066997Epoch 5
-------------------------------
loss: 0.012129  [  128/60000]
loss: 0.019118  [12928/60000]
loss: 0.057839  [25728/60000]
loss: 0.031959  [38528/60000]
loss: 0.020570  [51328/60000]
Test Error: Accuracy: 98.0%, Avg loss: 0.062022Done!

http://www.ppmy.cn/devtools/156426.html

相关文章

CSS 图像、媒体和表单元素的样式化指南

CSS 图像、媒体和表单元素的样式化指南 1. 替换元素:图像和视频1.1 调整图像大小示例代码:调整图像大小 1.2 使用 object-fit 控制图像显示示例代码:使用 object-fit 2. 布局中的替换元素示例代码:Grid 布局中的图像 3. 表单元素的…

Android-音频采集

前言 音视频这块,首先是要先采集音频。今天我们就来深入探讨一下 Android 音频采集的两大类型:Mic 音频采集和系统音频采集。 Mic音频采集 Android提供了两个API用于实现录音功能:android.media.AudioRecord、android.media.MediaRecorder。…

LeetCode --- 434周赛

目录 3432. 统计元素和差值为偶数的分区方案 3433. 统计用户被提及情况 3434. 子数组操作后的最大频率 3435. 最短公共超序列的字母出现频率 一、统计元素和差值为偶数的分区方案 本题可以直接模拟,当然我们也可以来从数学的角度来分析一下这题的本质 设 S S S …

CF 761A.Dasha and Stairs(Java实现)

题目分析 大概意思是输入偶数值奇数值,判断是否能够凑成一连串数字 思路分析 能够连成一串数字的条件考虑:1.偶数与奇数差为1;2.偶数与奇数相等,且不为0 代码 import java.util.*;public class Main {public static void…

「AI学习笔记」深度学习进化史:从神经网络到“黑箱技术”(三)

在这篇文章中,我们将探讨深度学习(DL)这一领域的最新发展,以及它如何从传统机器学习(ML)中独立出来,成为一个独立的生态系统。深度学习的核心思想与我们大脑中的神经网络高度相似,因…

关于图像锐化的一份介绍

在这篇文章中,我将介绍有关图像锐化有关的知识,具体包括锐化的简单介绍、一阶锐化与二阶锐化等方面内容。 一、锐化 1.1 概念 锐化(sharpening)就是指将图象中灰度差增大的方法,一次来增强物体的轮廓与边缘。因为发…

【JavaScript】Web API事件流、事件委托

目录 1.事件流 1.1 事件流和两个阶段说明 1.2 事件捕获 1.3 事件冒泡 1.4 阻止冒泡 1.5 解绑事件 L0 事件解绑 L2 事件解绑 鼠标经过事件的区别 两种注册事件的区别 2.事件委托 案例 tab栏切换改造 3.其他事件 3.1 页面加载事件 3.2 页面滚动事件 3.2 页面滚…

Go学习:类型转换需注意的点 以及 类型别名

目录 1. 类型转换 2. 类型别名 1. 类型转换 在从前的学习中,知道布尔bool类型变量只有两种值true或false,C/C、Python、JAVA等编程语言中,如果将布尔类型bool变量转换为整型int变量,通常采用 “0为假,非0为真”的方…