PyTorch使用教程(11)-cuda的使用方法

server/2025/1/23 11:48:11/

1. 基本概念

CUDA(Compute Unified Device Architecture)是NVIDIA开发的一种并行计算平台和编程模型,专为图形处理器(GPU)设计,旨在加速科学计算、工程计算和机器学习等领域的高性能计算任务。CUDA允许开发人员使用GPU进行通用计算(也称为GPGPU,General-Purpose computing on Graphics Processing Units)。
在这里插入图片描述

2.Torch与CUDA

Torch是一个流行的深度学习库,由PyTorch开发团队创建,主要用于Python编程环境。当Torch结合CUDA时,它可以显著提升训练深度神经网络的速度。通过将数据和计算转移到GPU上,利用GPU的大量并行核心处理大量矩阵运算,实现对大规模数据集的高效处理。

3. 核心功能

(1)、torch.cuda.device
torch.cuda.device是一个上下文管理器,用于更改所选设备。它允许你在代码块内指定张量或模型应在哪个GPU上创建或执行。

(2)、 torch.cuda.is_available
torch.cuda.is_available()函数用于检查CUDA是否可用。如果系统中安装了NVIDIA的显卡驱动和CUDA工具包,并且PyTorch版本支持CUDA,那么该函数将返回True。

(3)、torch.device
torch.device是一个对象,表示张量可以存放的设备。它可以是CPU或某个GPU。通过指定torch.device(“cuda”),你告诉PyTorch你希望在一个支持CUDA的NVIDIA GPU上执行张量运算。如果有多个GPU,可以通过指定GPU的索引来选择其中一个,例如torch.device(“cuda:0”)表示第一个GPU,torch.device(“cuda:1”)表示第二个GPU,依此类推。

(4)、张量移动
在PyTorch中,你可以使用.to(‘cuda’)或.cuda()函数将张量(Tensor)从CPU移动到GPU。同样,你也可以使用这些方法将模型参数和优化器移动到GPU上。

4.功能示例

(1)、检查CUDA是否可用

python">import torchif torch.cuda.is_available():print("CUDA is available. Number of GPUs:", torch.cuda.device_count())
else:print("CUDA is not available.")

(2)、创建张量并移动到GPU

python">import torch# 在CPU上创建一个张量
x = torch.randn(3, 3)# 检查CUDA是否可用
if torch.cuda.is_available():# 将张量移动到GPUdevice = torch.device("cuda")x_gpu = x.to(device)print(x_gpu)  # 这将显示张量的设备为 "cuda:0"# 直接在GPU上创建另一个张量y = torch.randn(3, 3, device=device)z = x_gpu + y  # 这个加法操作在GPU上执行print(z)

(3)、在不同GPU上创建和操作张量

python">import torch# 在默认GPU上创建一个张量
x = torch.cuda.FloatTensor(1)
print("x.get_device() ==", x.get_device())  # 输出 0# 在GPU 1上创建一个张量
with torch.cuda.device(1):a = torch.cuda.FloatTensor(1)print("a.get_device() ==", a.get_device())  # 输出 1# 将CPU张量转移到GPU 1b = torch.FloatTensor(1).cuda()print("b.get_device() ==", b.get_device())  # 输出 1c = a + bprint("c.get_device() ==", c.get_device())  # 输出 1# 在GPU 0上的张量操作
z = x + x  # 仍然在GPU 0上
print("z.get_device() ==", z.get_device())  # 输出 0# 在特定GPU上创建张量
d = torch.randn(2).cuda(2)
print("d.get_device() ==", d.get_device())  # 输出 2

(4)、将模型和优化器移动到GPU

python">import torch
import torch.nn as nn
import torch.optim as optim# 创建一个简单的神经网络模型
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.fc1 = nn.Linear(3, 2)self.fc2 = nn.Linear(2, 1)def forward(self, x):x = self.fc1(x)x = self.fc2(x)return xnet = Net()# 检查CUDA是否可用
if torch.cuda.is_available():# 将模型参数和优化器移动到GPUdevice = torch.device("cuda")net = net.to(device)print(net)optimizer = optim.SGD(net.parameters(), lr=0.01)optimizer = optimizer.to(device)  # 注意:优化器通常不需要显式移动到GPU# 创建一些假数据并移动到GPU
inputs = torch.randn(20, 3).to(device)
targets = torch.randint(0, 2, (20,)).to(device)# 定义损失函数
criterion = nn.CrossEntropyLoss()# 训练模型
net.train()
for epoch in range(5):optimizer.zero_grad()outputs = net(inputs)loss = criterion(outputs, targets)loss.backward()optimizer.step()print(f'Epoch {epoch+1}, Loss: {loss.item()}')

5. 使用注意事项

