PyTorch深度学习——数据输入和预处理

devtools/2024/11/12 2:43:45/

pytorch数据载入

数据载入

在使用pytorch构建和训练模型的过程中,需要经常把原始数据(比如图片、音频)转化为张量的格式,为了方便地批量处理图片数据,pytorch引入了一系列工具来对这个过程进行包装

torch.utils.data.DataLoader

pytorch提供的一个用于数据加载的工具类,用于批量加载数据并为模型提供输入。它可以将数据集包装成一个可迭代的对象,方便地进行数据加载和批处理操作

Pytorch torch.utils.data.DataLoader 用法详细介绍-CSDN博客

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,batch_sampler=None, num_workers=0, collate_fn=None,pin_memory=False, drop_last=False, timeout=0,worker_init_fn=None, *, prefetch_factor=2,persistent_workers=False)

参数说明

  • dataset :要从中加载数据的数据集(一个torch.utils.data.DataLoader的实例
  • batch_size:每批次要装载多少样品(迷你批次的大小)
  • shuffle :设置为True以使数据在每个时期都重新洗牌
  • sampler :定义从数据集中抽取样本的策略
  • batch_sampler:类似于采样器sampler,但一次返回一个迷你批次的索引,sampler只返回一个下标索引, 与batch_size,shuffle,sampler和drop_last互斥
  • num_workers :多少个子流程用于数据加载。 0表示将在主进程中加载数据 (默认值:0)
  • collate_fn :把一批 dataset 的实例转化为包含迷你批次数据的张量
  • pin_memory :如果为True,则数据加载器在将张量返回之前将其复制到CUDA固定的内存中。 如果您的数据元素是自定义类型,或者您的collate_fn返回的是一个自定义类型的批处理
  • drop_last :决定是否将最后一个迷你批次的数据丢掉
  • timeout :如果为正,则为从工作人员收集批次的超时值。 应始终为非负数。 (默认值:0)
  • worker_init_fn:如果非None,这个函数将在每个工作子进程上被调用,并接收工作进程ID(一个在[0, num_workers - 1]范围内的整数)作为输入,它在设置随机种子之后、但在数据加载之前被调用。(默认:None)
  • prefetch_factor :每个子流程预先加载的样本数。 2表示将在所有子流程中预取总共2 * num_workers个样本。 (默认值:2)
  • persistent_workers :如果为True,则一次使用数据集后,数据加载器将不会关闭工作进程。 这样可以使Worker Dataset实例保持活动状态。 (默认值:False)

映射类型的数据集

为了能够使用 DataLoader 类,首先需要构造关于单个数据的 torch.ulits.data.Dataset 类,这个类有两种:一种是映射类型(Map-Style),对于这个类型,每个数据都有一个对应的索引,通过输入具体的索引,就能得到对应的数据

import torch.utils.data as dataclass MyDataset(data.Dataset):def __init__(self, data_list):self.data_list = data_listdef __len__(self):return len(self.data_list)def __getitem__(self, index):return self.data_list[index]

一般来说,对于这个类,主要需要重写两个方法:一个是 __getitem__ ,该方法是python内置的操作符方法,对应的操作符是索引操作符 [],通过输入整数数据索引,其大小在0至N-1之间(N微数据的总数目),返回具体的某一条数据记录,这就是该方法需要完成的任务,而具体的逻辑需要根据数据集的类型来决定,另一个方法是 __len__ ,该方法返回数据的总数

在python,如果一个Dataset类重写了该方法,可以通过使用 len 内置函数来获取数据的数目

torchvision工具包的使用

PyTorch:Torchvision的简单介绍与使用-CSDN博客

可迭代类型的数据集

from torch.utils.data import IterableDatasetclass MyIterableDataset(IterableDataset):def __init__(self, file_path):self.file_path = file_pathdef __iter__(self):with open(self.file_path, 'r') as file_obj:for line in file_obj:line_data = line.strip('\n').split(',')yield line_dataif __name__ == '__main__':dataset = MyIterableDataset('test_csv.csv')for data in dataset:print(data)

pytorch模型的保存和加载

序列化和反序列化

由于pytorch的模块和张量的本质是 torch.nn.Module 和 torch.tensor 类的实例,而pytorch自带了一系列的方法,可以将这些类的实例转化为字符串,所以这些势力可以通过python序列化方法进行序列化(serialization)和反序列化(unserialization)

pytorch的实现里集成了python自带的pickle包对模块和张量进行序列化,张量序列化的本质是把张量的信息,包括数据类型和存储位置,以及携带的数据等转化为字符串,而这些字符串时候可以通过使用python自带的文件IO函数进行存储,这个过程是可逆的,即可以通过文件IO函数来读取存储的字符串,然后将字符串逆向解析成pytorch的模块和张量

torch.save(obj,f,pickle_module=pickle,pickle_protocol=2)
torch.load(f,map_location=None,pickle_module=pickle,**pickle_load_args)

torch.save 函数传入的第一个参数是pytorch中可以被序列化的对象,包括模型和张量等,第二个参数是存储文件的路径,序列化的结果将会被保留在这个路径里面,第三个参数是默认的,传入的是序列化的库,可以使用pytorch默认的序列化库pickle,第四个参数是pickle协议,即如何把对象转化为字符串的规范,上述使用的协议版本是2

与 torch.save 函数对应的是 torch.load 函数,该函数在给定序列化后的文件路径之后,就能输出 pytorch 的对象,第一个参数是文件路径之后,第二个参数是张量存储位置的映射,如果存储时的模型在CPU上,可以直接使用默认参数,但当存储的模型在GPU上,torch.load 的默认行为是先把模型载入CPU中,然后转移到保存时的GPU上,加入载入模型的时候是在另外一台计算机上,而计算机没有GPU或GPU的型号对不上就会报错

此时可以使用 map_loactin 函数,设置 map_loactin = 'CPU',这样就会把模型保留在CPU里面,不再移动到GPU中,pickle_module 参数和 torch.save 里的同名参数的作用一致

pytorch中,模型的保存方法有两种,第一种是直接保存模型的实例(因为模型本身可以被序列化),第二种是保存模型的状态字典(State Dict),一个模型的状态字典包含模型所有参数的名字以及名字对应的张量,通过调用 state_dict 方法,就可获取当前模型的状态字典

状态字典的保存和载入

由于pytorch模块的实现依赖具体的pytorch版本,所以会存在一种情况:使用某一个版本保存的序列化文件无法被另一个版本的pytorch载入,相比之下,pytorch的张量变动较小,二状态字典只含有张量参数的名字和张量参数的具体信息,预模块的实现关联较小,因此更加推荐使用 state_dict 方法来获取状态字典,然后保存该张量字典来保存模型,这样可以实现最大限度地减小代码对pytorch版本的依赖性

另外在训练的时候,不仅要保存模型的相关信息,还要保存优化器的相关信息,因为可能需要从存储的检查点出发,继续进行训练,pytorch中参数:当前的学习率,当前梯度的指数移动平均等,通过调用优化器的 state_dict 方法和 load_state_dict 方法,可以让优化器输出和载入相关的状态信息

save_info = { # 保存的信息"iter_num":iter_num,  # 迭代步数"optimizer":optimizer.state_dict,  # 优化器的状态字典"model":model.state_dict(),  # 模型的状态字典
}
# 保存信息
torch.save(save_info,save_path)
# 载入信息
save_info = torch.load(save_path)
optimizer.load_stste_dict(save_info["optimizer"])
model.load_stste_dict(save_info["model"])

pytorch数据可视化

tensorboard是一个数据可视化工具,能直观的显示深度学习中张量的变化,从这个变幻的过程中很容易的可以了解到模型在训练中的行为,包括但不限于损失函数的下降趋势是否合理,张量分量的分布是否在训练过程中发生变化

Pytorch:Tensorboard的安装及常用类的使用【图表+图片方法的使用】-CSDN博客

pytorch进阶 可视化工具TensorBoard的使用_pip install future tensorboard-CSDN博客

PyCharm中TensorBoard的安装和使用_phyton怎么安装 tensorborad-CSDN博客

Tensorboard的使用 ---- SummaryWriter类(pytorch版)-CSDN博客

pytorch模型的并行化

多GPU训练:PyTorch中的数据并行与模型并行-CSDN博客


http://www.ppmy.cn/devtools/31611.html

相关文章

Nacos的开源背景和它的主要贡献者是谁?

在微服务架构的浪潮中,服务注册与发现、动态配置管理等功能日益成为支撑微服务稳定运行的核心组件。而Nacos,作为阿里巴巴开源的一个明星项目,自诞生之初就凭借其强大的功能和灵活性,迅速成为云原生领域的佼佼者。 一、Nacos的开…

使用Python实现二维码生成工具

二维码的本质是什么? 二维码本质上,就是一段字符串。 我们可以把任意的字符串,制作成一个二维码图片。 生活中使用的二维码,更多的是一个 URL 网址。 需要用到的模块 先看一下Python标准库,貌似没有实现这个功能的…

A5资源网有哪些类型的资源可以下载?

A5资源网提供了广泛的资源下载,包括但不限于以下类型: 设计素材:包括各类图标、矢量图、背景素材、UI界面元素等,适用于网页设计、平面设计等领域。 图片素材:提供高质量的照片、插图、摄影作品等,可用于…

计算机组成实验(5)

一、实验目的和要求 1.1 实验目的 1. 复习二进制加减、乘除的基本法则 2. 掌握补码的基本原理和作用 3. 了解浮点数的表示方法及加法运算法则 4. 进一步了解计算机系统的复杂运算操作 1.2 实验要求 1. 熟悉二进制原码补码的概念,了解二进制加减乘除的原理与操作实现。 …

Windows 系统运维常用命令

目标:通过本文可以快速实现windows 网络问题定位。 ipconfig:查看本机网络配置情况 C:\Users\zzg>ipconfigWindows IP 配置以太网适配器 以太网:媒体状态 . . . . . . . . . . . . : 媒体已断开连接连接特定的 DNS 后缀 . . . . . . . :无线局域网适配器 本地…

Ubuntu22安装docker

安装步骤 1. 更新软件包索引 首先,打开终端并更新你的软件包列表以确保访问到最新的软件包版本: sudo apt-get update 2. 安装必要的依赖项 安装几个必需的软件包,这些软件包让apt能够通过HTTPS使用仓库: bash sudo apt-ge…

【深度学习】第一门课 神经网络和深度学习 Week 3 浅层神经网络

🚀Write In Front🚀 📝个人主页:令夏二十三 🎁欢迎各位→点赞👍 收藏⭐️ 留言📝 📣系列专栏:深度学习 💬总结:希望你看完之后,能对…

【Android学习】简易计算器的实现

1.项目基础目录 新增dimens.xml 用于控制全部按钮的尺寸。图片资源放在drawable中。 另外 themes.xml中原来的 <style name"Theme.Learn" parent"Theme.MaterialComponents.DayNight.DarkActionBar">变为了&#xff0c;加上后可针对button中增加图片…