DataLoader,DataSet和Sampler

news/2024/11/27 20:58:20/

DataLoader、DataSet和Sampler之间的关系

Sample和DataSet是DataLoader的两个子模块。Sampler的功能主要是生成索引。也就是样本的序号
DatasetDatasetDataset是根据索引去读取数据以及对应的标签。DataLoader负责以特定的方式从数据集中迭代的产生一个一个batchbatchbatch集合。其中。DataLoader和Dataset是pytorch中数据读取的核心
(以特定的方式从数据集中迭代产生一个一个的batch集合》

DataLoader

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,batch_sampler=None, num_workers=0, collate_fn=None,pin_memory=False, drop_last=False, timeout=0,worker_init_fn=None)

实例化一个DataLoader所需参数如上所示:其中:

  • dataset: 定义好的MapMapMap式,或者lterablelterablelterable式数据集。
  • batch_size: 一个batch中的样本个数,默认为1.
  • shuffle: 每一个epoch的batch样本是相同还是随机的。
  • sampler: 数据集中的采样方法,如果有,则shuffle参数必须为false。
  • batch_sampler和sampler类似,但是一次返回的是一个batch内所有样本的index。
  • num_workers: 多少个子进程同时工作来获取数据,多线程。
  • collate_fn: 合并样本列表以形成小批量。
  • pin_menory: 如果为True,数据加载器在返回前将张量复制到CUDA固定内存中。
  • drop_last:如果数据集大小不能被batch_size整除,设置为True,可能删除最后一个不完整的批处理。如果设为False。并且数据集大小不能被batch_size整除,则最后一个batch将更小。
  • timeout:如果是正数,表明等待从worker进程中收集一个batch等待时间,若超出设定时间还没有收集到,那就不收集这个内容了,numeric应总是大于等于0.

Dataset

Dataset就是一个负责处理索引(index)到样本(sample)映射的一个类(class).

index→sampleindex \rightarrow sampleindexsample
torch.utils.data.Datasettorch.utils.data.Datasettorch.utils.data.Dataset 是一个表示数据集的抽象的类,任何自定义的数据集都需要继承这个类并腹泻相关方法

pytorch:提供两种数据集:Map式数据集Map式数据集Maplterable式数据集lterable式数据集lterable

Map 数据集

一个Map式的数据集必须要重写getitem(self, index),len(self) 两个内建方法,用来表示从索引到样本的映射(Map)
getitem(self,index),len(self)getitem(self, index),len(self) getitem(self,index),len(self)两个内建方法。

用来表示从索引到样本的映射(Map).
这样一个数据集dataset。举个例子,当使用dataset[idx]dataset[idx]dataset[idx]命令时,可以在你的硬盘中读取数据集中的第idxidxidx张图片以及其标签,
len(dataset)len(dataset)len(dataset):则返回这个数据集的容量。

自定义类结构一般如下:

class CustomDataset(data.Dataset):#需要继承data.Datasetdef __init__(self):# TODO# 1. Initialize file path or list of file names.passdef __getitem__(self, index):# TODO# 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).# 2. Preprocess the data (e.g. torchvision.Transform).# 3. Return a data pair (e.g. image and label).#这里需要注意的是,第一步:read one data,是一个datapassdef __len__(self):# You should change 0 to the total size of your dataset.return 0

getitem最主要的方法是,其规定了如何读取数据,但是又不同于一般的方法,因为它是python built-in方法,其主要作用是能让该类可以像list一样通过索引值对数据进行访问。假定你定义好一个dataset,那么你可以直接通过Dataset[0]来访问第一个数据。

lterable式数据集

一个lterable式数据集,是抽象类:data.IterableDataset的子类,
并且腹泻了iter方法成为一个迭代器。这种数据集主要用于数据大小未知,或者以流的形式输入。本地文件不固定的情况,需要以迭代的方式来获取样本索引

迭代器

迭代器是一个可以记住遍历的位置的对象,迭代器对象从集合的第一个元素开始访问。直到所有的元素被访问完结束,迭代器只能往前而不会后退,迭代器两个基本方法:iter()iter()iter()next()next()next().
iter()iter()iter():方法返回一个特殊的迭代器对象,这个迭代器对象实现了next()方法,并通过$StopIteration $异常标识迭代的完成。
next()next()next() 方法会返回迭代器的输出。

class MyNumbers:def __iter__(self):self.a = 1return selfdef __next__(self):if self.a <= 20:x = self.aself.a += 1return xelse:raise StopIteration
#StopIteration 异常用于标识迭代的完成,防止出现无限循环的情况,在 next() 方法中我们可以设置在完成指定循环次数后触发 StopIteration 异常来结束迭代。myclass = MyNumbers()
myiter = iter(myclass)for x in myiter:print(x)

