【PyTorch】(基础五)---- 图像数据集加载

news/2024/12/15 4:17:26/

数据集

torchvision数据集

Torchvision在torchvision.datasets模块中提供了许多内置的数据集,以及用于构建您自己的数据集的实用程序类。关于一些内置数据集目录如下,点击进去之后会有详细的数据集介绍,包括数据集大小、分类类型、以及下载方式等。

在这里插入图片描述

接下来我都会使用CIFAR-10数据集为例进行展示,因为其数据量不是很大,下载到本地后对存储的压力不是很大,CIFAR-10数据集由10个类别的60000张32 x32彩色图像组成,每个类别6000张图像。有50000张训练图像和10000张测试图像,用于完成图像分类任务,点进去之后看到详细页面

在这里插入图片描述

以下是CIFAR-10数据集的类别和部分图片展示:

在这里插入图片描述

torchvision提供的下载CIFAR10数据集的语法如下:

torchvision.datasets.CIFAR10(root: Union[str, Path], train: bool = True, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False)# CIFAR10为数据集的名字
# root为存放数据的目录
# train表示是否为训练数据集
# download表示是否将数据下载到本地
# transform可以指定如何对数据集进行预处理,通常都是使用ToTensor# 我们在选择好这些属性参数后的命令如下:
dataset = torchvision.datasets.CIFAR10(root=./dataset”, train = True, download = True,transform=torchvision.transforms.ToTensor())

运行后得到的下载链接可以放到迅雷中使用,下载速度更快,下载后将压缩文件复制到当前的目录下面即可,系统再次运行时会对其进行解压使用

在这里插入图片描述

第一次下载后,即使继续打开download运行,系统在检测到之后就不需要继续进行下载了,所以download经常处于打开状态后续也无需进行修改

DataLoader

datalorader用于加载数据集并提供迭代器(指定每次取多少个数据,是否随机取数等),使得模型训练过程中的数据读取更加高效。其基本语法如下:

test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=True)# dataset指定从哪个DataSet中进行读取数据
# batch_size指定每次读取数据的多少
# shuffle指定是否进行随机抽取
# num_workers指定的多线程数量,在Windows中设置大于0可能出问题
# drop_last表示是否保留最后余出的几个数据,比如一共有1024张图片,batch_size设置为100,最终会余出24个数据

【代码示例】使用DataLoader读取CIFAR10数据集,每次读取64张图片

import torchvision# 准备的测试数据集
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWritertest_data = torchvision.datasets.CIFAR10("./dataset/CIFAR10", train=False, transform=torchvision.transforms.ToTensor())test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=True)# 测试数据集中第一张图片及target
img, target = test_data[0]
print(img.shape)
print(target)writer = SummaryWriter("dataloader")
# 观察每一次随机的结果
for epoch in range(2):step = 0for data in test_loader:imgs, targets = data# print(imgs.shape)# print(targets)writer.add_images("Epoch: {}".format(epoch), imgs, step)step = step + 1writer.close()

运行结果:

在这里插入图片描述

自定义数据集

要实现自定义的数据集类,首先,需要创建一个数据集类。这个类需要继承 torch.utils.data.Dataset 并实现两个方法:__len____getitem__

这里使用一个蚂蚁和蜜蜂的分类数据集(网盘下载),观察一下其目录结构,整个数据集分成训练集(train)和验证集(val)两部分,每部分包含ants和bees两个文件夹,每个文件夹中都是若干图片,其中的文件夹名字就是其图片的label

我们可以创建自定义的Dataset读取具体的数据。

import os
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from torch.utils.tensorboard import SummaryWriter# 定义自定义 Dataset 类
class AntsBeesDataset(Dataset):def __init__(self, root_dir, transform=None):self.root_dir = root_dirself.transform = transformself.image_paths = []self.labels = []# 遍历根目录下的 ants 和 bees 文件夹for label in ['ants', 'bees']:label_dir = os.path.join(root_dir, label)for filename in os.listdir(label_dir):if filename.endswith('.jpg'):self.image_paths.append(os.path.join(label_dir, filename))self.labels.append(0 if label == 'ants' else 1)def __len__(self):return len(self.image_paths)def __getitem__(self, idx):image_path = self.image_paths[idx]label = self.labels[idx]image = Image.open(image_path).convert('RGB')if self.transform:image = self.transform(image)return image, label# 定义数据变换
transform = transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 创建训练数据集
train_dataset = AntsBeesDataset(root_dir='./data_antAndBee/train', transform=transform)# 创建验证数据集
val_dataset = AntsBeesDataset(root_dir='./data_antAndBee/val', transform=transform)# 创建数据加载器
train_loader = DataLoader(dataset=train_dataset,batch_size=32,shuffle=True,num_workers=0,  # 使用多进程pin_memory=True,  # 使用 pinned 内存
)val_loader = DataLoader(dataset=val_dataset,batch_size=32,shuffle=False,num_workers=0,  # 使用多进程pin_memory=True,  # 使用 pinned 内存
)# 检查数据集和数据加载器
print(train_dataset.__len__())  # 输出数据集长度
print(train_dataset.__getitem__(0))  # 输出第一个样本# TensorBoard Writer
writer = SummaryWriter('logs/log8')# 每轮读取的结果都是随机的
for epoch in range(2):step = 0for data in train_loader:imgs, targets = data# 如果需要将图像堆叠成一个张量,可以在这里进行处理# 例如,使用 pad_sequence 或者自定义的填充方法# imgs = torch.stack([F.pad(img, (0, max_width - img.size(-1), 0, max_height - img.size(-2))) for img in imgs])writer.add_images("Epoch: {}".format(epoch), imgs, step)step += 1writer.close()

