【Pytorch】加载数据

devtools/2024/11/14 9:24:36/

数据集获取:链接: https://pan.baidu.com/s/1jZoTmoFzaTLWh4lKBHVbEA 密码: 5suq

本文基于P5. PyTorch加载数据初认识_哔哩哔哩_bilibili 

dataset:提供一种方式去获取数据及其label值,解释:Pytorch中的dataset类——创建适应任意模型的数据集接口_datasetpath-CSDN博客

dataloader:为网络提供不同的数据形式

首先新建一个python文件:read_data

把数据集文件与代码文件放在同一目录下

找到图片,复制路径。

read_data文件代码:

python">from torch.utils.data import Dataset
# 读取图片
from PIL import Image
import os# Dataset 是 PyTorch 的数据集基类。
# Image 用于打开和处理图片。
# os 用于处理文件路径。# MyData 类继承自 PyTorch 的 Dataset 类,需要实现三个方法:__init__()、__getitem__() 和 __len__()。
class MyData(Dataset):# 初始化sdef __init__(self, root_dir, label_dir):# self.root_dir和self.label_dir分别保存图像数据的根目录和标签目录。# self.path是root_dir 和 label_dir的连接路径。# self.img_path是指定目录下所有文件的列表,即图像文件的名称。# 路径self.root_dir = root_dir# 标签名self.label_dir = label_dir# 拼接成路径名self.path = os.path.join(self.root_dir, self.label_dir)# 获取所有图片的编号self.img_path = os.listdir(self.path)# 传编号def __getitem__(self, idx):# idx是数据集中的索引。# img_name是根据索引获取的图像文件名称。# img_item_path是图像的完整路径。# Image.open(img_item_path)用于打开图像文件。# label是图像的标签(在这个例子中,标签是目录名)。# return img, label返回图像和标签的元组。# 当前图片的名字img_name = self.img_path[idx]# 当前图片的地址img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)# 打开图片# Image.open()返回值是PIL类型格式,可以直接图片展示img = Image.open(img_item_path)label = self.label_dir# 返回样本对{x:y}return img, labeldef __len__(self):# 返回数据集中图像的数量,即img_path列表的长度。# 返回长度return len(self.img_path)# root_dir 是数据的根目录。
# ants_label_dir 和 bees_label_dir 是两个标签目录,分别代表蚂蚁和蜜蜂的图像数据。
# ants_dataset 和 bees_dataset 分别是两个 MyData 实例,表示蚂蚁和蜜蜂的图像数据集。
root_dir = "dataset/train"
ants_label_dir = "ants"
bees_label_dir = "bees"
ants_dataset = MyData(root_dir, ants_label_dir)
bees_dataset = MyData(root_dir, bees_label_dir)train_dataset = ants_dataset + bees_dataset

进阶版:

from torch.utils.data import Dataset, DataLoader
from torch.utils.data import ConcatDataset
import numpy as np
from PIL import Image
import os
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import make_grid# Dataset 和 DataLoader 用于创建和加载数据集。
# ConcatDataset 用于合并多个数据集。
# Image 用于打开和处理图像。
# os 用于处理文件路径。
# transforms 用于图像预处理。
# SummaryWriter 用于 TensorBoard 日志记录。
# make_grid 用于将多个图像合并成一个网格图像。writer = SummaryWriter("logs")class MyData(Dataset):def __init__(self, root_dir, image_dir, label_dir, transform):self.root_dir = root_dirself.image_dir = image_dirself.label_dir = label_dirself.label_path = os.path.join(self.root_dir, self.label_dir)self.image_path = os.path.join(self.root_dir, self.image_dir)self.image_list = os.listdir(self.image_path)self.label_list = os.listdir(self.label_path)# 应用于图像的转换操作(如调整大小和转换为 Tensor)self.transform = transform# 因为label 和 Image文件名相同,进行一样的排序,可以保证取出的数据和label是一一对应的self.image_list.sort()self.label_list.sort()def __getitem__(self, idx):# 根据索引idx获取图像和标签。# img_item_path和label_item_path是图像和标签的完整路径。# Image.open(img_item_path)# 打开图像文件。img_name = self.image_list[idx]label_name = self.label_list[idx]img_item_path = os.path.join(self.root_dir, self.image_dir, img_name)label_item_path = os.path.join(self.root_dir, self.label_dir, label_name)#获取图片文件img = Image.open(img_item_path)# 读取标签文件的内容。with open(label_item_path, 'r') as f:label = f.readline()# 应用转换操作self.transform。img = self.transform(img)# 返回一个字典,包含图像和标签。sample = {'img': img, 'label': label}return sampledef __len__(self):# 确保图像和标签的数量相同。# 返回数据集中图像的数量。assert len(self.image_list) == len(self.label_list)return len(self.image_list)if __name__ == '__main__':# transform定义了图像预处理操作。transform = transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor()])root_dir = "dataset/train"image_ants = "ants_image"label_ants = "ants_label"ants_dataset = MyData(root_dir, image_ants, label_ants, transform)image_bees = "bees_image"label_bees = "bees_label"bees_dataset = MyData(root_dir, image_bees, label_bees, transform)train_dataset = ants_dataset + bees_dataset# 使用DataLoader创建一个数据加载器,batch_size = 1和num_workers = 2。dataloader = DataLoader(train_dataset, batch_size=1, num_workers=2)# 使用SummaryWriter将索引为119的图像写入TensorBoard。writer.add_image('error', train_dataset[119]['img'])writer.close()


