pytorch中的TensorDataset和DataLoader

server/2024/12/22 0:59:57/

TensorDataset 详解

TensorDataset 主要用于将多个 Tensor 组合在一起,方便对数据进行统一处理。它可以用于简单地将特征和标签配对,也可以将多个特征张量组合在一起。

1. 将特征和标签组合

假设我们有一组图像数据(特征)和对应的标签,我们可以将它们组合成一个 TensorDataset

import torch
from torch.utils.data import TensorDataset# 创建输入数据(图像)和标签
images = torch.randn(100, 3, 28, 28)  # 100张图像,每张图像3通道,28x28像素
labels = torch.randint(0, 10, (100,))  # 100个标签,范围在0到9之间# 创建 TensorDataset
dataset = TensorDataset(images, labels)# 访问数据集中的特定样本
sample_image, sample_label = dataset[0]
print(f"Sample Image Shape: {sample_image.shape}")  # 输出: Sample Image Shape: torch.Size([3, 28, 28])
print(f"Sample Label: {sample_label}")  # 输出: Sample Label: 3

在这个例子中,我们创建了一个包含100张图像和对应标签的 TensorDataset。通过 dataset[0],我们可以访问第一个样本的图像和标签。

2. 组合多个特征张量

除了将特征和标签组合,TensorDataset 还可以将多个特征张量组合在一起。例如,假设我们有两个不同的特征张量,我们可以将它们组合成一个 TensorDataset

# 创建两个特征张量
feature1 = torch.randn(100, 50)  # 100个样本,每个样本50维
feature2 = torch.randn(100, 30)  # 100个样本,每个样本30维# 创建 TensorDataset
dataset = TensorDataset(feature1, feature2)# 访问数据集中的特定样本
sample_feature1, sample_feature2 = dataset[0]
print(f"Sample Feature1 Shape: {sample_feature1.shape}")  # 输出: Sample Feature1 Shape: torch.Size([50])
print(f"Sample Feature2 Shape: {sample_feature2.shape}")  # 输出: Sample Feature2 Shape: torch.Size([30])

在这个例子中,我们创建了一个包含两个特征张量的 TensorDataset,并通过 dataset[0] 访问第一个样本的两个特征。

DataLoader 详解

DataLoader 主要用于批量加载数据,并支持多种数据处理功能,如随机打乱、多线程加载等。

1. 批量处理数据

DataLoader 可以将数据集划分为多个批次(batch),便于模型训练。

from torch.utils.data import DataLoader# 创建 DataLoader
train_loader = DataLoader(dataset, batch_size=32, shuffle=False)# 遍历 DataLoader
for batch_features, batch_labels in train_loader:print(f"Batch Features Shape: {batch_features.shape}")  # 输出: Batch Features Shape: torch.Size([32, 3, 28, 28])print(f"Batch Labels Shape: {batch_labels.shape}")  # 输出: Batch Labels Shape: torch.Size([32])# 这里可以进行训练操作,如前向传播、反向传播等

在这个例子中,train_loader 将数据集划分为大小为32的批次。通过遍历 train_loader,我们可以轻松地获取每个批次的特征和标签。

2. 数据打乱

DataLoader 可以通过设置 shuffle=True 来在每个 epoch 开始时随机打乱数据,避免模型学习到数据的顺序。

# 创建 DataLoader,并设置 shuffle=True
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)# 遍历 DataLoader
for epoch in range(2):  # 假设我们要训练两个 epochfor batch_features, batch_labels in train_loader:print(f"Epoch {epoch}, Batch Features Shape: {batch_features.shape}")# 这里可以进行训练操作

在这个例子中,每次 epoch 开始时,数据都会被随机打乱,确保模型不会受到数据顺序的影响。

3. 多线程加载

DataLoader 支持通过设置 num_workers 参数来使用多线程并行加载数据,加快数据读取速度。

# 创建 DataLoader,并设置 num_workers=4
train_loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)# 遍历 DataLoader
for batch_features, batch_labels in train_loader:print(f"Batch Features Shape: {batch_features.shape}")# 这里可以进行训练操作

在这个例子中,我们设置了 num_workers=4,表示使用4个线程来并行加载数据,从而加快数据读取速度。

结合使用 TensorDataset 和 DataLoader

