PyTorch使用教程(7)-数据集处理

devtools/2025/1/20 12:49:53/

1、基础概念

在PyTorch中,torch.utils.data模块是处理数据集和数据加载的核心工具。以下是该模块中一些基础概念的理解:
在这里插入图片描述

1.1 Dataset

  • 定义:Dataset是一个抽象类,用于表示数据集。用户需要通过继承Dataset类并实现其__len__和__getitem__方法来创建自定义的数据集。

  • 功能:Dataset定义了数据集的内容,它相当于一个类似列表的数据结构,具有确定的长度,并能够用索引获取数据集中的元素。

  • 类型:Dataset主要分为两种类型:map-style和iterable-style。map-style数据集需要实现__getitem__和__len__方法,而iterable-style数据集则需要实现__iter__方法。

python">from typing import Generic, TypeVar, List_T_co = TypeVar('_T_co', covariant=True)class Dataset(Generic[_T_co]):def __getitem__(self, index: int) -> _T_co:raise NotImplementedError("Subclasses of Dataset should implement __getitem__.")def __len__(self) -> int:raise NotImplementedError("Subclasses of Dataset should implement __len__.")def __add__(self, other: "Dataset[_T_co]") -> "ConcatDataset[_T_co]":"""Adds two datasets. This can be useful when you have two datasets with potentiallyoverlapping elements and you want to treat the elements as distinct."""from .dataset_ops import ConcatDatasetreturn ConcatDataset([self, other])

1.2 DataLoader

  • 定义:DataLoader是一个迭代器,用于封装Dataset,并提供一个可迭代对象,方便进行批量加载、数据打乱、并行加载等操作。
  • 功能:DataLoader能够控制batch的大小、batch中元素的采样方法,以及将batch结果整理成模型所需输入形式的方法。
  • 参数:常用的参数包括dataset(表示要加载的数据集对象)、batch_size(表示每个batch的大小)、shuffle(表示是否在每个epoch开始时打乱数据)、num_workers(表示用于数据加载的进程数)等。
python">DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,batch_sampler=None, num_workers=0, collate_fn=None,pin_memory=False, drop_last=False, timeout=0,worker_init_fn=None, *, prefetch_factor=2,persistent_workers=False)

1.3 Sampler

  • 定义:Sampler是一个抽象类,用于从数据集中生成索引。
  • 功能:Sampler的作用是在Dataset上面进行抽样,抽样的方式有多种,如按顺序抽样、随机抽样、在子集合中随机抽样、带权重的抽样等。
  • 类型:包括SequentialSampler、RandomSampler、SubsetRandomSampler、WeightedRandomSampler、BatchSampler等。

1.4 Batching

  • 定义:Batching是指将数据集分成多个小批次(batch)进行处理的过程。
  • 功能:Batching可以提高数据处理的效率,并有助于模型训练过程中的梯度更新和收敛。
  • 实现:通过DataLoader的batch_size参数来实现批量加载。

1.5 Shuffling

  • 定义:Shuffling是指在每个epoch开始时打乱数据集中的元素顺序。
  • 功能:Shuffling有助于提高模型的泛化能力,防止模型对数据的顺序产生依赖。
  • 实现:通过DataLoader的shuffle参数来启用数据打乱功能。

1.6 Multi-process Data Loading

  • 定义:Multi-process Data Loading是指使用多个进程来并行加载数据的过程。
  • 功能:Multi-process Data Loading可以显著提高数据加载的速度,尤其是在处理大规模数据集时。
  • 实现:通过DataLoader的num_workers参数来设置并行加载的进程数。

2、创建数据集

在PyTorch中,创建数据集通常涉及继承torch.utils.data.Dataset类并实现其必需的方法。以下是一个详细的步骤指南,用于创建自定义数据集:

  1. 导入必要的库
    首先,确保你已经导入了PyTorch和其他可能需要的库。
python">import torch
from torch.utils.data import Dataset
  1. 继承Dataset类
    创建一个新的类,继承自Dataset。
python">class CustomDataset(Dataset):def __init__(self, data, labels, transform=None):# 初始化数据集,存储数据和标签self.data = dataself.labels = labelsself.transform = transform# 确保数据和标签的长度相同assert len(self.data) == len(self.labels), "Data and labels must have the same length"def __len__(self):# 返回数据集的大小return len(self.data)def __getitem__(self, idx):# 根据索引获取数据和标签sample = self.data[idx]label = self.labels[idx]# 如果定义了转换,则应用转换if self.transform:sample = self.transform(sample)return sample, label
  1. 准备数据和标签
    在创建CustomDataset实例之前,你需要准备好数据和标签。这些数据可以是图像、文本、数值等,具体取决于你的任务。
python"># 假设你有一些数据和标签(这里只是示例)
data = [torch.randn(3, 32, 32) for _ in range(100)]  # 100个3x32x32的随机图像
labels = [torch.tensor(i % 2) for i in range(100)]   # 100个标签,0或1
  1. 创建数据集实例
    使用你准备好的数据和标签来创建CustomDataset的实例。
python">dataset = CustomDataset(data, labels)
  1. (可选)应用转换

如果你需要对数据进行预处理或增强,可以定义一个转换函数,并在创建数据集实例时传递给它。

python"># 定义一个简单的转换函数(例如,将图像数据标准化)
def normalize(sample):return (sample - sample.mean()) / sample.std()# 创建数据集实例时应用转换
dataset = CustomDataset(data, labels, transform=normalize)
  1. 使用DataLoader加载数据
    最后,使用torch.utils.data.DataLoader来加载数据集,以便进行批量处理、打乱数据等。
python">from torch.utils.data import DataLoaderdataloader = DataLoader(dataset, batch_size=32, shuffle=True)# 现在你可以遍历dataloader来加载数据了
for batch_data, batch_labels in dataloader:# 在这里进行模型训练或评估pass

注意事项

  • 确保你的数据和标签是可索引的,通常它们应该是列表、NumPy数组或PyTorch张量。
  • 如果你的数据是图像,并且存储在文件系统中,你可能需要在__getitem__方法中实现图像读取和预处理逻辑。
  • 对于大型数据集,考虑使用torchvision.datasets中提供的预定义数据集类,它们通常包含了常见的图像数据集(如CIFAR、MNIST等)的加载逻辑。
  • 如果数据集太大无法全部加载到内存中,你可以考虑使用torch.utils.data.IterableDataset来创建一个可迭代的数据集,这样你就可以按需加载数据了。

http://www.ppmy.cn/devtools/152089.html

相关文章

基于海思soc的智能产品开发(高、中、低soc、以及和fpga的搭配)

【 声明:版权所有,欢迎转载,请勿用于商业用途。 联系信箱:feixiaoxing 163.com】 市场上关于图像、音频的soc其实非常多,这里面有高、中、低档,开发方式也不相同。之所以会这样,有价格的因素&am…

AF3 MSAModule类源码解读

AlphaFold3 中的MSAModule 类 是一个用于处理多序列比对(MSA)的模块,核心功能是通过 MSAModuleBlock 堆叠和梯度检查点优化,实现对 MSA 表征和配对表征的高效计算。调用该类最终返回更新后的配对表征z,更新后的z含有MSA特征和目的蛋白质序列信息。 源代码: class MSAMo…

自制游戏——国争

自制小游戏,分享给大家 //0——步兵(k) //1——弓箭手(k) //2——炮兵(k) //3——土地(平方公里) //4——能量(t) //5——钱(元宝…

【RK3588 docker编译问题】

问题集合 问题1: 编译lunch出现问题 12:31:21 Build sandboxing disabled due to nsjail error. 12:31:22 Build sandboxing disabled due to nsjail error. In file included from build/make/core/config.mk:313: In file included from build/make/core/envset…

从零创建一个 Django 项目

1. 准备环境 在开始之前,确保你的开发环境满足以下要求: 安装了 Python (推荐 3.8 或更高版本)。安装 pip 包管理工具。如果要使用 MySQL 或 PostgreSQL,确保对应的数据库已安装。 创建虚拟环境 在项目目录中创建并激活虚拟环境&#xff…

Java基础--类和对象

目录 什么是类? 什么是对象 为什么java会设计对象 Java对象该怎么用 程序执行流程 类的加载顺序 什么是类? 类是构建对象的模板,一个类可以创建多个对象,每个对象的数据的最初来源来自对象 public class Student{public in…

Python从0到100(八十四):神经网络-卷积神经网络训练CIFAR-10数据集

前言: 零基础学Python:Python从0到100最新最全教程。 想做这件事情很久了,这次我更新了自己所写过的所有博客,汇集成了Python从0到100,共一百节课,帮助大家一个月时间里从零基础到学习Python基础语法、Pyth…

Vue.js组件开发-解决PDF签章预览问题

在Vue.js组件开发中,解决PDF签章预览问题可能涉及多个方面,包括选择合适的PDF预览库、配置PDF.js(或其封装库如vue-pdf)以正确显示签章、以及处理可能的兼容性和性能问题。 步骤和建议: 1. 选择合适的PDF预览库 ‌vu…