torch 的数据加载 Datasets DataLoaders

news/2024/12/22 9:31:32/

点赞收藏关注!
如需要转载,请注明出处!

torch的模型加载有两种方式:
Datasets & DataLoaders

torch本身可以提供两数据加载函数:
torch.utils.data.DataLoader()和torch.utils.data.Dataset()

其中torch.utils.data 是PyTorch提供的一个模块,用于处理和加载数据。该模块提供了一系列工具类和函数,用于创建、操作和批量加载数据集。

加载函数后可以实现数据集代码与模型训练代码分离,以获得更好的可读性和模块化
Dataset定义了抽象的数据集类,用户可以通过继承该类来构建自己的数据集。制作自己的数据集必须要实现三个函数:

  • init()函数在实例化Dataset对象时运行一次
  • len()函数返回数据集中样本的数量
  • getitem()函数的作用是:从给定索引 index,从数据集中加载并返回一个样本并将其转换为张量。
import torch
from torch.utils.data import Datasetclass CreateDataset(Dataset):def __init__(self, data):self.data = datadef __getitem__(self, index):# 根据索引获取样本return self.data[index]def __len__(self):# 返回数据集大小return len(self.data)# 创建数据集对象
data = [[255,255,255],[255,245,235],[225,226,227]]
dataset = CreateDataset(data)# 根据索引获取样本
sample = dataset[1]
print(sample)
# [255,245,235]

数据处理模块其他的功能:

  • TensorDataset: 继承自 Dataset 类,用于将张量数据打包成数据集。它接受多个张量作为输入,并按照第一个输入张量的大小来确定数据集的大小。对 tensor 进行打包,就好像 python 中的 zip 功能。该类通过每一个 tensor 的第一个维度进行索引。因此,该类中的 tensor 第一维度必须相等。
from torch.utils.data import TensorDataset
import torch
from torch.utils.data import DataLoadera = torch.tensor([[11, 22, 33], [44, 55, 66], [77, 88, 99], [11, 22, 33], [44, 55, 66], [77, 88, 99], [11, 22, 33], [44, 55, 66], [77, 88, 99], [11, 22, 33], [44, 55, 66], [77, 88, 99]])
b = torch.tensor([0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2])
train_ids = TensorDataset(a, b)for x_train, y_label in train_ids:print(x_train, y_label)##############################################################################################
#tensor([11, 22, 33]) tensor(0)
#tensor([44, 55, 66]) tensor(1)
#tensor([77, 88, 99]) tensor(2)
#tensor([11, 22, 33]) tensor(0)
#tensor([44, 55, 66]) tensor(1)
#tensor([77, 88, 99]) tensor(2)
#tensor([11, 22, 33]) tensor(0)
#tensor([44, 55, 66]) tensor(1)
#tensor([77, 88, 99]) tensor(2)
#tensor([11, 22, 33]) tensor(0)
#tensor([44, 55, 66]) tensor(1)
#tensor([77, 88, 99]) tensor(2)
  • DataLoader: 数据加载器类,用于批量加载数据集。它接受一个数据集对象作为输入,并提供多种数据加载和预处理的功能,如设置批量大小、多线程数据加载和数据打乱等。DataLoader中最重要的参数就是dataset,它决定了要装载的数据集。

  • Subset: 数据集的子集类,用于从数据集中选择指定的样本。定义了一个子集的索引列表indices,它可以根据需要进行调整。然后,我们使用Subset类创建了一个名为subset的子集对象,它接受两个参数:原始数据集dataset和子集的索引列表indices。

indices = [0, 2, 4]  # 子集的索引列表
subset = Subset(dataset, indices)
  • random_split: 将一个数据集随机划分为多个子集,可以指定划分的比例或指定每个子集的大小。
