pytorch分批加载大数据集

ops/2024/11/1 20:51:53/

pytorch分批加载大数据集

  • 本文处理的数据特点:
  • 加载图片
  • 表格数据

本文处理的数据特点:

(1)数据量大,无法一次读取到内存中
(2)数据是图片或者存储在csv中(每一行是一个sample,包括feature和label)

加载数据集需要继承torch.utils.data 的 Dataset类,并实现 __len__和__getitem__方法。其中

len:返回数据集总数,

getitem:返回指定的数的矩阵和标签。

加载图片

这段代码是一个使用 PyTorch 数据加载和处理机制的例子,主要用于从指定目录加载图片数据,并通过 DataLoader 进行批量处理。

python">from torch.utils.data import Dataset, DataLoader
import torch
import glob
import os
from PIL import Imageclass PictureLoad(Dataset):def __init__(self, paths, size=(10, 10)):self.paths = glob.glob(paths)self.size = sizedef __len__(self):return len(self.paths)def __getitem__(self, item):try:img = Image.open(self.paths[item]).resize(self.size)img_tensor = torch.from_numpy(np.asarray(img)).float() / 255.0  # 转为Tensor并归一化label = os.path.basename(self.paths[item]).split('.')[0]  # 更健壮的文件名提取方式return img_tensor, labelexcept IOError:print(f"Error opening file: {self.paths[item]}")  # 处理文件打开错误return None, Noneif __name__ == '__main__':root_path = os.path.join(os.path.dirname(os.getcwd()), "cap")pic_paths = os.path.join(root_path, '*.jpg')picture = Pictureload(pic_paths)dataloader = DataLoader(picture, batch_size=32, num_workers=2, timeout=2)for a, b in dataloader:print(b, a.shape)  # 输出标签和图片数据的尺寸,而不是原始数据

表格数据

确保数据以分批方式从文件中加载,且不会一次性将所有数据加载到内存中,适合处理大规模数据文件。

python">import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pdclass DataLoad(Dataset):def __init__(self, file_path, batch_size=3):'''初始化函数,设置文件路径和每批读取的数据大小。'''self.file_path = file_pathself.batch_size = batch_sizeself.total_data = self._get_total_len()def _get_total_len(self):'''辅助函数用于计算文件中的数据行数。'''with open(self.file_path, 'r') as file:return sum(1 for line in file)def __len__(self):'''返回数据集的总长度。'''return self.total_datadef __getitem__(self, idx):'''根据索引获取数据,每次从文件中动态加载数据。'''if idx * self.batch_size >= self.total_data:raise IndexError("Index out of range")skip_rows = idx * self.batch_size if idx > 0 else 0df = pd.read_csv(self.file_path, skiprows=skip_rows, nrows=self.batch_size, header=None)data_tensor = torch.tensor(df.values)return data_tensorif __name__ == "__main__":dataset = DataLoad('path_to_your_data.csv', batch_size=32)dataloader = DataLoader(dataset, batch_size=1, shuffle=False, drop_last=False)for epoch in range(3):print(f"Epoch {epoch + 1}")for data in dataloader:print("Data Batch:")print(data)

对于两个batch_size的解释:假设 PretrainData 类每次通过其__getitem__ 方法返回一批数据,即32行数据(根据它的 batch_size=32 设定)。当您使用 DataLoader 并设置其 batch_size 为1时,意味着每次从 DataLoader 迭代得到的数据批将包含从 PretrainData 返回的1个独立批次。因此,每个从 DataLoader 返回的数据批将包含1*32=32条数据。


http://www.ppmy.cn/ops/130233.html

相关文章

【vue项目中添加告警音频提示音】

一、前提: 由于浏览器限制不能自动触发音频文件播放,所以实现此类功能时,需要添加触发事件,举例如下: 1、页面添加打开告警声音开关按钮 2、首次进入页面时添加交互弹窗提示:是否允许播放音频 以上两种方…

江协科技STM32学习- P29 实验- 串口收发HEX数据包/文本数据包

🚀write in front🚀 🔎大家好,我是黄桃罐头,希望你看完之后,能对你有所帮助,不足请指正!共同学习交流 🎁欢迎各位→点赞👍 收藏⭐️ 留言📝​…

【hector mapping参数设置】

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 hector mapping部分参数介绍调整hector_mapping中的参数ros常见问题总结 hector mapping部分参数介绍 在wiki.ros.org/hector_mapping界面找到3.1.4 Parameters章节…

3D Gaussian Splatting代码详解(一):模型训练、数据加载

1 模型训练 这段代码实现了一个 3D 高斯模型的训练循环,旨在通过逐步优化模型参数,使其能够精确地渲染特定场景。以下是代码的详细解析: def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations,…

算法刷题-小猫爬山

本题来源165. 小猫爬山 - AcWing题库 翰翰和达达饲养了 NN 只小猫&#xff0c;这天&#xff0c;小猫们要去爬山。 经历了千辛万苦&#xff0c;小猫们终于爬上了山顶&#xff0c;但是疲倦的它们再也不想徒步走下山了&#xff08;呜咕>_<&#xff09;。 翰翰和达达只好花…

记录|SQL中日期查询出现的问题

目录 前言一、BETWEEN AND问题二、时间带有时分秒更新时间 前言 参考文章&#xff1a; 一、BETWEEN AND问题 假设这是我的表中信息&#xff1a; 我想查询2024-10-16到2024-10-17的数据&#xff0c;理论上用Between and就行&#xff0c;如下所示&#xff1a; SELECT create_t…

scRank从untreated数据推断药物有反应细胞类型

由于细胞簇之间存在异质性&#xff0c;细胞对药物的反应也各不相同。因此&#xff0c;识别对药物有反应的细胞簇对于探索药物作用至关重要&#xff0c;但这仍然是一个巨大的挑战。在这里&#xff0c;作者使用 scRank 解决了这个问题&#xff0c;它采用靶标扰动基因调控网络&…

大数据之VIP(Virtual IP,虚拟IP)负载均衡

VIP&#xff08;Virtual IP&#xff0c;虚拟IP&#xff09;负载均衡是一种在计算机网络中常用的技术&#xff0c;用于将网络请求流量均匀地分散到多个服务器上&#xff0c;以提高系统的可扩展性、可靠性和性能。以下是对VIP负载均衡的详细解释&#xff1a; 一、VIP负载均衡的基…