Pytorch指定数据加载器使用子进程

news/2025/1/15 12:57:08/
torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True,num_workers=4, pin_memory=True)

num_workers 参数是 DataLoader 类的一个参数,它指定了数据加载器使用的子进程数量。通过增加 num_workers 的数量,可以并行地读取和预处理数据,从而提高数据加载的速度。

通常情况下,增加 num_workers 的数量可以提高数据加载的效率,因为它可以使数据加载和预处理工作在多个进程中同时进行。然而,当 num_workers 的数量超过一定阈值时,增加更多的进程可能不会再带来更多的性能提升,甚至可能会导致性能下降。

这是因为增加 num_workers 的数量也会增加进程间通信的开销。当 num_workers 的数量过多时,进程间通信的开销可能会超过并行化所带来的收益,从而导致性能下降。

此外,还需要考虑到计算机硬件的限制。如果你的计算机 CPU 核心数量有限,增加 num_workers 的数量也可能会导致性能下降,因为每个进程需要占用 CPU 核心资源。

因此,对于 num_workers 参数的设置,需要根据具体情况进行调整和优化。通常情况下,一个合理的 num_workers 值应该在 2 到 8 之间,具体取决于你的计算机硬件配置和数据集大小等因素。在实际应用中,可以通过尝试不同的 num_workers 值来找到最优的配置。

综上所述,当 num_workers 的值从 4 增加到 8 时,如果你的计算机硬件配置和数据集大小等因素没有发生变化,那么两者之间的性能差异可能会很小,或者甚至没有显著差异。

测试代码如下

import torch
import torchvision
import matplotlib.pyplot as plt
import torchvision.models as models
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
import timeif __name__ == '__main__':mp.freeze_support()train_on_gpu = torch.cuda.is_available()if not train_on_gpu:print('CUDA is not available. Training on CPU...')else:print('CUDA is available! Training on GPU...')device = torch.device("cuda" if torch.cuda.is_available() else "cpu")batch_size = 4# 设置数据预处理的转换transform = torchvision.transforms.Compose([torchvision.transforms.Resize((512,512)),  # 调整图像大小为 224x224torchvision.transforms.ToTensor(),  # 转换为张量torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 归一化])dataset = torchvision.datasets.ImageFolder('C:\\Users\\ASUS\\PycharmProjects\\pythonProject1\\cats_and_dogs_train',transform=transform)val_ratio = 0.2val_size = int(len(dataset) * val_ratio)train_size = len(dataset) - val_sizetrain_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])train_dataset = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True,num_workers=4, pin_memory=True)val_dataset = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True,num_workers=4, pin_memory=True)model = models.resnet18()num_classes = 2for param in model.parameters():param.requires_grad = Falsemodel.fc = nn.Sequential(nn.Dropout(),nn.Linear(model.fc.in_features, num_classes),nn.LogSoftmax(dim=1))optimizer = optim.Adam(model.parameters(), lr=0.001)criterion = nn.CrossEntropyLoss().to(device)model.to(device)filename = "recognize_cats_and_dogs.pt"def save_checkpoint(epoch, model, optimizer, filename):checkpoint = {'epoch': epoch,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': loss,}torch.save(checkpoint, filename)num_epochs = 3train_loss = []for epoch in range(num_epochs):running_loss = 0correct = 0total = 0epoch_start_time = time.time()for i, (inputs, labels) in enumerate(train_dataset):# 将数据放到设备上inputs, labels = inputs.to(device), labels.to(device)# 前向计算outputs = model(inputs)# 计算损失和梯度loss = criterion(outputs, labels)optimizer.zero_grad()loss.backward()# 更新模型参数optimizer.step()# 记录损失和准确率running_loss += loss.item()train_loss.append(loss.item())_, predicted = torch.max(outputs.data, 1)correct += (predicted == labels).sum().item()total += labels.size(0)accuracy_train = 100 * correct / total# 在测试集上计算准确率with torch.no_grad():running_loss_test = 0correct_test = 0total_test = 0for inputs, labels in val_dataset:inputs, labels = inputs.to(device), labels.to(device)outputs = model(inputs)loss = criterion(outputs, labels)running_loss_test += loss.item()_, predicted = torch.max(outputs.data, 1)correct_test += (predicted == labels).sum().item()total_test += labels.size(0)accuracy_test = 100 * correct_test / total_test# 输出每个 epoch 的损失和准确率epoch_end_time = time.time()epoch_time = epoch_end_time - epoch_start_timeprint("Epoch [{}/{}], Time: {:.4f}s, Loss: {:.4f}, Train Accuracy: {:.2f}%, Loss: {:.4f}, Test Accuracy: {:.2f}%".format(epoch + 1, num_epochs,epoch_time,running_loss / len(val_dataset),accuracy_train, running_loss_test / len(val_dataset), accuracy_test))save_checkpoint(epoch, model, optimizer, filename)plt.plot(train_loss, label='Train Loss')# 添加图例和标签plt.legend()plt.xlabel('Epochs')plt.ylabel('Loss')plt.title('Training Loss')# 显示图形plt.show()

