文章目录
- Dataset 类
- DataLoader 类
Dataset 类
概念:
Dataset
是一个抽象类,用于表示数据集。它定义了如何获取数据集中的单个样本和标签。
作用:
- 为数据集提供统一的接口,便于数据的读取、预处理和管理。
关键方法:
__len__(self)
: 返回数据集的大小(样本数量)。__getitem__(self, index)
: 根据索引index
返回对应的样本和标签。
自定义 Dataset:
需要继承 torch.utils.data.Dataset
并实现上述两个方法。
示例(PyTorch):
import torch
from torch.utils.data import Datasetclass Dataset(Dataset):def __init__(self, datas, labels):self.datas = datas # 数据文件路径列表self.labels = labels # 标签列表def __len__(self):return len(self.data)def __getitem__(self, idx):# 加载数据,例如读取图像文件data = self.data[idx]label = self.labels[idx]# 一系列的处理return data, label
DataLoader 类
概念:
DataLoader
是一个数据迭代器,用于包装Dataset
,以便于批量(batch)加载数据。
作用:
- 提供批量数据、数据打乱(shuffle)、并行加载(多线程/多进程)等功能,提高数据加载的效率。
关键参数:
dataset
: 要加载的数据集(Dataset
实例)。batch_size
: 每个批次的样本数量。shuffle
: 是否在每个 epoch 开始时打乱数据。num_workers
: 使用多少子进程来加载数据(0
表示不使用多进程)。collate_fn
: 指定如何将一批样本组合成一个批次。
工作流程:
- 从
Dataset
中按索引取出样本。 - 使用
collate_fn
将多个样本组合成一个批次。 - 迭代返回批量数据供模型训练或评估。
示例(PyTorch):
from torch.utils.data import DataLoader# 创建 Dataset 实例
dataset = MyDataset(datas, labels)# 创建 DataLoader 实例
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)