Pytorch中的torch.utils.data.Dataset 类

server/2025/3/24 23:30:28/

1、使用方法

python">from torch.utils.data import Dataset

2、torch.utils.data.Dataset 类的定义

python">class Dataset(Generic[_T_co]):r"""An abstract class representing a :class:`Dataset`.All datasets that represent a map from keys to data samples should subclassit. All subclasses should overwrite :meth:`__getitem__`, supporting fetching adata sample for a given key. Subclasses could also optionally overwrite:meth:`__len__`, which is expected to return the size of the dataset by many:class:`~torch.utils.data.Sampler` implementations and the default optionsof :class:`~torch.utils.data.DataLoader`. Subclasses could alsooptionally implement :meth:`__getitems__`, for speedup batched samplesloading. This method accepts list of indices of samples of batch and returnslist of samples... note:::class:`~torch.utils.data.DataLoader` by default constructs an indexsampler that yields integral indices.  To make it work with a map-styledataset with non-integral indices/keys, a custom sampler must be provided."""def __getitem__(self, index) -> _T_co:raise NotImplementedError("Subclasses of Dataset should implement __getitem__.")# def __getitems__(self, indices: List) -> List[_T_co]:# Not implemented to prevent false-positives in fetcher check in# torch.utils.data._utils.fetch._MapDatasetFetcherdef __add__(self, other: "Dataset[_T_co]") -> "ConcatDataset[_T_co]":return ConcatDataset([self, other])# No `def __len__(self)` default?# See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]# in pytorch/torch/utils/data/sampler.py

原文解释:

  1. 所有表示从键到数据样本映射的数据集都应该继承自这个类。
    这意味着,如果你有一个数据集,它通过某些键(可能是整数、字符串等)来访问数据样本,那么你应该从 Dataset 类继承来创建你的数据集类。

  2. 所有的子类都应该重写 __getitem__ 方法,支持通过给定的键获取数据样本。
    __getitem__ 是 Python 的特殊方法,用于通过 dataset[key] 这样的语法来获取数据。在你的子类中,你需要实现这个方法,确保它能够返回与给定键对应的数据样本。

  3. 子类也可以选择性地重写 __len__ 方法,该方法通常被许多 Sampler 实现和 DataLoader 的默认选项所使用,用于返回数据集的大小。
    __len__ 方法用于返回数据集中样本的总数。虽然它不是强制要求的,但如果你希望使用 PyTorch 的 Sampler 或 DataLoader,通常需要实现这个方法。

  4. 子类还可以选择性地实现 __getitems__ 方法,以加速批量数据加载。这个方法接受一个包含批次样本索引的列表,并返回一个样本列表。
    __getitems__ 是一个可选的优化方法。如果你需要批量加载数据,实现这个方法可以提高效率。它接受一个索引列表,并返回对应的样本列表。

  5. DataLoader 默认构造一个生成整数索引的采样器(sampler)。要使它能够与具有非整数索引/键的 map-style 数据集一起工作,必须提供一个自定义的采样器。
    DataLoader 默认情况下假设你的数据集是可以通过整数索引访问的(即 dataset[0]dataset[1] 等)。如果你的数据集使用非整数键(比如字符串或其他类型),你需要提供一个自定义的采样器来生成这些键。

3、示例

示例 1:简单的整数索引数据集

 假设我们有一个数据集,数据储存在一个列表中,我们可以通过整数索引来访问。

python">from torch.utils.data import Datasetclass SimpleDataset(Dataset):def __init__(self, data):self.data = datadef __getitem__(self, index):return self.data[index]def __len__(self):return len(self.data)# 示例数据
data = [1, 2, 3, 4, 5]# 创建数据集
dataset = SimpleDataset(data)# 使用
print(dataset[0])  # 输出:1
print(len(dataset))  # 输出:5

示例 2:字符串键的数据集

假设我们有一个数据集,数据以字典形式存储,键是字符串。

python">from torch.utils.data import Datasetclass StringKeyDataset(Dataset):def __init__(self, data):self.data = dataself.keys = list(data.keys())def __getitem__(self, key):return self.data[key]def __len__(self):return len(self.keys)# 示例数据
data = {"a": 1, "b": 2, "c": 3}# 创建数据集
dataset = StringKeyDataset(data)# 使用
print(dataset["a"])  # 输出:1
print(len(dataset))  # 输出:3

注意:如果需要与 DataLoader 一起使用,必须提供一个自定义的采样器,因为默认的采样器生成整数索引。

示例 3:实现 __getitems__ 方法

为了实现批量加载数据,我们可以实现 __getitems__ 方法。

python">from torch.utils.data import Datasetclass BatchableDataset(Dataset):def __init__(self, data):self.data = datadef __getitem__(self, index):return self.data[index]def __getitems__(self, indices):return [self.data[i] for i in indices]def __len__(self):return len(self.data)# 示例数据
data = [10, 20, 30, 40, 50]# 创建数据集
dataset = BatchableDataset(data)# 使用
print(dataset[0])  # 输出:10
print(dataset.__getitems__([1, 3]))  # 输出:[20, 40]

示例 4:图像数据集

假设我们有一个图像数据集,图像路径存储在列表中。

python">from torch.utils.data import Dataset
from PIL import Image
import osclass ImageDataset(Dataset):def __init__(self, img_dir):self.img_dir = img_dirself.img_names = os.listdir(img_dir)def __getitem__(self, index):img_name = self.img_names[index]img_path = os.path.join(self.img_dir, img_name)image = Image.open(img_path)return imagedef __len__(self):return len(self.img_names)# 创建数据集
img_dir = "path/to/images"
dataset = ImageDataset(img_dir)# 使用
print(len(dataset))  # 输出图像数量
print(dataset[0])  # 输出第一张图像

示例 5:自定义采样器

如果你的数据集使用非整数键(如字符串),并且你想与 DataLoader 一起使用,可以定义一个自定义采样器。

python">from torch.utils.data import Dataset, DataLoader, Sampler
import randomclass StringKeyDataset(Dataset):def __init__(self, data):self.data = dataself.keys = list(data.keys())def __getitem__(self, key):return self.data[key]def __len__(self):return len(self.keys)class StringSampler(Sampler):def __init__(self, keys):self.keys = keys#每次调用时(如新的epoch开始),先打乱键的顺序,再返回迭代器。#实现数据加载时的随机化顺序。def __iter__(self):random.shuffle(self.keys)return iter(self.keys)def __len__(self):return len(self.keys)# 示例数据
data = {"a": 1, "b": 2, "c": 3}# 创建数据集和采样器
dataset = StringKeyDataset(data)
sampler = StringSampler(dataset.keys)# 使用 DataLoader
dataloader = DataLoader(dataset, sampler=sampler, batch_size=2)for batch in dataloader:print(batch)  # 输出批次数据


http://www.ppmy.cn/server/177104.html

相关文章

油候插件、idea、VsCode插件推荐(自用)

开发软件: 之前的文章: 开发必装最实用工具软件与网站 推荐一下我使用的开发工具 目前在用的 油候插件 AC-baidu-重定向优化百度搜狗谷歌必应搜索_favicon_双列 让查询变成多列,而且可以流式翻页 Github 增强 - 高速下载 github下载 TimerHo…

基于ArcGIS和ETOPO-2022 DEM数据分层绘制全球海陆分布

第〇部分 前言 一幅带有地理空间参考、且包含海陆分布的DEM图像在研究区的绘制中非常常见,本文将实现以下图像的绘制 关键步骤: (1)NOAA-NCEI官方下载最新的ETOPO-2022 DEM数据 (2)在ArcGIS(…

JavaIO流的使用和修饰器模式(直击心灵版)

系列文章目录 JavaIO流的使用和修饰器模式 文章目录 系列文章目录前言一、字节流: 1.FileInputStream(读取文件)2.FileOutputStream(写入文件) 二、字符流: 1..基础字符流:2.处理流:3.对象处理流:4.转换流: 三、修饰器…

【一起学Rust | Tauri2.0框架】基于 Rust 与 Tauri 2.0 框架实现全局状态管理

前言 在现代应用程序开发中,状态管理是构建复杂且可维护应用的关键。随着应用程序规模的增长,组件之间共享和同步状态变得越来越具有挑战性。如果处理不当,状态管理可能会导致代码混乱、难以调试,并最终影响应用程序的性能和可扩…

Lineageos 22.1(Android 15)实现负一屏

一、前言 方案是参考的这位大佬的,大家可以去付费订阅支持一波。我大概理一下Android15的修改。 大佬的方案代码 二、Android15适配调整 1.bp调整,加入aidl引入,这样make之后就可以索引代码了 filegroup {name: "launcher-src"…

arp -a命令输出详解

一、arp -a输出 C:\WINDOWS\system32>arp -a接口: 169.254.199.84 --- 0x2Internet 地址 物理地址 类型169.254.255.255 ff-ff-ff-ff-ff-ff 静态224.0.0.2 01-00-5e-00-00-02 静态224.0.0.22 01-00-5e-00-00-16…

如何在SQL中高效使用聚合函数、日期函数和字符串函数:实用技巧与案例解析

文章目录 聚合函数group by子句的使用实战OJ日期函数字符串函数数学函数其它函数 聚合函数 函数说明COUNT([DISTINCT] expr)返回查询到的数据的 数量SUM([DISTINCT] expr)返回查询到的数据的 总和,不是数字没有意义AVG([DISTINCT] expr)返回查询到的数据的 平均值&…

批量删除 PPT 空白幻灯片页面

如果我们需要删除 PPT 文档中的空白幻灯片页面,我们可以借助 Office 工具来完成,但是如果是大量的 PPT 文档需要批量删除空白幻灯片页面,那就需要使用专业的批量处理工具来完成,今天就给大家介绍一种批量删除 PPT 空白幻灯片页面的…