03_Datase和DataLoader 之间的配合

news/2025/1/15 7:42:36/

Datase和 DataLoader 之间的配合,用于构建数据集,以及训练过程中,读取数据的过程;

1. Dataset和DataLoader

1.1 各自的作用

Dataset: 定义了数据集的内容,它相当于一个类似列表的数据结构,具有确定的长度,能够用索引获取数据集中的元素.

在绝大部分情况下,用户只需在Dataset的子类中实现__len__方法和__getitem__方法,就可以轻松构建自己的数据集,并用默认数据管道进行加载。

DataLoader: 定义了按batch加载数据集的方法,它是一个实现了__iter__方法的可迭代对象,每次迭代输出一个batch的数据。

DataLoader:

  1. 控制batch的大小,
  2. batch中元素的采样方法,
  3. 规定 collate_fn()方法:将batch结果整理成模型所需输入形式的方法,并且能够使用多进程读取数据。

1.2 获取一个batch 时步骤

获取一个batch数据的步骤:

让我们考虑一下从一个数据集中获取一个batch的数据需要哪些步骤。

(假定数据集的特征和标签分别表示为张量X和Y,数据集可以表示为(X,Y), 假定batch大小为m)

  1. 首先我们要确定数据集的长度n。结果类似:n = 1000。

  2. 然后我们从0到n-1的范围中抽样出m个数(batch大小)。假定m=4, 拿到的结果是一个列表,类似:indices = [1,4,8,9]

  3. 接着我们从数据集中去取这m个数对应下标的元素。拿到的结果是一个元组列表,类似:samples = [(X[1],Y[1]),(X[4],Y[4]),(X[8],Y[8]),(X[9],Y[9])]

  4. 最后我们将结果整理成两个张量作为输出。

拿到的结果是两个张量,类似batch = (features,labels) ,

其中 features = torch.stack([X[1],X[4],X[8],X[9]])

labels = torch.stack([Y[1],Y[4],Y[8],Y[9]])

1.3 完成上述步骤,Dataset与DataLoader的分工

  • 上述第1个步骤确定数据集的长度是由 Dataset的__len__ 方法实现的。

  • 第2个步骤从0到n-1的范围中抽样出m个数的方法是由 DataLoader的 sampler和 batch_sampler参数指定的。

sampler参数指定单个元素抽样方法,一般无需用户设置,程序默认在DataLoader的参数shuffle=True时采用随机抽样,shuffle=False时采用顺序抽样。

batch_sampler参数将多个抽样的元素整理成一个列表,一般无需用户设置,默认方法在DataLoader的参数drop_last=True时会丢弃数据集最后一个长度不能被batch大小整除的批次,在drop_last=False时保留最后一个批次。

  • 第3个步骤的核心逻辑根据下标取数据集中的元素 是由 Dataset的 __getitem__方法实现的。

  • 第4个步骤的逻辑由DataLoader的参数collate_fn指定。一般情况下也无需用户设置。

1.4 Dataset和DataLoader的一般使用方式

如下:

import torch   
from torch.utils.data import TensorDataset,Dataset,DataLoader  
from torch.utils.data import RandomSampler,BatchSampler   ds = TensorDataset(torch.randn(1000,3),  torch.randint(low=0,high=2,size=(1000,)).float())  
dl = DataLoader(ds,batch_size=4,drop_last = False)  
features,labels = next(iter(dl))  
print("features = ",features )  
print("labels = ",labels )

将DataLoader内部调用方式步骤拆解如下:

# step1: 确定数据集长度 (Dataset的 __len__ 方法实现)  
ds = TensorDataset(torch.randn(1000,3),  torch.randint(low=0,high=2,size=(1000,)).float())  
print("n = ", len(ds)) # len(ds)等价于 ds.__len__()  # step2: 确定抽样indices (DataLoader中的 Sampler和BatchSampler实现)  
sampler = RandomSampler(data_source = ds)  
batch_sampler = BatchSampler(sampler = sampler,   batch_size = 4, drop_last = False)  
for idxs in batch_sampler:  indices = idxs  break   
print("indices = ",indices)  # step3: 取出一批样本batch (Dataset的 __getitem__ 方法实现)  
batch = [ds[i] for i in  indices]  #  ds[i] 等价于 ds.__getitem__(i)  
print("batch = ", batch)  # step4: 整理成features和labels (DataLoader 的 collate_fn 方法实现)  
def collate_fn(batch):  features = torch.stack([sample[0] for sample in batch])  labels = torch.stack([sample[1] for sample in batch])  return features,labels   features,labels = collate_fn(batch)  
print("features = ",features)  
print("labels = ",labels) 