上面的方法实现起来很复杂,其实这种目录结构叫做“ImageFolder”格式,其结构如下所示:

root/
├── class1/
│   ├── image1.jpg
│   ├── image2.jpg
│   └── ...
├── class2/
│   ├── image1.jpg
│   ├── image2.jpg
│   └── ...
└── ...

torchvision.datasets为我们提供了一种方法用于快速便捷地处理这种格式的数据集,使用方法ImageFolder()可以快速加载对应的数据集

from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets
from torchvision.transforms import transformsif __name__ == "__main__":# 定义数据变换transform = transforms.Compose([# 改成统一大小并转换成tensor类型transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),])# 创建训练数据集train_dataset = datasets.ImageFolder(root='./data_antAndBee/train',transform=transform)# 创建验证数据集val_dataset = datasets.ImageFolder(root='./data_antAndBee/val',transform=transform)# 创建数据加载器train_loader = DataLoader(dataset=train_dataset,batch_size=32,shuffle=True,num_workers=0,  # 使用多进程pin_memory=True  # 使用 pinned 内存)val_loader = DataLoader(dataset=val_dataset,batch_size=32,shuffle=False,num_workers=0,  # 使用多进程pin_memory=True  # 使用 pinned 内存)# 检查数据集和数据加载器print(train_dataset.classes)  # 输出类别名称print(train_dataset.class_to_idx)  # 输出类别到索引的映射# TensorBoard Writerwriter = SummaryWriter('logs/log7')# 每轮读取的结果都是随机的for epoch in range(2):step = 0for data in train_loader:imgs, targets = datawriter.add_images("Epoch: {}".format(epoch), imgs, step)step += 1writer.close()

http://www.ppmy.cn/news/1555202.html

相关文章

Sui 区块链 Move 语言基础:深入解析数据类型与模块概念

目录 前言Move 共学活动:快速上手 Move 开发一、整数类型1. Move 语言特性:强类型与类型安全2. 运算符3. 处理负数与小数 二、布尔类型三、地址类型1. 十六进制地址2. 命名地址 四、包和模块的概念1. 创建一个包2. 包名与配置文件一致性3. 模块名与文件名…

基于小程序实现日历课表、排班表、月份切换、快捷周切换、自定义课程内容、课程颜色、Mock数据开箱即用

目录 引言小程序开发背景本文目标:实现日历课表/排班表适用场景:学生课表、员工排班、日程安排等需求分析支持日历视图和课表/排班视图可以查看、添加、编辑、删除课表/排班项支持按周、月查看总结说明参考代码数据Mock引言 本文将介绍如何基于小程序实现一个日历课表和排班表…

文件系统--底层架构(图文详解)

一、文件系统的底层存储与寻址 当我们谈到文件系统的底层结构时,最关键的问题是:文件的数据与元数据(属性)如何存储在磁盘上,以及系统是如何定位这些数据的?在谈及文件系统之前,我们要先对储存…

【考前预习】1.计算机网络概述

往期推荐 子网掩码、网络地址、广播地址、子网划分及计算-CSDN博客 一文搞懂大数据流式计算引擎Flink【万字详解,史上最全】-CSDN博客 浅学React和JSX-CSDN博客 浅谈云原生--微服务、CICD、Serverless、服务网格_云原生 serverless-CSDN博客 浅谈维度建模、数据分析…

css矩形样式,两边圆形

废话不多说&#xff0c;代码如下&#xff0c;直接拷贝即可使用&#xff1a; index.vue文件 <template><view class"wrap"><view class"tabs"><view class"tab active"><view class"name">标签</view…

C语言程序设计P5-5【应用函数进行程序设计 | 第五节】—知识要点:变量的作用域和生存期

知识要点&#xff1a;变量的作用域和生存期 视频&#xff1a; 目录 一、任务分析 二、必备知识与理论 三、任务实施 一、任务分析 有一个一维数组&#xff0c;内放 10 个学生成绩&#xff0c;写一个函数&#xff0c;求出平均分、最高分和最低分。 任务要求用一个函数来完…

MacOs 日常故障排除troubleshooting

1. 关闭开机自启动 app X macOs 15.1 System settings -> General -> Login Items & Extensions->Open at Login -> Select app X and click -

OpenCV DCT图像去噪

文章目录 一、简介二、实现代码三、实现效果参考文献一、简介 DCT(离散余弦变换)图像去噪是一种基于频域变换的去噪方法,其主要思想是通过将图像从空间域转换到频域,对频域中的高频成分(通常与噪声相关)进行滤波,从而实现去噪。 DCT图像去噪的基本步骤: 1.将图像转换到…