Pytorch 代码复现终极指南【收藏】

news/2025/1/1 9:39:32/

修改自:https://zhuanlan.zhihu.com/p/532511514

我在刚接触pytorch的时候搜到了这个大佬的文章,当时最后天坑部分没有看的太明白,直到今天我也遇到的相同的问题,特来做一点点补充,方便大家理解。

上述大佬文章的简版内容:

  1. 入门版本

Pytorch复现的入门版本就是官方指南,需要设定好各种随机种子。

https://pytorch.org/docs/stable/notes/randomness.html

import random
import numpy as np
import torchrandom.seed(0)  # Python 随机种子
np.random.seed(0)  # Numpy 随机种子
torch.manual_seed(0)  # Pytorch 随机种子
torch.cuda.manual_seed(0)  # CUDA 随机种子
torch.cuda.manual_seed_all(0)  # CUDA 随机种子

2. Dataloader的并行

DataLoader启用多线程时(并行的线程数num_workers 大于1)也会出现随机现象,解决办法:

1. 禁用多线程:num_workers 设置为0。

2. 固定好worker的初始化方式,代码如下:

def seed_worker(worker_id):worker_seed = torch.initial_seed() % 2 ** 32numpy.random.seed(worker_seed)random.seed(worker_seed)g = torch.Generator()g.manual_seed(0)DataLoader(train_dataset,batch_size=batch_size,num_workers=num_workers,worker_init_fn=seed_worker,generator=g,)

3:算法的随机性

有些并行算法带有随机性,比如LSTM或者注意力机制,RNN等。

尤其是使用 CUDA Toolkit 10.2 或更高版本构建 cuDNN 库时,cuBLAS 库中新的缓冲区管理和启发式算法会带来随机性。在默认配置中使用两种缓冲区大小(16 KB 和 4 MB)时会发生这种情况。

解决办法就是在代码头部设置环境变量:

os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'

如果是用到CNN的算法,同时要设置以下变量:

torch.backends.cudnn.benchmark = False  # 限制cuDNN算法选择的不确定性
torch.backends.cudnn.deterministic=True  # 固定cuDNN算法

设置完这些,基本99%的情况下都可以复现结果,如果无法复现,那就重启notebook 或者python。

天坑:for 循环内随机性

如果在一个for 循环内多次运行pytorch训练,就会出现随机性。

以下常见方式均无效

  • 强制每次train之前empty_cache;

  • 每次循环结束后,手动del 变量,并且用gc 回收;

  • 强制初始化模型的参数;

  • 强制设置set_rng_state;(https://discuss.pytorch.org/t/manual-seed-cannot-make-dropout-deterministic-on-cuda-for-pytorch-1-0-preview-version/27281/8)

  • 重启python文件和notebook;

知乎大佬的解决方案:

  1. 上面的随机种子设置最好在for 循环里面设置,否则可能白瞎。

  1. nn模型里面的dropout 在for 循环里面有随机性。解决办法是禁掉dropout或者显式的调用Dropout。

对于该天坑,本文作者的实验结果:

调用一次Dataloader就会影响下一个Dataloader的随机数生成。

解释:例如现在有两种模型的训练方式:

  1. 在train后面继续进行下一个Epoch的train。

  1. train后面进行val,再进行下一个Epoch的train。

这两种方式得到的训练结果从第二个Epoch开始就是不同的,且val前后模型的weights没变,那应该就是生成的随机数变了。因此,应该就是调用一次Dataloader就会有新的随机数。

总结

以上就是Pytorch代码的复现终极指南,保险起见的话,先把能加的都加上,然后看能否复现。

之后如果有强迫症的话,可以做减法,逐个筛检,直到保留必要的代码。

愿天下太平,代码无坑


http://www.ppmy.cn/news/29349.html

相关文章

2022 年度_职业项目总结_Java技术点归纳

Java技术点归纳目录概述需求:设计思路实现思路分析1.Structs 元工程改造2.个贷子系统开发3.架构的迭代开发,升级,部署,参考资料和推荐阅读Survive by day and develop by night. talk for import biz , show your perfect code,fu…

记录第一次接口上线过程

新入职一家公司后,前三天一直在学习公司内部各种制度文化以及考试。 一直到第三天组长突然叫我过去,给了一个需求的思维导图,按照这个需求写这样一个接口, 其实还不错,不用自己去分析需求,按照这上面直接开…

Python 虚拟环境的使用

PyCharm 创建的虚拟环境与使用 workon 命令创建的虚拟环境在本质上没有区别,它们都是 Python 的虚拟环境。 使用 PyCharm 创建工程时,使用可以使用曾经工程的虚拟环境,或者新建一个虚拟环境来安装 Python 的库,又或者使用 workon…

工业机器人编程调试怎么学

很多人觉得工业机器人很难学学,实际上机器人涉及的知识远比PLC要少。现简单说明一下初学者学习工业机器人编程调试的流程,以AUBO机器人为例: 首先我们需要知道工业机器人的调试学起来不难,远比编程更简单,示教器上的编…

【Servlet篇2】Servlet的工作过程,Servlet的api——HttpServletRequest

一、Servlet的工作过程 二、Tomcat的初始化 步骤1:寻找到当前目录下面所有需要加载的Servlet(也就是类) 步骤2:根据类加载的结果创建实例(通过反射),并且放入集合当中 步骤3:实例创建好之后,调用Servlet的init()方…

如何评估模糊测试工具-unibench的使用

unibench是一个用来评估模糊测试工具的benchmark。这个benchmark集成了20多个常用的测试程序,以及许多模糊测试工具。 这篇文章(https://zhuanlan.zhihu.com/p/421124258)对unibench进行了简单的介绍,本文就不再赘诉,…

leetcode 181. 超过经理收入的员工

SQL架构 表:Employee ---------------------- | Column Name | Type | ---------------------- | id | int | | name | varchar | | salary | int | | managerId | int | ---------------------- Id是该表的主键。 该表的…

【面试题】在JS循环中使用await会怎么样?

前言这个问题是这样产生的?某天,在学习异步的知识遇到这样一道题:使用Promise的方式,每隔一秒输出数组中一个值const arr [1, 2, 3] ​ arr.reduce((pre, cur) > {return pre.then(() > {returnnewPromise((resolve, rejec…