1.5 Dataset和DataLoader的核心源码

以下是 Dataset和 DataLoader的核心源码,省略了为了提升性能而引入的诸如多进程读取数据相关的代码。

import torch   
class Dataset(object):  def __init__(self):  pass  def __len__(self):  raise NotImplementedError  def __getitem__(self,index):  raise NotImplementedError  class DataLoader(object):  def __init__(self,dataset,batch_size,collate_fn = None,shuffle = True,drop_last = False):  self.dataset = dataset  self.sampler =torch.utils.data.RandomSampler if shuffle else \  torch.utils.data.SequentialSampler  self.batch_sampler = torch.utils.data.BatchSampler  self.sample_iter = self.batch_sampler(  self.sampler(self.dataset),  batch_size = batch_size,drop_last = drop_last)  self.collate_fn = collate_fn if collate_fn is not None else \  torch.utils.data._utils.collate.default_collate  def __next__(self):  indices = next(iter(self.sample_iter))  batch = self.collate_fn([self.dataset[i] for i in indices])  return batch  def __iter__(self):  return self 

我们来测试一番

class ToyDataset(Dataset):  def __init__(self,X,Y):  self.X = X  self.Y = Y   def __len__(self):  return len(self.X)  def __getitem__(self,index):  return self.X[index],self.Y[index]  X,Y = torch.randn(1000,3),torch.randint(low=0,high=2,size=(1000,)).float()  
ds = ToyDataset(X,Y)  dl = DataLoader(ds,batch_size=4,drop_last = False)  
features,labels = next(iter(dl))   
print("features = ",features )  
print("labels = ",labels )

2. Dataset: 用于数据集的创建

Dataset创建数据集常用的方法有:

  • 使用 torch.utils.data.TensorDataset 根据Tensor创建数据集(numpy的array,Pandas的DataFrame需要先转换成Tensor)。
  • 使用 torchvision.datasets.ImageFolder 根据图片目录创建图片数据集。
  • 继承 torch.utils.data.Dataset 创建自定义数据集。

此外,还可以通过:

  • torch.utils.data.random_split 将一个数据集分割成多份,常用于分割训练集,验证集和测试集。
  • 调用Dataset的加法运算符(+)将多个数据集合并成一个数据集。

2.1 根据 Tensor 创建数据集

import numpy as np   
import torch   
from torch.utils.data import TensorDataset,Dataset,DataLoader,random_split   # 根据Tensor创建数据集  from sklearn import datasets   
iris = datasets.load_iris()  
ds_iris = TensorDataset(torch.tensor(iris.data),torch.tensor(iris.target))  # 分割成训练集和预测集  
n_train = int(len(ds_iris)*0.8)  
n_val = len(ds_iris) - n_train  
ds_train,ds_val = random_split(ds_iris,[n_train,n_val])  print(type(ds_iris))  
print(type(ds_train)) # 使用DataLoader加载数据集  
dl_train,dl_val = DataLoader(ds_train,batch_size = 8),DataLoader(ds_val,batch_size = 8)  for features,labels in dl_train:  print(features,labels)  break  # 演示加法运算符(`+`)的合并作用  ds_data = ds_train + ds_val  print('len(ds_train) = ',len(ds_train))  
print('len(ds_valid) = ',len(ds_val))  
print('len(ds_train+ds_valid) = ',len(ds_data))  print(type(ds_data))  

2.2 根据图片目录创建图片数据集