http://www.ppmy.cn/devtools/111377.html

相关文章

Vscode中启动Vue2.x项目运行正常但templete部分UI组件红色波浪线报错 ts(2339)

Vscode中启动Vue2.x项目运行正常但templete部分UI组件红色波浪线报错 错误示例 原因 Vue - Official 插件升级导致的问题(具体原因有待查询) 解决方案 打开Vscode软件 —> 找到扩展插件 —> 选择Vue - Official —> 安装特定版本(版本 < V2.0.28就行) —> 重…

MongoDB与Pymongo深度实践:从基础概念到无限级评论应用示例

文章目录 前言一、MongoDB1.基本介绍2.概念解析3.常见的数据类型4.Docker 安装5.常用命令 二、Pymongo1.基本操作&#xff08;连接、数据库、集合&#xff09;2.基本操作&#xff08;增删改查&#xff09; 三、MongoDB应用示例&#xff1a;无限级评论1.MongoDB 工具类2.实现无限…

支付宝开放平台-开发者社区——AI 日报「9 月 13 日」

1 OpenAl推出了一个新的大语言模型一 OpenAl o1 前沿技术瞭望官&#xff5c;阅读原文 新的模型主要体现在下面几个方面&#xff0c;思维链&#xff1a;o1在回答问题前会产生一个内部的思维链&#xff0c;这使得它能够进行更深入的推理。强化学习&#xff1a;通过大规模强化学…

【SQL】百题计划:SQL内置函数“LENGTH“的使用

【SQL】百题计划-20240912 方法一&#xff1a; Select tweet_id from Tweets where LENGTH(content) > 15;– 方法二&#xff1a; Select tweet_id from Tweets where CHAR_LENGTH(content)> 15;

Java开发安全及防护

目录 一、开发安全 二、XSS介绍及防范措施 2.1何为XSS 2.2XSS分类 2.3常用方法 三、SQL注入介绍及防范措施 3.1何为SQL注入 3.2常用方法 四、重放介绍及防范措施 4.1何为重放 4.2常用方法 一、开发安全 在学习安全之前&#xff0c;我们首先学习漏洞&#xff0c;知道漏…

前端单独实现 vue 动态路由

前端单独实现 vue 动态路由 Vue 动态路由权限是指在 Vue 应用程序中&#xff0c;根据用户的权限动态生成和控制路由的行为。这意味着不是所有的路由都在应用启动时就被硬编码到路由配置中&#xff0c;而是根据用户的权限信息&#xff0c;在运行时动态地决定哪些路由应该被加载…

如何在 Selenium 中获取网络调用请求?

引言 捕获网络请求对于理解网站的工作方式以及传输的数据至关重要。Selenium 作为一种 Web 自动化工具,可以用于捕获网络请求。本文将讨论如何使用 Selenium 在 Java 中捕获网络请求并从网站检索数据。 我们可以使用浏览器开发者工具轻松捕获网络请求或日志。大多数现代 Web…

Redis访问工具

使用Redis存储缓存数据&#xff0c;如何通过Java去访问Redis&#xff1f; 防止后面看晕&#xff0c;先来张图。 1. Redis的客户端库 Redis的客户端库是Redis官方提供的&#xff0c;用于让Java等编程语言与Redis服务器进行通信的工具包。常见的Redis客户端库有多个&#xff0c…