【PyTorch】关于torchvision中的数据集以及dataloader的使用

server/2024/9/29 15:42:56/

前提文章目录

【PyTorch】深度学习PyTorch环境配置及安装【详细清晰】
【PyTorch】深度学习PyTorch加载数据
【PyTorch】关于Tensorboard的简单使用
【PyTorch】关于Transforms的简单使用


文章目录

  • ``前提``文章目录
    • 数据集简介
    • 程序中下载数据集
    • 读取数据集
    • 结合transform进行读取数据集
    • dataloader
      • batch_size的参数测试
      • drop_last的参数测试
      • shuffle的参数测试

数据集简介

pytorch官网:https://pytorch.org/
cifar数据集网址:https://pytorch.org/vision/0.9/datasets.html#cifar
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

程序中下载数据集

在这里插入图片描述

使用CIFAR10数据集 按ctrl+p可以提示括号内需要输入什么参数
root:希望数据集存放在什么位置;
train 默认为true,为true表示该数据集为训练集 为false表示该数据集为测试集;
download: 为true就是让它进行一个下载

python">import torchvision
# 训练数据集
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, download=True)
# 测试数据集
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, download=True)

运行程序,可以看到如下图,说明已经开始下载了
在这里插入图片描述
下载速度慢的解决办法:
直接点击蓝色的下载链接进行下载或者将蓝色的下载链接放在迅雷里面进行下载
coco数据集大概30多个G,这个CIFAR数据集才一百多兆,比较时候练手
在这里插入图片描述
下载完毕!
在这里插入图片描述

读取数据集

python">import torchvision# 训练数据集
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, download=True)
# 测试数据集
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, download=True)# 查看数据集
print(test_set[0])  # (<PIL.Image.Image image mode=RGB size=32x32 at 0x27E68307FA0>, 3)  # 图片组成部分 第一个是图片 第二个是target
print(test_set.classes)  # ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']img, target = test_set[0]
print(img)  # <PIL.Image.Image image mode=RGB size=32x32 at 0x27E21B660E0>
print(target)  # 3   实际上对应的就是test_set.classes的第4个也就是cat,因为它是从0开始算的,说明这张图片对应的是猫
print(test_set.classes[target])  # cat
# 展示图片  PIL的Image可以直接show
img.show()

数据集的图片比较小,只有32x32的像素
展示图片结果:
在这里插入图片描述

结合transform进行读取数据集

python">import torchvision# PIL图片类型需要转为tensor类型
from torch.utils.tensorboard import SummaryWriterdataset_transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()
])
# 训练数据集
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=dataset_transform, download=True)
# 测试数据集
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=dataset_transform, download=True)# # 查看数据集
# print(test_set[0])  # (<PIL.Image.Image image mode=RGB size=32x32 at 0x27E68307FA0>, 3)  # 图片组成部分 第一个是图片 第二个是target
# print(test_set.classes)  # ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
#
# img, target = test_set[0]
# print(img)  # <PIL.Image.Image image mode=RGB size=32x32 at 0x27E21B660E0>
# print(target)  # 3   实际上对应的就是test_set.classes的第4个也就是cat,因为它是从0开始算的,说明这张图片对应的是猫
# print(test_set.classes[target])  # cat
# # 展示图片
# img.show()# print(test_set[0])  # tensor类型的  表明可以用tensorboard进行显示
writer = SummaryWriter("p10")
for i in range(10):img, target = test_set[i]writer.add_image("test_set", img, i)
writer.close()

运行程序后,打开终端,用tensorboard进行读取
在这里插入图片描述
在这里插入图片描述

dataloader

在这里插入图片描述
dataset:是一堆数据集,不知道里面每一张数据的内容
dataloader:将数据集加载到神经网络当中,做的是每次去dataset去取数据,每次取多少,怎么取,是由dataloader当中的参数设置的
dataloader官网:https://pytorch.org/docs/1.8.1/data.html?highlight=dataloader#torch.utils.data.DataLoader
在这里插入图片描述

常见参数设置

  • batch_size:每次取数据取的数量
  • shuffle:打乱数据,为true打乱顺序和原来不一样,false打乱顺序和原来一样
  • num_workers:多进程进行加载 默认为0,表示用主进程进行加载 ,有时候在Windows会有一些问题,num_workers>0,有时候在Windows下会出现一些错误
  • drop_last:对于取整数据集剩余除不尽的数据集的个数是采用舍去还是不舍去,为true舍去,false不舍去

示例测试:

batch_size的参数测试

python">import torchvision.datasetsfrom torch.utils.data import DataLoader# 准备测试数据集
from torch.utils.tensorboard import SummaryWritertest_data = torchvision.datasets.CIFAR10("./dataset", train=False, transform=torchvision.transforms.ToTensor())
# batch_size=4每次从数据集中取4张进行打包
test_loader = DataLoader(dataset=test_data, batch_size=4, shuffle=True, num_workers=0, drop_last=False)# 测试数据集中第一张图片及target
img, target = test_data[0]
print(img.shape)  # torch.Size([3, 32, 32])
print(target)  # 3writer = SummaryWriter("dataloader")
step = 0
for data in test_loader:imgs, targets = data# print(imgs.shape)  # torch.Size([4, 3, 32, 32]) 4张图片 3通道 32x32的图片# print(targets)  # tensor([9, 6, 5, 7])  ; 9, 6, 5, 7这些数字是随机抓取的targetwriter.add_images("test_data", imgs, step) #注意这里是add_images有一个s的step += 1
writer.close()