import numpy as np   
import torch   
from torch.utils.data import DataLoader  
from torchvision import transforms,datasets   #演示一些常用的图片增强操作  from PIL import Image  
img = Image.open('./data/cat.jpeg')  
img  # 随机数值翻转  
transforms.RandomVerticalFlip()(img)  #随机旋转  
transforms.RandomRotation(45)(img)  # 定义图片增强操作  transform_train = transforms.Compose([  transforms.RandomHorizontalFlip(), #随机水平翻转  transforms.RandomVerticalFlip(), #随机垂直翻转  transforms.RandomRotation(45),  #随机在45度角度内旋转  transforms.ToTensor() #转换成张量  ]  
)   transform_valid = transforms.Compose([  transforms.ToTensor()  ]  
)  # 根据图片目录创建数据集  def transform_label(x):  return torch.tensor([x]).float()  ds_train = datasets.ImageFolder("./eat_pytorch_datasets/cifar2/train/",  transform = transform_train,target_transform= transform_label)  
ds_val = datasets.ImageFolder("./eat_pytorch_datasets/cifar2/test/",  transform = transform_valid,  target_transform= transform_label)  print(ds_train.class_to_idx)  # 使用DataLoader加载数据集  dl_train = DataLoader(ds_train,batch_size = 50,shuffle = True)  
dl_val = DataLoader(ds_val,batch_size = 50,shuffle = True)  for features,labels in dl_train:  print(features.shape)  print(labels.shape)  break  

2.3 创建自定义数据集

下面我们通过另外一种方式,即继承 torch.utils.data.Dataset 创建自定义数据集的方式来对 cifar2构建 数据管道。