以下是一个完整的示例,展示了如何结合使用 TensorDataset 和 DataLoader 进行数据加载和训练。

import torch
from torch.utils.data import TensorDataset, DataLoader# 创建输入数据和标签
images = torch.randn(1000, 3, 28, 28)  # 1000张图像,每张图像3通道,28x28像素
labels = torch.randint(0, 10, (1000,))  # 1000个标签,范围在0到9之间# 创建 TensorDataset
dataset = TensorDataset(images, labels)# 创建 DataLoader
train_loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)# 遍历 DataLoader 进行训练
for epoch in range(2):for batch_images, batch_labels in train_loader:print(f"Epoch {epoch}, Batch Images Shape: {batch_images.shape}")print(f"Epoch {epoch}, Batch Labels Shape: {batch_labels.shape}")# 这里可以进行训练操作,如前向传播、反向传播等

在这个例子中,我们首先使用 TensorDataset 将图像和标签组合在一起,然后通过 DataLoader 进行批量加载和训练。通过设置 shuffle=True 和 num_workers=4,我们实现了数据的随机打乱和多线程加载。

总结

  • TensorDataset 用于将多个 Tensor 组合在一起,方便对数据进行统一处理。
    • 可以组合特征和标签。
    • 可以组合多个特征张量。
  • DataLoader 用于批量加载数据,支持多种数据处理功能。
    • 支持批量处理数据。
    • 支持数据打乱。
    • 支持多线程加载。


http://www.ppmy.cn/server/126806.html

相关文章

Unity初识+面板介绍

Unity版本使用 小版本号高,出现bug可能性更小;一台电脑可以安装多个版本的Unity,但是需要安装在不同路径;安装Unity时不能有中文路径;Unity项目路径也不要有中文。 Scene面板 相当于拍电影的片场,Unity程…

开发微信小程序 基础03

WXSS(类似CSS) 定义: WXSS (WeiXin Style Sheets)是一套样式语言,用于描述 WXML的组件样式,类似于网页开发中的 CSS。 分类: 全局样式:定义在 app.wxss 中的样式为全局样式,作用于每一个页面 局部样式&…

Linux——环境变量

文章目录 1.什么是环境变量2.常见环境变量3. 如何查看环境变量4. 测试PATH5.测试HOME6. 和环境变量有关的指令7.环境变量的组织方式8. 通过代码获取环境变量main函数的第3个参数 9. 环境变量具有全局属性 当我们在Linux操作系统进行操作时,我们会发现使用系统命令的…

Prompt:在AI时代,提问比答案更有价值

你好,我是三桥君 随着AI技术的飞速发展,我们进入了一个信息爆炸的时代。在这个时代,只要你会提问,AI就能为你提供满意的答案。这种现象让很多人开始思考:在这个答案触手可及的时代,答案的价值是否还像以前…

如何设置 IIS 用以运行Delphi 编译的 CGI 程序

使用 Delphi 的 WebBroker 架构,可以非常方便地开发 Web 服务器程序。 结合一些好的前端库,可以很简单地作出非常漂亮功能强大的基于 WEB 页面的程序。 具体做法这里就不细说了。 在 Delphi 里面新建一个 Web Server 的工程,选择 IIS CGI …

《深度学习》OpenCV 图像拼接 拼接原理、参数解析、案例实现

目录 一、图像拼接 1、直接看案例 图1与图2展示: 合并完结果: 2、什么是图像拼接 3、图像拼接步骤 1)加载图像 2)特征点检测与描述 3)特征点匹配 4)图像配准 5)图像变换和拼接 6&am…

滚雪球学MySQL[8.1讲]:MySQL扩展功能

全文目录: 前言8. MySQL扩展功能8.1 存储过程与函数8.1.1 存储过程8.1.2 函数 8.2 触发器8.2.1 创建触发器 8.3 事件调度8.3.1 创建事件8.3.2 管理事件 8.4 JSON与全文检索8.4.1 JSON数据类型8.4.2 全文检索 下期内容预告 前言 在上一期的文章中,我们深…

YOLOv8 Flask整合问题

YOLOv8 Flask整合问题 yolov8 flask 后代码没有进行推理问题。 Bug model.predict()pyinstallerHTTPServer/flask: not executing yolov8是异步线程调用了,flask打包exe后会应该异步问题,model.predict()不会进行返回,导致没有看着没有执行而…