Pytorch从入门到精通:二、dataset与datalodar

news/2024/11/28 15:57:01/

数据是深度学习的基础,一般来说,数据量越大,训练出来的模型也越强大。如果现在有了一些数据,该怎么把这些数据加到模型中呢?Pytorch中提供了dataset和dataloader,让我们一起来学习一下吧,dataset和dataloader博主将用几个例子来说明,感谢支持!
在这里插入图片描述

文章目录

  • 一、dataset
  • 二、查看dataset
  • 三、os操作读取文件夹下的对象
  • 四、Dataset
    • Dataset实操一
    • Dataset 实操二
    • dataset实操三
  • 五、 datalodar
    • 自定义dataset并用datalodar加载
  • 六、os的一些操作

一、dataset

提供一种方式去获取数据及其label
● 如何获取每一个数据及其label
● 告诉我们有多少数据
查看pytorch是否可用

print(torch.cuda.is_available()) # 查看当前cuda是否可用
True

二、查看dataset

from torch.utils.data import Dataset
help(Dataset) # 用帮助文档查看Dataset

Help on class Dataset in module torch.utils.data.dataset:
class Dataset(typing.Generic)
| Dataset(*args, **kwds)
|
| An abstract class representing a :class:Dataset.
|
| All datasets that represent a map from keys to data samples should subclass
| it. All subclasses should overwrite :meth:__getitem__, supporting fetching a
| data sample for a given key. Subclasses could also optionally overwrite
| :meth:__len__, which is expected to return the size of the dataset by many
| :class:~torch.utils.data.Sampler implementations and the default options
| of :class:~torch.utils.data.DataLoader.
|
| … note::
| :class:~torch.utils.data.DataLoader by default constructs a index
| sampler that yields integral indices. To make it work with a map-style
| dataset with non-integral indices/keys, a custom sampler must be provided.
|
| Method resolution order:
| Dataset
| typing.Generic
| builtins.object
|
| Methods defined here:
|
| add(self, other: ‘Dataset[T_co]’) -> ‘ConcatDataset[T_co]’
|
| getattr(self, attribute_name)
|
| getitem(self, index) -> +T_co
|
| ----------------------------------------------------------------------
| Class methods defined here:
|
| register_datapipe_as_function(function_name, cls_to_register, enable_df_api_tracing=False) from builtins.type
|
| register_function(function_name, function) from builtins.type
|
| ----------------------------------------------------------------------
| Data descriptors defined here:
|
| dict
| dictionary for instance variables (if defined)
|
| weakref
| list of weak references to the object (if defined)
|
| ----------------------------------------------------------------------
| Data and other attributes defined here:
|
| annotations = {‘functions’: typing.Dict[str, typing.Callable]}
|
| orig_bases = (typing.Generic[+T_co],)
|
| parameters = (+T_co,)
|
| functions = {‘concat’: functools.partial(<function Dataset.register_da…
|
| ----------------------------------------------------------------------
| Class methods inherited from typing.Generic:
|
| class_getitem(params) from builtins.type
|
| init_subclass(*args, **kwargs) from builtins.type
| This method is called when a class is subclassed.
|
| The default implementation does nothing. It may be
| overridden to extend subclasses.
|
| ----------------------------------------------------------------------
| Static methods inherited from typing.Generic:
|
| new(cls, *args, **kwds)
| Create and return a new object. See help(type) for accurate signature.

三、os操作读取文件夹下的对象

import os
dir_path = "hymenoptera_data\\hymenoptera_data\\train\\ants"  # 文件夹目录
data_dir = os.listdir(dir_path)  # 获取文件夹目录中的对象
data_dir

[‘0013035.jpg’,
‘1030023514_aad5c608f9.jpg’,
‘1095476100_3906d8afde.jpg’,
‘1099452230_d1949d3250.jpg’,
‘116570827_e9c126745d.jpg’,
‘1225872729_6f0856588f.jpg’,
‘1262877379_64fcada201.jpg’,
‘1269756697_0bce92cdab.jpg’,
‘1286984635_5119e80de1.jpg’,
‘132478121_2a430adea2.jpg’,
‘1360291657_dc248c5eea.jpg’,
‘1368913450_e146e2fb6d.jpg’,
‘1473187633_63ccaacea6.jpg’,
‘148715752_302c84f5a4.jpg’,
‘1489674356_09d48dde0a.jpg’,
‘149244013_c529578289.jpg’,
‘150801003_3390b73135.jpg’,
‘150801171_cd86f17ed8.jpg’,
‘154124431_65460430f2.jpg’,
‘162603798_40b51f1654.jpg’,
‘1660097129_384bf54490.jpg’,
‘167890289_dd5ba923f3.jpg’,
‘1693954099_46d4c20605.jpg’,
‘175998972.jpg’,
‘178538489_bec7649292.jpg’,
‘1804095607_0341701e1c.jpg’,
‘1808777855_2a895621d7.jpg’,
‘188552436_605cc9b36b.jpg’,
‘1917341202_d00a7f9af5.jpg’,
‘1924473702_daa9aacdbe.jpg’,
‘196057951_63bf063b92.jpg’,
‘196757565_326437f5fe.jpg’,
‘201558278_fe4caecc76.jpg’,
‘201790779_527f4c0168.jpg’,
‘2019439677_2db655d361.jpg’,
‘207947948_3ab29d7207.jpg’,
‘20935278_9190345f6b.jpg’,
‘224655713_3956f7d39a.jpg’,
‘2265824718_2c96f485da.jpg’,
‘2265825502_fff99cfd2d.jpg’,
‘226951206_d6bf946504.jpg’,
‘2278278459_6b99605e50.jpg’,
‘2288450226_a6e96e8fdf.jpg’,
‘2288481644_83ff7e4572.jpg’,
‘2292213964_ca51ce4bef.jpg’,
‘24335309_c5ea483bb8.jpg’,
‘245647475_9523dfd13e.jpg’,
‘255434217_1b2b3fe0a4.jpg’,
‘258217966_d9d90d18d3.jpg’,
‘275429470_b2d7d9290b.jpg’,
‘28847243_e79fe052cd.jpg’,
‘318052216_84dff3f98a.jpg’,
‘334167043_cbd1adaeb9.jpg’,
‘339670531_94b75ae47a.jpg’,
‘342438950_a3da61deab.jpg’,
‘36439863_0bec9f554f.jpg’,
‘374435068_7eee412ec4.jpg’,
‘382971067_0bfd33afe0.jpg’,
‘384191229_5779cf591b.jpg’,
‘386190770_672743c9a7.jpg’,
‘392382602_1b7bed32fa.jpg’,
‘403746349_71384f5b58.jpg’,
‘408393566_b5b694119b.jpg’,
‘424119020_6d57481dab.jpg’,
‘424873399_47658a91fb.jpg’,
‘450057712_771b3bfc91.jpg’,
‘45472593_bfd624f8dc.jpg’,
‘459694881_ac657d3187.jpg’,
‘460372577_f2f6a8c9fc.jpg’,
‘460874319_0a45ab4d05.jpg’,
‘466430434_4000737de9.jpg’,
‘470127037_513711fd21.jpg’,
‘474806473_ca6caab245.jpg’,
‘475961153_b8c13fd405.jpg’,
‘484293231_e53cfc0c89.jpg’,
‘49375974_e28ba6f17e.jpg’,
‘506249802_207cd979b4.jpg’,
‘506249836_717b73f540.jpg’,
‘512164029_c0a66b8498.jpg’,
‘512863248_43c8ce579b.jpg’,
‘518773929_734dbc5ff4.jpg’,
‘522163566_fec115ca66.jpg’,
‘522415432_2218f34bf8.jpg’,
‘531979952_bde12b3bc0.jpg’,
‘533848102_70a85ad6dd.jpg’,
‘535522953_308353a07c.jpg’,
‘540889389_48bb588b21.jpg’,
‘541630764_dbd285d63c.jpg’,
‘543417860_b14237f569.jpg’,
‘560966032_988f4d7bc4.jpg’,
‘5650366_e22b7e1065.jpg’,
‘6240329_72c01e663e.jpg’,
‘6240338_93729615ec.jpg’,
‘649026570_e58656104b.jpg’,
‘662541407_ff8db781e7.jpg’,
‘67270775_e9fdf77e9d.jpg’,
‘6743948_2b8c096dda.jpg’,
‘684133190_35b62c0c1d.jpg’,
‘69639610_95e0de17aa.jpg’,
‘707895295_009cf23188.jpg’,
‘7759525_1363d24e88.jpg’,
‘795000156_a9900a4a71.jpg’,
‘822537660_caf4ba5514.jpg’,
‘82852639_52b7f7f5e3.jpg’,
‘841049277_b28e58ad05.jpg’,
‘886401651_f878e888cd.jpg’,
‘892108839_f1aad4ca46.jpg’,
‘938946700_ca1c669085.jpg’,
‘957233405_25c1d1187b.jpg’,
‘9715481_b3cb4114ff.jpg’,
‘998118368_6ac1d91f81.jpg’,
‘ant photos.jpg’,
‘Ant_1.jpg’,
‘army-ants-red-picture.jpg’,
‘formica.jpeg’,
‘hormiga_co_por.jpg’,
‘imageNotFound.gif’,
‘kurokusa.jpg’,
‘MehdiabadiAnt2_600.jpg’,
‘Nepenthes_rafflesiana_ant.jpg’,
‘swiss-army-ant.jpg’,
‘termite-vs-ant.jpg’,
‘trap-jaw-ant-insect-bg.jpg’,
‘VietnameseAntMimicSpider.jpg’]
注意在windows下,路径使用双斜线\

四、Dataset

Dataset实操一

from torch.utils.data import Dataset
import os
from PIL import Imageclass Mydata(Dataset):def __init__(self,root_path,label_path):self.root_path = root_path  # hymenoptera_data/hymenoptera_data/trainself.label_path = label_path  # /antsself.path = os.path.join(self.root_path,self.label_path)  # 从根目录开始的绝对路径self.image_path = os.listdir(self.path) # 从根目录开始绝对路径文件夹下的对象 hymenoptera_data/hymenoptera_data/train/ants下的图片 type--> listdef __getitem__(self, idx):image_name = self.image_path[idx] # 单一的图片名称image_item_path = os.path.join(self.root_path,self.label_path,image_name)img = Image.open(image_item_path)label = self.label_pathreturn img,labeldef __len__(self):return len(self.image_path)ants_root_path = "hymenoptera_data\\hymenoptera_data\\train"
ants_label_path = "ants"
Ants = Mydata(ants_root_path,ants_label_path)
Ants[0][0].show() # 第一个0是索引,拿到第一个图像和标签,第二个0是拿到第一个图像,并显示出来

D:\anaconda\envs\Gpu-Pytorch\lib\site-packages\tqdm\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
在这里插入图片描述

bee_label_path = "bees"
Bees = Mydata(bee_root_path,bee_label_path)
Bees[0][0].show()

在这里插入图片描述

# 创建训练集train = Ants + Bees   # 直接将数据集加起来
print("the length of Ants is ",Ants.__len__())
print("the length of Bees is ",Bees.__len__())
print("the length of train is ",train.__len__())
the length of Ants is  124
the length of Bees is  121
the length of train is  245
# 查看是否正确
train[123][0].show() # 应该为蚂蚁
train[124][0].show() # 应该为蜜蜂

在这里插入图片描述
在这里插入图片描述

Dataset 实操二

#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project :Pytorch学习 
@File    :task_3.py
@IDE     :PyCharm 
@Author  :咋
@Date    :2023/6/29 14:29 
"""
from torch.utils.data import Dataset
import os
from PIL import Imageclass Mydata(Dataset):def __init__(self,root_path,image_path,label_path):self.root_path = root_pathself.image_path = image_pathself.label_path = label_pathself.A_image_path = os.path.join(self.root_path,self.image_path)self.A_label_path = os.path.join(self.root_path,self.label_path)self.img_item = os.listdir(self.A_image_path)self.label_item = os.listdir(self.A_label_path)def __getitem__(self, idx):img_name = self.img_item[idx]img_path = os.path.join(self.A_image_path, img_name)label_list = [i.split(".")[0] for i in self.label_item if i.count(".") == 1]# print(label_list)if img_name.split(".")[0] in label_list:img = Image.open(img_path)label_path = os.path.join(self.A_label_path,img_name.split(".")[0])label_path += ".txt"file = open(label_path, 'r')label = file.read()file.close()return img,labelelse:print("{0}没有对应的标签".format(img_name))return 0def __len__(self):return len(self.img_item)train_ants_root_path = "练手数据集\\train"
train_ants_image_path = "ants_image"
train_ants_label_path = "ants_label"
Ants = Mydata(train_ants_root_path,train_ants_image_path,train_ants_label_path)
for i in range(Ants.__len__()):try:print(Ants[i][1])except TypeError:print("跳过此张图片!")
# Ants[122][0].show()
# print(Ants[122][1])

ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
ants
formica.jpeg没有对应的标签
跳过此张图片!
ants
imageNotFound.gif没有对应的标签
跳过此张图片!
ants
ants
ants
ants
ants
ants
添加了异常捕获,解决了图片没有对应标签的问题!

dataset实操三

使用torchvision中的数据集创建dataset

#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project :Pytorch_learn 
@File    :dataset_3.py
@IDE     :PyCharm 
@Author  :咋
@Date    :2023/7/2 14:58 
"""
import torchvision
from torch.utils.data import DataLoader
from tensorboardX import SummaryWriter
from torchvision import transforms
dataset = torchvision.datasets.MNIST("./Mnist",train=True,download=True,transform=transforms.ToTensor())
dataloader = DataLoader(dataset,batch_size=64,shuffle=False,num_workers=0)
# 使用tensorboard将dataloader展示出来
'''方式一
# write = SummaryWriter("log_2")
# count = 0
# for data in dataloader:
#     image,label = data
#     # print(data[1])
#     # print(image.shape)
#     write.add_images("dataloader",image,count)
#     count += 1
'''# 方式二
write = SummaryWriter("log_3")
for i,data in enumerate(dataloader):image,label = datawrite.add_images("dataloader",image,i)write.close()

在这里插入图片描述
enumerate会将可迭代对象中的内容和其索引一起返回:

例如对于一个seq,得到:
(0, seq[0]), (1, seq[1]), (2, seq[2])

五、 datalodar

为后面的网络提供不同的数据类型

自定义dataset并用datalodar加载

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from net import Net
import softmax
from torch.utils.data import Dataset
import os
from PIL import Image
import numpy as nptransform_tool = transforms.ToTensor()  # 创建一个transform工具
# # image_tensor = transform_tool(image)
with open("mnist-label.txt", 'r') as f:label_str = f.read().strip()   # 打开文件读入缓存
class Mydata(Dataset):def __init__(self,image_path):self.image_path = image_path# self.label_path = label_path  # /antsself.image = os.listdir(self.image_path) # 从根目录开始绝对路径文件夹下的对象 hymenoptera_data/hymenoptera_data/train/ants下的图片 type--> listdef __getitem__(self, idx):image_name = self.image[idx] # 单一的图片名称image_item_path = os.path.join(self.image_path,image_name)img = Image.open(image_item_path)# transform_tool = transforms.ToTensor()  # 创建一个transform工具img = transform_tool(img)labels_list = [int(label) for label in label_str.split(',')]  # 读取标签,不用每次都打开labels = np.array(labels_list)label = labels[idx]return img,labeldef __len__(self):return len(self.image)
# trainset = Mydata("mnist-dataset")# 设置训练参数
batch_size = 32
epochs = 5
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 数据集
# transform = transforms.Compose([transforms.ToTensor(),
#                                 transforms.Normalize((0.5,), (0.5,))])
# trainset =
# trainset = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=True, transform=transform)
trainset = Mydata("mnist-dataset")trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=False,num_workers=0)
print(len(trainloader))
# 输出提示信息
print("batch_size:", batch_size)
print("data_batches:", len(trainloader))
print("epochs:", epochs)# 神经网络
net = Net().to(device)
# net.load_state_dict(torch.load('./model/model.pth'))# 损失函数和优化器
# 负对数似然损失
criterion = nn.NLLLoss()
optimizer = optim.SGD(net.parameters(), lr=0.0005, momentum=0.9)
total_correct = 0
total_samples = 0
# 训练网络
```python
for epoch in range(epochs):running_loss = 0.0for i, data in enumerate(trainloader):inputs, labels = datainputs, labels = Variable(inputs).to(device), Variable(labels).to(device)# 反向传播优化参数optimizer.zero_grad()outputs = net(inputs)# outputs = int(net(inputs))# print(outputs)labels = labels.long()# print(labels)# print(type(labels))loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()# 计算每个batch的准确率_, predicted = torch.max(outputs.data, 1)total_samples += labels.size(0)total_correct += (predicted == labels).sum().item()if i % 5 == 0:    # 每轮输出损失值accuracy = 100.0 * total_correct / total_samplesprint('[epoch: %d, batches: %d] loss: %.5f accuracy: %.2f%%' %(epoch + 1, i + 1, running_loss / 2000, accuracy))total_correct = 0total_samples = 0running_loss = 0.0
torch.save(net.state_dict(), 'model.pth')  # 每轮保存模型参数print('Finished Training')

打开文件可以在定义类之前打开,把文件信息读入缓存中,在__getitem__中读取各个标签,不用每次执行__getitem__都打开一次文件。

六、os的一些操作

windows使用两个\\表示路径
import os
dir_path = "/home/aistudio"  # 文件夹目录
data_dir = os.listdir(dir_path)  # 获取文件夹目录中的对象
label_path = "label"
all_path = os.path.join(dir_path,label_path)

http://www.ppmy.cn/news/978047.html

相关文章

三十章:Segmenter:Transformer for Semantic Segmentation ——分割器:用于语义分割的Transformer

0.摘要 图像分割在单个图像块的级别上经常存在歧义&#xff0c;并需要上下文信息来达到标签一致性。在本文中&#xff0c;我们介绍了一种用于语义分割的Transformer模型- Segmenter。与基于卷积的方法相比&#xff0c;我们的方法允许在第一层和整个网络中对全局上下文进行建模。…

达闼面试(部分)(未完全解析)

grpc怎么解决负载均衡问题? Answer by newBing : gRPC提供了多种负载均衡策略&#xff0c;包括轮询、随机、最少连接数等。gRPC客户端可以使用这些策略来选择要连接的服务器。 k8s环境下部署grpc的几种方案 : 在k8s环境中&#xff0c;可以选择headless service&#xff0c;或者…

单片机第一季:零基础11——实时时钟DS1302

目录 1&#xff0c;DS1302 时钟芯片介绍 2&#xff0c;BCD码介绍 3&#xff0c;涉及到的寄存器 3.1&#xff0c;控制寄存器 3.2&#xff0c;日历/时钟寄存器 3.3&#xff0c;DS1302 的读写时序 4&#xff0c;相关代码 这一章我们来学习DS1302 时钟芯片&#xff0c…

openGauss学习笔记-20 openGauss 简单数据管理-DISTINCT

文章目录 openGauss学习笔记-20 openGauss 简单数据管理-DISTINCT20.1 语法格式20.2 参数说明20.3 示例 openGauss学习笔记-20 openGauss 简单数据管理-DISTINCT DISTINCT关键字与SELECT语句一起使用&#xff0c;用于去除重复记录&#xff0c;只获取唯一的记录。 当一个表中有…

APP自动化测试工具大全

在本篇文章中&#xff0c;将给大家推荐14款日常工作中经常用到的测试开发工具神器&#xff0c;涵盖了自动化测试、APP性能测试、稳定性测试、抓包工具等。 【B站最通俗易懂】Python接口自动化测试从入门到精通&#xff0c;超详细的进阶教程&#xff0c;看完这套视频就够了 一、…

Angular:动态依赖注入和静态依赖注入

问题描述&#xff1a; 自己写的服务依赖注入到组件时候是直接在构造器内初始化的。 直到看见代码中某大哥写的 private injector: Injector 动态依赖注入和静态依赖注入 在 Angular 中&#xff0c;使用构造函数注入的方式将服务注入到组件中是一种静态依赖注入的方式。这种方…

sql关键字和字段冲突

问题描述&#xff1a;包含key字段名的条件查询sql语句执行居然报错 select * from goods_table where key"apple";执行报错&#xff1a; 报错原因 sql语句的字段和关键字发生冲突&#xff0c;导致无法解析sql 报错解决 key加反引号后&#xff0c;即可解决 sele…

本地 IDC 中的 K8s 集群如何以 Serverless 方式使用云上计算资源

作者&#xff1a;庄宇 在前一篇文章《应对突发流量&#xff0c;如何快速为自建 K8s 添加云上弹性能力》中&#xff0c;我们介绍了如何为 IDC 中 K8s 集群添加云上节点&#xff0c;应对业务流量的增长&#xff0c;通过多级弹性调度&#xff0c;灵活使用云上资源&#xff0c;并通…