(1)、GPU内存限制
显卡的内存是有限的,如果模型或数据过大,可能会导致内存不足的问题。可以通过减小批量大小、使用更小的模型或者使用分布式训练等方式来解决。

(2)、数据类型匹配
在使用CUDA加速时,需要确保模型和数据的数据类型匹配。通常情况下,模型和数据都应该使用torch.cuda.FloatTensor类型。

(3)、CUDA版本和驱动兼容性
确保安装了适用于CUDA的PyTorch版本以及相应版本的NVIDIA显卡驱动。不同版本的CUDA和PyTorch之间可能存在兼容性问题。

(4)、避免跨GPU操作
默认情况下,PyTorch不支持跨GPU操作。如果需要对分布在不同设备上的张量进行操作,需要显式地进行数据传输,这可能会引入额外的开销。

(5)、异步数据传输
为了将数据传输与计算重叠,可以使用异步的GPU副本。只需在调用cuda()时传递一个额外的async=True参数。此外,通过将pin_memory=True传递给DataLoader的构造函数,可以使DataLoader将batch返回到固定内存中,从而加快主机到GPU的复制速度。

(6)、多GPU训练
对于多GPU训练,PyTorch提供了nn.DataParallel等工具和函数来简化这一过程。然而,在使用多进程进行CUDA模型训练时需要注意线程安全和资源竞争等问题。

6、小结

torch.cuda是PyTorch中用于在NVIDIA GPU上进行加速计算的重要模块。通过合理利用CUDA的并行计算能力,可以显著提升深度学习模型的训练和推理速度。然而,在使用CUDA时也需要注意一些细节和限制,以确保程序的正确性和性能。通过本文的介绍和示例代码,希望读者能够更好地理解和使用torch.cuda进行深度学习开发。


http://www.ppmy.cn/server/160725.html

相关文章

Kafka中bin目录下面kafka-run-class.sh脚本中的JAVA_HOME

在Kafka中,bin目录下面的kafka-run-class.sh脚本中关于JAVA_HOME的脚本如下: # Which java to use if [ -z "$JAVA_HOME" ]; thenJAVA"java" elseJAVA"$JAVA_HOME/bin/java" fi 这段脚本是关于决定在执行 Kafka 时应该使…

HTML 基础入门:核心标签全解析

在网页开发的世界里,HTML(超文本标记语言)是基石般的存在。它负责构建网页的基本结构,为用户呈现出丰富多样的内容。今天,就让我们一起深入了解 HTML 中几个极为关键的基础标签,开启网页创作的第一步。 一…

[Spring] OpenFeign的使用

🌸个人主页:https://blog.csdn.net/2301_80050796?spm1000.2115.3001.5343 🏵️热门专栏: 🧊 Java基本语法(97平均质量分)https://blog.csdn.net/2301_80050796/category_12615970.html?spm1001.2014.3001.5482 🍕 Collection与…

【大模型】ChatGPT 高效处理图片技巧使用详解

目录 一、前言 二、ChatGPT 4 图片处理介绍 2.1 ChatGPT 4 图片处理概述 2.1.1 图像识别与分类 2.1.2 图像搜索 2.1.3 图像生成 2.1.4 多模态理解 2.1.5 细粒度图像识别 2.1.6 生成式图像任务处理 2.1.7 图像与文本互动 2.2 ChatGPT 4 图片处理应用场景 三、文生图操…

maven常见知识点

1、maven是什么? maven是Java的包管理工具,因为java包太多了,使用工具统一管理。 2、引入同一个包时使用哪个? 会遵循 路径最短优先 和 声明顺序优先 两大原则。解决这个问题的过程也被称为 Maven 依赖调解。 3、什么是 POM&…

某大厂一面:说说ThreadLocal的实现原理

ThreadLocal 是 Java 中一个非常有用的类,它提供了线程本地存储的功能。其作用是为每个线程提供独立的变量副本,使得不同线程访问时互不干扰。以下是 ThreadLocal 的详细原理: 1. ThreadLocal 类的基本作用 ThreadLocal 通过保证每个线程都…

Web 音视频(三)在浏览器中创建视频

前言 ​ 在 WebCodecs 之前,由于编解码能力的缺失,几乎无法在纯浏览器中编辑、创建视频。 WebCodecs 补齐了编解码能力,相当于在浏览器中提供了视频创作能力。 预计 WebCodecs 将会像 HTML5 技术(Video、Audio、MSE...&#xff0…

arkime和elasticsearch 安装方法三

Ubuntu新机 sudo apt upgrade sudo apt install open-vm-tools-desktop -y sudo reboot 然后换源 cp /etc/apt/source.list /etc/apt/source.list.bak sudo apt update nano /etc/apt/source.list deb https://mirrors.aliyun.com/ubuntu/ jammy main restricted unive…