在这里插入图片描述
在这里插入图片描述

drop_last的参数测试

修改代码

python">test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=True)
python">writer.add_images("test_data_drop_last", imgs, step)

在这里插入图片描述

shuffle的参数测试

shuffle为False测试:第一次取的顺序和第二次取的顺序不打乱,是取得顺序是一样的

python">test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=False, num_workers=0, drop_last=True)
python">for epoch in range(2): # epoch 取值0、1    shuffle=False两轮摸一样的。不打乱step = 0for data in test_loader:imgs, targets = datawriter.add_images("Epoch:{}".format(epoch), imgs, step)step += 1

在这里插入图片描述
shuffle为True测试:第一次取的顺序和第二次取的顺序会进行一个打乱

python">test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=True)

在这里插入图片描述
全部代码:

python">import torchvision.datasetsfrom torch.utils.data import DataLoader# 准备测试数据集
from torch.utils.tensorboard import SummaryWritertest_data = torchvision.datasets.CIFAR10("./dataset", train=False, transform=torchvision.transforms.ToTensor())
# batch_size=4每次从数据集中取4张进行打包
#test_loader = DataLoader(dataset=test_data, batch_size=4, shuffle=True, num_workers=0, drop_last=False)
# test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=True)
# test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=False, num_workers=0, drop_last=True)
test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=True)# 测试数据集中第一张图片及target
img, target = test_data[0]
print(img.shape)  # torch.Size([3, 32, 32])
print(target)  # 3writer = SummaryWriter("dataloader")
for epoch in range(2): # epoch 取值0、1    shuffle=False两轮摸一样的。不打乱step = 0for data in test_loader:imgs, targets = data# print(imgs.shape)  # torch.Size([4, 3, 32, 32]) 4张图片 3通道 32x32的图片# print(targets)  # tensor([9, 6, 5, 7])  ; 9, 6, 5, 7这些数字是随机抓取的target# writer.add_images("test_data", imgs, step)# writer.add_images("test_data_drop_last", imgs, step)writer.add_images("Epoch:{}".format(epoch), imgs, step)step += 1
writer.close()

http://www.ppmy.cn/server/106114.html

相关文章

Win11搭建Angular开发环境

作为一名后端程序员&#xff0c;无论当前的工作是否需要&#xff0c;会一点点前端无疑对自己是有帮助的。今天就来介绍一下如何搭建Angular的开发环境。我也是摸着石头过河&#xff0c;所以很多东西也不熟悉&#xff0c;先按照Angular官网的介绍来配置吧。 这个是Angular最新版…

以简单的例子从头开始建spring boot web多模块项目(一)

目的&#xff1a;从头梳理&#xff0c;如何手工从头建立多模块项目。 步骤&#xff1a; 1、建立maven项目,类型&#xff1a;maven Archetype&#xff0c;Name:ParentDemo 选择JDK版本&#xff0c;Archetype&#xff1a;org.apache.maven.archetypes:maven-archetype-quickstart…

Qml Image 截取一部分图片形式

Qml Image 截取一部分 &#xff1a;每次只显示一张图片的一部分&#xff0c;以有不同的状态显示 将这个图形每次只显示一部分出来&#xff1a; 想了好久&#xff0c;才找到实现的方法&#xff1a; 运行效果&#xff1a; 再来一个升级版的小程序&#xff1a; import QtQuick…

014、架构_配置文件_LDS(loadserver.ini)

loadserver.ini 文件说明配置文件“loadserver.ini”用于配置LoadServer模块运行时参数。文件位于用户“~/etc”目录下.配置文件分为: [general]段:配置LoadServer运行中数据文件目录。[transfer]段:配置Transfer运行参数,Transfer安装在数据分片用户下,以LoadServer进程名…

网络协议的基础知识

了解OSI模型和TCP/IP模型 在上一篇关于互联网的工作原理的数据传输中&#xff0c;我们了解到&#xff0c;两台计算机之间传输数据时&#xff0c;需要将数据封装成数据包。这些数据包中不仅包含我们实际要传输的信息&#xff0c;还包括很多额外的内容&#xff0c;比如目标地址、…

搭建 PXE 远程安装服务器和设置 Kickstart 无人值守安装

目录 搭建 PXE 远程安装服务器 1.安装并启用 TFTP 服务 2.安装并启用 DHCP 服务 3.准备 Linux 内核、初始化镜像文件 4.准备 PXE 引导程序 5.安装FTP服务&#xff0c;准备CentOS 7 安装源 6.配置启动菜单文件 7.关闭防火墙&#xff0c;验证 PXE 网络安装 设置 Kicksta…

MATLAB 沿任意方向分层点云(82)

MATLAB 沿任意方向分层点云(82) 一、算法介绍二、算法实现1.代码2.效果更多内容参考: MATLAB点云处理学习 一、算法介绍 沿着某个方向,将点云分割为多层,每层点云使用不同颜色进行可视化显示,具体代码和不同方向的分层效果如下: 二、算法实现 1.代码 % Load point c…

视频截取中的UI小组件

引言 视频截取在社交类 APP 中十分常见。有了上传视频的功能&#xff0c;就不可避免地需要提供截取和编辑的选项。如果我们过度依赖第三方库&#xff0c;项目的代码可能会变得异常臃肿&#xff0c;因为这些库往往包含许多我们用不到的功能&#xff0c;而且它们的 UI 样式和功能…