深度学习数据集定义与加载

news/2024/11/29 8:44:07/

深度学习数据集定义与加载
深度学习模型在训练时需要大量的数据来完成模型调优,这个过程均是数字的计算,无法直接使用原始图片和文本等来完成计算。因此与需要对原始的各种数据文件进行处理,转换成深度学习模型可以使用的数据类型。
一、框架自带数据集
飞桨框架将深度学习任务中常用到的数据集作为领域API开放,对应API所在目录为paddle.vision.datasets与paddle.text.datasets,可以通过以下代码飞桨框架中提供了哪些数据集。
import paddle
print(‘视觉相关数据集:’, paddle.vision.datasets.all)
print(‘自然语言相关数据集:’, paddle.text.datasets.all)
视觉相关数据集: [‘DatasetFolder’, ‘ImageFolder’, ‘MNIST’, ‘FashionMNIST’, ‘Flowers’, ‘Cifar10’, ‘Cifar100’, ‘VOC2012’]
自然语言相关数据集: [‘Conll05st’, ‘Imdb’, ‘Imikolov’, ‘Movielens’, ‘UCIHousing’, ‘WMT14’, ‘WMT16’]
警告
除paddle.vision.dataset与paddle.text.dataset外,飞桨框架还内置了另一套数据集,路径为paddle.dataset.*,但是该数据集的使用方式较老,会在未来的版本废弃,尽量不要使用该目录下数据集的API。
这里可以定义手写数字体的数据集,其它数据集的使用方式也都类似。用mode来标识训练集与测试集。数据集接口会自动从远端下载数据集到本机缓存目录~/.cache/paddle/dataset。
from paddle.vision.transforms import ToTensor

训练数据集 用ToTensor将数据格式转为Tensor

train_dataset = paddle.vision.datasets.MNIST(mode=‘train’, transform=ToTensor())

验证数据集

val_dataset = paddle.vision.datasets.MNIST(mode=‘test’, transform=ToTensor())
二、自定义数据集
在实际的场景中,更多需要使用已有的相关数据来定义数据集。可以使用飞桨提供的paddle.io.Dataset基类,来快速实现自定义数据集。
import paddle
from paddle.io import Dataset

BATCH_SIZE = 64
BATCH_NUM = 20

IMAGE_SIZE = (28, 28)
CLASS_NUM = 10

class MyDataset(Dataset):
“”"
步骤一:继承paddle.io.Dataset类
“”"
def init(self, num_samples):
“”"
步骤二:实现构造函数,定义数据集大小
“”"
super(MyDataset, self).init()
self.num_samples = num_samples

def __getitem__(self, index):"""步骤三:实现__getitem__方法,定义指定index时如何获取数据,并返回单条数据(训练数据,对应的标签)"""data = paddle.uniform(IMAGE_SIZE, dtype='float32')label = paddle.randint(0, CLASS_NUM-1, dtype='int64')return data, labeldef __len__(self):"""步骤四:实现__len__方法,返回数据集总数目"""return self.num_samples

测试定义的数据集

custom_dataset = MyDataset(BATCH_SIZE * BATCH_NUM)

print(’=custom dataset=’)
for data, label in custom_dataset:
print(data.shape, label.shape)
break
=custom dataset=
[28, 28] [1]
通过以上的方式,就可以根据实际场景,构造自己的数据集。
三、数据加载
飞桨推荐使用paddle.io.DataLoader完成数据的加载。简单的示例如下:
train_loader = paddle.io.DataLoader(custom_dataset, batch_size=BATCH_SIZE, shuffle=True)

如果要加载内置数据集,将 custom_dataset 换为 train_dataset 即可

for batch_id, data in enumerate(train_loader()):
x_data = data[0]
y_data = data[1]

print(x_data.shape)
print(y_data.shape)
break

[64, 28, 28]
[64, 1]
通过上述的方法,就定义了一个数据迭代器train_loader, 用于加载训练数据。通过batch_size=64设置了数据集的批大小为64,通过shuffle=True,在取数据前会打乱数据。此外,还可以通过设置num_workers来开启多进程数据加载,提升加载速度。
注解
DataLoader 默认用异步加载数据的方式来读取数据,一方面可以提升数据加载的速度,另一方面也会占据更少的内存。如果需要同时加载全部数据到内存中,设置use_buffer_reader=False。


http://www.ppmy.cn/news/609391.html

相关文章

为什么静态方法无法直接调用非静态成员变量和方法

静态方法无法直接调用非静态成员变量和方法 看到这句话,要想到形容的是这样的如下 静态方法里面无法调用非静态变量 下面在写一个对比非静态的方法和静态方法调用变量对比 问题原因 静态变量和静态的方法是属于类,不属于对象,调用的时候不需要实例化(当然如果你非要实例化之后…

latex用法总结

画彩色直线 $\textcolor[rgb]{1,0,0}{\rule[1.5pt]{0.5cm}{0.2em}}$表格和图在同一行 \begin{figure*}\begin{minipage}{0.63\linewidth}\includegraphics[width1.0\hsize]{HPatches_curve.pdf}\end{minipage}\hfill\begin{minipage}{0.34\linewidth}\tiny\renewcommand\arra…

大数据Spark(十六):Spark Core的RDD算子练习

文章目录 RDD算子练习 map 算子 filter 算子 flatMap 算子

适定、超定和欠定方程的概念

矩阵的每一行代表一个方程&#xff0c;m行代表m个线性联立方程。 n列代表n个变量。如果m是独立方程数&#xff0c;根据m<n、mn、m>n确定方程是 ‘欠定’、‘适定’ 还是 ‘超定’。 超定方程组&#xff1a;方程个数大于未知量个数的方程组。 对于方程组Ray&#xff0c;R为…

深度学习数据预处理

深度学习数据预处理 训练过程中有时会遇到过拟合的问题&#xff0c;其中一个解决方法就是对训练数据做增强&#xff0c;对数据进行处理得到不同的图像&#xff0c;从而泛化数据集。数据增强API是定义在领域目录的transofrms下&#xff0c;这里介绍两种使用方式&#xff0c;一种…

大数据Spark(十七):Spark Core的RDD持久化

文章目录 RDD 持久化 引入 API 缓存/持久化函数 缓存/持久化级别

【Android学习笔记】Android布局属性大全

第一类:属性值为true或false android:layout_centerHrizontal 水平居中 android:layout_centerVertical 垂直居中 android:layout_centerInparent 相对于父元素完全居中 android:layout_alignParentBottom 贴紧父元素的下边缘 android:layout_alignParentLeft 贴紧父元素的…

Android 内容观察者 ContentObserver 类

ContentObserver——内容观察者&#xff0c;目的是观察特定Uri引起的数据库的变化 这个是官方的文档,将的也是比较少 一般使用分为四步, 1、创建内容观察者 ContentObserver 2、注册监听 registerContentObserver 3、刷新数据库改变 onChange 4、注销监听 unregisterConten…