from pathlib import Path   
from PIL import Image   class Cifar2Dataset(Dataset):  def __init__(self,imgs_dir,img_transform):  self.files = list(Path(imgs_dir).rglob("*.jpg"))  self.transform = img_transform  def __len__(self,):  return len(self.files)  def __getitem__(self,i):  file_i = str(self.files[i])  img = Image.open(file_i)  tensor = self.transform(img)  label = torch.tensor([1.0]) if  "1_automobile" in file_i else torch.tensor([0.0])  return tensor,label   train_dir = "./eat_pytorch_datasets/cifar2/train/"  
test_dir = "./eat_pytorch_datasets/cifar2/test/"  # 定义图片增强  
transform_train = transforms.Compose([  transforms.RandomHorizontalFlip(), #随机水平翻转  transforms.RandomVerticalFlip(), #随机垂直翻转  transforms.RandomRotation(45),  #随机在45度角度内旋转  transforms.ToTensor() #转换成张量  ]  
)   transform_val = transforms.Compose([  transforms.ToTensor()  ]  
)  ds_train = Cifar2Dataset(train_dir,transform_train)  
ds_val = Cifar2Dataset(test_dir,transform_val)  dl_train = DataLoader(ds_train,batch_size = 50,shuffle = True)  
dl_val = DataLoader(ds_val,batch_size = 50,shuffle = True)  for features,labels in dl_train:  print(features.shape)  print(labels.shape)  break  

3. DataLoader: 用于加载数据

DataLoader能够控制batch的大小,batch中元素的采样方法,以及将batch结果整理成模型所需输入形式的方法,并且能够使用多进程读取数据。

DataLoader的函数签名如下:

DataLoader(  dataset,  batch_size=1,  shuffle=False,  sampler=None,  batch_sampler=None,  num_workers=0,  collate_fn=None,  pin_memory=False,  drop_last=False,  timeout=0,  worker_init_fn=None,  multiprocessing_context=None,  
)  

一般情况下,我们仅仅会配置 dataset, batch_size, shuffle, num_workers,pin_memory, drop_last这六个参数,

有时候对于一些复杂结构的数据集,还需要自定义collate_fn函数,其他参数一般使用默认值即可。

DataLoader除了可以加载我们前面讲的 torch.utils.data.Dataset 外,还能够加载另外一种数据集 torch.utils.data.IterableDataset。

和Dataset数据集相当于一种列表结构不同,IterableDataset相当于一种迭代器结构。它更加复杂,一般较少使用。

dataset : 数据集
batch_size: 批次大小
shuffle: 是否乱序
sampler: 样本采样函数,一般无需设置。
batch_sampler: 批次采样函数,一般无需设置。
num_workers: 使用多进程读取数据,设置的进程数。
collate_fn: 整理一个批次数据的函数。
pin_memory: 是否设置为锁业内存。默认为False,锁业内存不会使用虚拟内存(硬盘),从锁业内存拷贝到GPU上速度会更快。
drop_last: 是否丢弃最后一个样本数量不足batch_size批次数据。
timeout: 加载一个数据批次的最长等待时间,一般无需设置。
worker_init_fn: 每个worker中dataset的初始化函数,常用于 IterableDataset。一般不使

#构建输入数据管道  
ds = TensorDataset(torch.arange(1,50))  
dl = DataLoader(ds,  batch_size = 10,  shuffle= True,  num_workers=2,  drop_last = True)  
#迭代数据  
for batch, in dl:  print(batch)  tensor([43, 44, 21, 36,  9,  5, 28, 16, 20, 14])  
tensor([23, 49, 35, 38,  2, 34, 45, 18, 15, 40])  
tensor([26,  6, 27, 39,  8,  4, 24, 19, 32, 17])  
tensor([ 1, 29, 11, 47, 12, 22, 48, 42, 10,  7])  

reference:

https://zhuanlan.zhihu.com/p/560502810


http://www.ppmy.cn/news/189185.html

相关文章

软考A计划-电子商务设计师-专业英语

点击跳转专栏>Unity3D特效百例点击跳转专栏>案例项目实战源码点击跳转专栏>游戏脚本-辅助自动化点击跳转专栏>Android控件全解手册点击跳转专栏>Scratch编程案例 👉关于作者 专注于Android/Unity和各种游戏开发技巧,以及各种资源分享&am…

研报精选230424

目录 【行业230424华福证券】功率半导体行业深度报告:能源变革大时代,功率器件大市场 【行业230424华西证券】海外锂资源企业近况总结之锂辉石篇:2023年海外锂资源供应量同比增长45%,H2比H1增加13万吨LCE供应 【行业230424东吴证券…

07年25大IT创新

12月29日消息&#xff0c;据国外媒体报道&#xff0c;PC World日前评出了2007年25大创新<nobr><strong class"kgb" οnmοuseοverisShowAds true;isShowAds2 true;KeyGate_ads.Move(this,"","","-100836","产品"…

谁在使用Qt

关于“谁在使用Qt&#xff0c;那些产品是采用Qt开发的&#xff1f;”这个问题是很多Qt的使用者和学习者所关心的问题&#xff0c;这个问题关系到Qt的应用范围&#xff0c;Qt影响力&#xff0c;Qt的魅力等等。2010 Qt开发者大会茶歇的时候大家咨询过齐亮的一个问题。这个问题也同…

【经典箴言 || 人生感悟 】

人的一生中只有七次机会&#xff0c;平均每七年拥有一次&#xff0c;大概在25岁到75岁&#xff0c;第一次通常抓不到&#xff0c;因为太年轻&#xff0c;最后一次也抓不到&#xff0c;因为太老。中途还有2次因为自己错过&#xff0c;所以抓不到。所以对于人来说人真正才会有三次…

经典箴言和人生感悟

人的一生中只有七次机会&#xff0c;平均每七年拥有一次&#xff0c;大概在25岁到75岁&#xff0c;第一次通常抓不到&#xff0c;因为太年轻&#xff0c;最后一次也抓不到&#xff0c;因为太老。中途还有2次因为自己错过&#xff0c;所以抓不到。所以对于人来说人真正才会有三次…

ProGrade Digital宣布推出SDXC UHS-II V90存储卡

V90是数据密集型6K、4K、动态JPEG、高比特率MPEG-4视频采集的关键 此外&#xff0c;公司还宣布CFast & SD格式的USB 3.1, Gen.2双槽读卡器已开始发货 拉斯维加斯--(美国商业资讯)--美国广播电视展(NAB)—ProGrade Digital是一家以提供最优质专业级数字存储卡和工作流程解决…

51job爬取职位搜索下面的2000条职位信息

打了这么久的酱油&#xff0c;终于自己独立完成了网站信息的爬取&#xff0c;记录一下。 要求&#xff1a; https://search.51job.com/list/020000%252C00,000000,0000,00,9,99,%2B,2,1.html?langc&stype1&postchannel0000&workyear99&cotype99&degreef…