Sampler

sampler类的源代码主要由三种方法,如下:

class Sampler(object):r"""Base class for all Samplers.Every Sampler subclass has to provide an __iter__ method, providing a wayto iterate over indices of dataset elements, and a __len__ method thatreturns the length of the returned iterators."""# 一个 迭代器 基类def __init__(self, data_source):passdef __iter__(self):raise NotImplementedErrordef __len__(self):raise NotImplementedError
  • init方法:初始化。
  • iter:用来产生迭代索引值,也就是指定每个step需要读取那些数据。
  • len: 用来返回每次迭代器的长度。
    python提供了我们几种采样器,如下:

SequentialSampler

按顺序对数据集采样,其原理首先在初始化的时候,拿到数据集data_source,之后在__iter__方法中首先得到一个和data_source一样长度的range迭代器。每次只返回一个索引值。

RandomSampler

随机采样

SubsetRandomSampler

子集随机采样,用于训练,测试集和验证集合的划分。

WeightedRandomSampler

加权随机采样。

BatchSampler

前面的采样器每次只返回一个索引,但是我们在训练时是对批量数据进行训练。而这样的工作都需要BatchSampler来做。也就是说BatchSampler的作用就是将前面的Sampler采样得到的索引值进行合并,当数量等于一个batch大小后就将这一批的索引值返回

总结

慢慢的将各种采样方法,全部都将其搞定。慢慢的将其研究透彻,研究彻底。都行啦的样子与打算。


http://www.ppmy.cn/news/2637.html

相关文章

黑马程序员14套经典IT教程+面试宝典

很多同学对互联网比较感兴趣 &#xff0c;奈何苦恼不知道如何入门。今天免费给大家分享一波&#xff0c;黑马程序员14套经典IT教程程序员面试宝典&#xff01;涉及Java、前端、Python、大数据、软件测试、UI设计、新媒体短视频等。从厌学到学嗨&#xff0c;你只差一套黑马教程&…

如何自定义SpringBoot中的starter,并且使用它

目录 1 简介 2 规范 2.1 命名 2.2 模块划分 3 示例 1 简介 SpringBoot中的starter是一种非常重要的机制&#xff0c;能够抛弃以前繁琐的配置&#xff0c;将其统一集成进starter&#xff0c;应用者只需要在maven中引入starter依赖&#xff0c;SpringBoot就自动扫描到要加载…

Java线程实现

内容引用自《深入理解Java虚拟机&#xff1a;JVM高级特性与最佳实践&#xff08;第3版&#xff09;周志明》 线程的实现 我们知道&#xff0c;线程是比进程更轻量级的调度执行单位&#xff0c;线程的引入&#xff0c;可以把一个进程的资源分配和 执行调度分开&#xff0c;各个…

大学电子系C++模拟考试

随手附上一些代码&#xff0c;未必是最优解&#xff0c;仅供参考。 加密四位数 【问题描述】 输入一个四位数&#xff0c;将其加密后输出。方法是将该数每一位的数字加9&#xff0c;然后除以10取余作为该位上的新数字&#xff0c;最后将千位上的数字和十位上的数字互换&#…

Hbase和Mysql存储数据量对比

目录 前言 生成数据 转换成hbase能够识别的HFile文件 导入HFile到hbase中 导入数据到Mysql 总结 前言 由于想知道hbase和mysql存储同样的一份数据需要的存储是否一样&#xff0c;故做的一下实验。 生成数据 脚本如下&#xff1a; #!/bin/basharray_brand([1]huawei […

C语言第二十课:实用调试技巧

目录 前言&#xff1a; 一、Bug&#xff1a; 二、调试&#xff1a; 1.调试是什么&#xff1a; 2.调试的基本步骤&#xff1a; 3. Debug 与 Release &#xff1a; 三、在Windows环境下进行调试&#xff1a; 1.调试环境的准备&#xff1a; 2.调试的快捷键&#xff1a; 3.调试…

javascript基础小结(一)

今天突发奇想&#xff0c;想要垂直精学一段时间的javascript&#xff0c;用我的第一次「连载」来记录总结一些知识点吧。 知识点 原始类型的类型转换 类型转换 alert 会自动将任何值都转换为字符串以进行显示。算术运算符会将值转换为数字常用的类型转换&#xff1a;转换为 …

Privacy

For information collected and further processed under this Privacy Policy, the data controller is Toy Games. Toy Games Ltd (“Toy Games”, “owner”, “us”, “our” or “we”) is dedicated to protecting the privacy rights of our games and other services …