pytorch中的基础数据集

server/2025/3/16 4:42:35/
  1. 数据是深度学习核心之一
  2. pytorch基础数据集介绍
  3. 加载/读取/显示/使用
  4. 代码演示与解释

常见的数据集Pascal VOC/COCO

DataLoader

DataLoader( dataset,

  • 含义:指定要加载的数据集,它必须是 torch.utils.data.Dataset 类的子类实例。Dataset 类定义了如何获取单个样本以及数据集的长度,DataLoader 会基于这个数据集进行数据的批量加载。

batch_size=1,

  • batch_size
    • 含义:每个批次加载的样本数量,默认值为 1。在训练神经网络时,通常会将多个样本组成一个批次一起输入到模型中,以提高训练效率和稳定性。
    • 示例:dataloader = DataLoader(dataset, batch_size=10)

shuffle=False,

  • 含义:布尔类型参数,若设置为 True,则在每个 epoch 开始时打乱数据集的顺序,有助于模型更好地学习数据的分布,提高模型的泛化能力,默认值为 False
  • 示例dataloader = DataLoader(dataset, batch_size=10, shuffle=True)

sampler=None,

  • sampler
  • 含义:自定义采样器,用于指定如何从数据集中采样样本。如果指定了 sampler,则 shuffle 参数将被忽略。采样器必须是 torch.utils.data.Sampler 类的子类实例。

batch_sampler=None,

batch_sampler

  • 含义:自定义批次采样器,用于指定如何将样本组合成批次。如果指定了 batch_sampler,则 batch_sizeshufflesampler 和 drop_last 参数将被忽略。批次采样器必须是 torch.utils.data.BatchSampler 类的子类实例。

num_workers=0,

  • num_workers
    • 含义:用于数据加载的子进程数量,设置为大于 0 的值可以实现并行数据加载,提高数据加载速度。默认值为 0,表示在主进程中加载数据。在使用多进程加载时,需要注意内存和 CPU 资源的使用情况。
    • 示例:dataloader = DataLoader(dataset, batch_size=10, num_workers=4)

collate_fn=None,

collate_fn

  • 含义:用于将多个样本组合成一个批次的函数。默认情况下,DataLoader 会使用一个简单的默认函数将样本组合成批次,但在处理一些特殊的数据类型或结构时,可能需要自定义 collate_fn 函数。

pin_memory=False, drop_last=False,

  • pin_memory
    • 含义:布尔类型参数,若设置为 True,则会将数据样本在返回前固定到 CUDA 页锁定内存中,这样可以加快数据从 CPU 到 GPU 的传输速度,适用于使用 GPU 进行训练的场景,默认值为 False
    • 示例dataloader = DataLoader(dataset, batch_size=10, pin_memory=True)

timeout=0, worker_init_fn=None,

  • timeout
    • 含义:在多进程数据加载时,等待从工作进程获取数据的超时时间(单位:秒)。如果在指定的时间内没有获取到数据,则会抛出 TimeoutError 异常。默认值为 0,表示不设置超时时间。
    • 示例dataloader = DataLoader(dataset, batch_size=10, num_workers=4, timeout=1)

multiprocessing_context=None,

multiprocessing_context

  • 含义:指定多进程的上下文,可以是 'fork''spawn' 或 'forkserver' 等。不同的上下文适用于不同的操作系统和场景,默认情况下会根据操作系统自动选择合适的上下文。

generator=None,

generator

  • 含义:用于生成随机数的生成器,可用于控制数据的打乱顺序等随机操作。可以传入一个 torch.Generator 实例,方便复现实验结果。

prefetch_factor=2,

prefetch_factor

  • 含义:每个工作进程预取的样本数量,默认值为 2。在多进程数据加载时,工作进程会预先加载一定数量的样本,以提高数据加载的连续性。

persistent_workers=False )

persistent_workers

  • 含义:布尔类型参数,若设置为 True,则在每个 epoch 结束后,工作进程不会被销毁,而是保持存活状态,这样可以减少进程创建和销毁的开销,提高数据加载效率,默认值为 False

  • torch.utils.data.Dataset的子集
  • torch.utils.data.DataLoader加载数据集


http://www.ppmy.cn/server/175325.html

相关文章

PyTorch多机训练Loss不一致问题排查指南:基于算子级一致性验证

比较二次训练过程中所有算子的误差,定位存在一致性问题的pytorch算子 一.背景二.技术方案1.核心思路2.关键技术点 三.代码 一.背景 在分布式训练场景中,观察到以下现象: 相同超参配置下,多次训练的Loss曲线存在显著差异(波动幅度…

移远通信联合德壹发布全球首款搭载端侧大模型的AI具身理疗机器人

在汹涌澎湃的人工智能浪潮中,具身智能正从实验室构想迈向现实应用。移远通信凭借突破性的端侧AI整体解决方案,为AI机器人强势赋能,助力其实现跨行业拓展,从工业制造到服务接待,再到医疗康养,不断改写各行业…

Qt信号与槽

1.信号与槽概述 在Qt中,用户和控件的每一次交互过程称为一个事件。比如“用户点击按钮”是一个事件,“用户关闭窗口”也是一个事件。 每个事件都会发出一个信号。例如用户点击按钮会发出“按钮被点击”的信号,用户关闭窗口会发出“窗口被关闭…

Android7上移植I2C-tools

一,下载源码 cd hardware/libhardware/tests git clone https://git.kernel.org/pub/scm/utils/i2c-tools/i2c-tools.git 二, 在 i2c-tools 目录添加 Android.mk 编译文件 LOCAL_PATH: $(call my-dir)################### i2c-tools ###############…

HCIA-11.以太网链路聚合与交换机堆叠、集群

链路聚合背景 拓扑组网时为了高可用,需要网络的冗余备份。但增加冗余容易后会出现环路,所以我们部署了STP协议来破除环路。 但是,根据实际业务的需要,为网络不停的增加冗余是现实需要的一部分。 那么,为了让网络冗余…

随笔小记-本人常用桌面应用(流程图-boardmix,截图-snipaste,文件比较-beyond compare,远程控制-向日葵,解压-360压缩)

1.流程图绘画-boardmix 2.快捷截图-snipaste 3.文件与文件夹比较工具(比较文件内容差异结构差异,可合并)-beyond compare 4.远程控制-向日葵远程控制 5.压缩包的解压缩-360压缩

微信小程序threejs三维开发

微信小程序threejs开发 import * as THREE from three; const { performance, document, window, HTMLCanvasElement, requestAnimationFrame, cancelAnimationFrame, core, Event, Event0 } THREE .DHTML import Stats from three/examples/jsm/libs/stats.module.js; im…

影刀RPA安装32位与64位的差别

1. 影刀RPA概述 1.1 产品简介 影刀RPA是一款由杭州分叉智能科技有限公司研发的RPA自动化软件,致力于为各行业客户提供RPA自动化机器人产品与解决方案,能够实现PC、手机上的任何软件自动化操作。其功能特性丰富,包括桌面软件自动化、网页自动…