不同num_workers的结果如下


http://www.ppmy.cn/news/1178719.html

相关文章

(二开)Flink 修改源码拓展 SQL 语法

1、Flink 扩展 calcite 中的语法解析 1)定义需要的 SqlNode 节点类-以 SqlShowCatalogs 为例 a)类位置 flink/flink-table/flink-sql-parser/src/main/java/org/apache/flink/sql/parser/dql/SqlShowCatalogs.java 核心方法: Override pu…

Qt 实现侧边栏滑出菜单效果

1.效果图 2.实现原理 这里做了两个widget,一个是 展示底图widget,一个是 展示动画widget。 这两个widget需要重合。动画widget需要设置属性叠加到底图widget上面,设置如下属性: setWindowFlags(Qt::FramelessWindowHint | Qt::…

软考系列(系统架构师)- 2014年系统架构师软考案例分析考点

试题一 软件架构(MYC 架构、扩展接口模式) MVC架构风格最初是Smalltalk-80中用来构建用户界面时采用的架构设计风格。其中M代表模型(Model),V代表视图(View),C代表控制器(Controller)。在该风格…

ubuntu安装配置svn

目录 简介安装SVN 启动模式方式1:单库svnserve方式方式2:多库svnserve方式 SVN 创建版本库1.svn 服务配置文件 svnserve.conf2.用户名口令文件 passwd3.权限配置文件4.多库方式运行 SVN 检出操作SVN 解决冲突SVN 提交操作SVN 版本回退SVN 查看历史信息1.svn log2.svn diff3.svn…

通过阿里云创建accessKeyId和accessKeySecret

我们想实现服务端向个人发送短信验证码 需要通过accessKeyId和accessKeySecret 这里可以白嫖阿里云的 这里 我们先访问阿里云官网 阿里云地址 进入后搜索并进入短信服务 如果没登录 就 登录一下先 然后在搜索框搜索短信服务 点击进入 因为我也是第一次操作 我们一起点免费开…

DBeaver连接数据库报错:Public Key Retrieval is not allowed 的解决方案

写在前面: DBeaver是一款免费的数据库管理工具,安装也是傻瓜式一键安装,比较推荐。 DBeaver官网(加载有点慢,耐心等待):DBeaver Community | Free Universal Database Tool 报错详情&#xff…

【UE】抓取物体

目录 效果 步骤 一、制作准心 二、简单的第三人称视角偏移 三、基于屏幕正中央的打点与射线 四、物理抓取的实现(抓取、放下、丢出) 效果 步骤 一、制作准心 1. 新建一个HUD,这里命名为“HUD_ZhunXin”,同时复制一个第三人…

Banana Pi BPI-W3(Armsom W3)RK3588开当板之调试UART

前言 本文主要讲解如何关于RK3588开发板UART的使用和调试方法,包括UART作为普通串口和控制台两种不同使用场景 一. 功能特点 Rockchip UART (Universal Asynchronous Receiver/Transmitter) 基于16550A串口标准,完整模块支持以下功能: 支…