import torch
import torchvision
# from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from torchvision.datasets import ImageFolder
# 准备数据集
from torch import nn
from torch.utils.data import DataLoader# 定义训练的设备
device = torch.device("cuda")
#读取数据
data_transform = transforms.Compose([transforms.Resize(size=(224,224)),transforms.ToTensor(),transforms.Normalize(mean=[0.5,0.5,0.5], std=[0.5, 0.5, 0.5])
])
full_dataset = ImageFolder(r'D:\PythonSpace\data\trainTest',transform = data_transform)
# length 数据集总长度
full_data_size = len(full_dataset)
print("总数据集的长度为:{}".format(full_data_size))
train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size#在这里
train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])
#在这里train_data_size = len(train_dataset)
test_data_size = len(test_dataset)
# 如果train_data_size=10, 训练数据集的长度为:10
print("训练数据集的长度为:{}".format(train_data_size))
print("测试数据集的长度为:{}".format(test_data_size))
>>>
总数据集的长度为:100
训练数据集的长度为:80
测试数据集的长度为:20
  • ConcatDataset: 将多个数据集连接在一起形成一个更大的数据集。
#链接两个数据集
dataset = torch.utils.data.ConcatDataset([celeba_dataset, digiface_dataset]) 
#导入数据集
loader = torch.utils.data.DataLoader( dataset=dataset, batch_size=cfg.batch_size, shuffle=True, drop_last=True, num_workers=cfg.n_workers)
  • get_worker_info: 获取当前数据加载器所在的进程信息。torch.utils.data.get_worker_info() 在worker进程中返回各种有用的信息(包括worker id、dataset replica、initial seed等),在main进程中返回None。用户可以在数据集代码和/或 worker_init_fn 中使用此函数来单独配置每个数据集副本,并确定代码是否在工作进程中运行。分片数据集特别有用。

如有帮助,点赞收藏关注!


http://www.ppmy.cn/news/1230870.html

相关文章

11 redis中分布式锁的实现

单机锁代码 import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.locks.Condition; import java.util.concurrent.locks.Lock; import java.util.con…

如何解决requests库自动确定认证arded 类型

requests 库是一种非常强大的爬虫工具,可以用于快速构建高效和稳定的网络爬虫程序。对于经常使用爬虫IP用来网站爬虫反爬策略的我来说,下面遇到的问题应当值得我们思考一番。 问题背景 在使用requests库进行网络请求时,有时会遇到需要对目标服务进行认证…

c++模式之单例模式详解

c模式之单例模式详解 1.概念2.懒汉模式示例(缺点)3.懒汉模式线程安全4.饿汉式创建单例5.饿汉模式线程示例 1.概念 单例模式是指在整个系统生命周期内,保证一个类只能产生一个实例,确保该类的唯一性. 使用单例两个原因&#xff1a…

在 Windows 中关闭 Nginx 所有进程

在 Windows 中关闭 Nginx 所有进程并强制重启的命令如下: 打开命令提示符(CMD)。 输入以下命令来查找 Nginx 进程的 PID: tasklist /fi "imagename eq nginx.exe"此命令将列出所有名为 nginx.exe 的进程以及它们的 PID…

Visio免费版!Visio国产平替软件,终于被我找到啦!

作为一个职场人士,我经常需要绘制各种流程图和图表,而Visio一直是我使用的首选工具。但是,随着公司的发展和工作的需要,我逐渐发现了Visio的优点和不足。 首先,让我们来看看Visio的优点。Visio是一个专业的流程图和图…

JVS低代码表单设计:数据联动详解(多级数据级、数据回显等)

在这信息化时代,表单作为数据的收集和展示工具,已经渗透到不同的角落。JVS低代码对表单的设计和操作进行了不断的优化和创新。其中,联动回显作为一项重要的功能,无论是多级数据级联控制、组件的联动控制,还是多表的数据…

数据中心走向绿色低碳,液冷存储舍我其谁

引言:没有最冷,只有更冷,绿色低碳早已成为行业关键词。 【全球存储观察 | 科技热点关注】 每一次存储行业的创新,其根源离不开行业端的用户需求驱动。 近些年从数据中心建设的整体发展情况来看,从风冷到…

基于单片机GPS轨迹定位和里程统计系统

**单片机设计介绍, 基于单片机GPS轨迹定位和里程统计系统 文章目录 一 概要二、功能设计设计思路 三、 软件设计原理图 五、 程序六、 文章目录 一 概要 一个基于单片机、GPS和里程计的轨迹定位和里程统计系统可以被设计成能够在移动的交通工具中精确定位车辆的位置…