从简单到复杂,训练神经网络的秘诀

ops/2024/11/9 16:45:48/

据称,开始训练神经网络非常简单。许多库和框架都以展示 30 行神奇的代码片段来解决问题而自豪,给人一种这些东西是即插即用的(错误)印象。常见的情况如下:

>>> your_data = # plug your awesome dataset here
>>> model = SuperCrossValidator(SuperDuper.fit, your_data, ResNet50, SGDOptimizer)

#conquer world here
这些库和示例激活了我们大脑中熟悉标准软件的部分 - 通常可以获得干净的 API 和抽象。请求库演示:

>>> r = requests.get('https://api.github.com/user', auth=('user', 'pass'))
>>> r.status_code
200

不幸的是,神经网络并非如此。当您稍微偏离训练 ImageNet 分类器时,它们就不是“现成的”技术。

当您破坏或错误配置代码时,通常会遇到某种异常。您插入了一个整数,而某个函数需要字符串。该函数只需要 3 个参数。此导入失败。该键不存在。两个列表中的元素数量不相等。此外,通常可以为某个功能创建单元测试。

这只是训练神经网络的开始。从语法上来说,一切都可能是正确的,但整个过程并没有得到正确的安排,而且很难分辨。“可能的错误面”很大,符合逻辑(而不是语法),并且很难进行单元测试。例如,也许你在数据增强期间左右翻转图像时忘记翻转标签。你的网络仍然可以(令人震惊地)很好地工作,因为你的网络可以内部学习检测翻转的图像,然后左右翻转它的预测。或者,也许你的自回归模型由于一次错误而意外地将它试图预测的东西作为输入。或者你试图削减你的梯度,但却削减了损失,导致训练期间异常示例被忽略。或者你从预训练检查点初始化了你的权重,但没有使用原始平均值。或者你只是搞砸了正则化强度、学习率、衰减率、模型大小等的设置。因此,只有在幸运的情况下,配置错误的神经网络才会抛出异常;大多数时候它会进行训练,但默默地工作得更糟。

因此,“快速而激烈”的神经网络训练方法行不通,只会导致痛苦。现在,痛苦是让神经网络良好运作的一个非常自然的部分,但可以通过彻底、防御、偏执和痴迷于可视化几乎所有可能的事情来减轻痛苦。深度学习成功最密切相关的品质是耐心和对细节的关注。

应对方案

从简单到复杂,对每一步中将会发生什么做出具体的假设,然后通过实验验证它们或进行调查,直到发现一些问题。要极力避免的是一次性引入大量“未经验证”的复杂性,这必然会引入错误/错误配置,而这些错误/配置将需要很长时间才能找到(如果永远找不到的话)。如果编写神经网络代码就像训练代码一样,你会希望使用非常小的学习率,并在每次迭代后猜测然后评估完整的测试集。

1. 与数据融为一体

训练神经网络的第一步是完全不接触任何神经网络代码,而是从彻底检查数据开始。这一步至关重要。扫描数据,了解它们的分布并寻找模式。是否包含重复的示例,是否有损坏的图像/标签。寻找数据不平衡和偏差。

一旦你有了定性分析,编写一些简单的代码来搜索/过滤/排序,无论你想到什么(例如标签类型、注释大小、注释数量等),并可视化它们的分布和沿任何轴的异常值。尤其是异常值几乎总是能揭示数据质量或预处理中的一些错误。

2. 建立端到端的训练/评估框架

现在了解了数据,下一步是建立一个完整的训练 + 评估框架,并通过一系列实验获得对其正确性的信任。在这个阶段,最好选择一些简单的模型,训练它,可视化损失/任何其他指标(例如准确性)、模型预测,并在此过程中使用明确的假设执行一系列实验。

此阶段的提示和技巧:

  • 固定随机种子。始终使用固定随机种子来保证当您运行代码两次时,您将获得相同的结果。这消除了变化因素并有助于让您保持理智。
  • 简化。确保禁用任何不必要的花哨功能。例如,在此阶段一定要关闭任何数据增强。数据增强是一种正则化策略,我们可能会在以后纳入,但现在它只是引入一些愚蠢错误的另一个机会。
  • 在您的 eval 中添加有效数字。绘制测试损失时,对整个(大型)测试集进行评估。不要只绘制批次上的测试损失,然后依靠 Tensorboard 对其进行平滑处理。我们追求正确性,非常愿意放弃时间来保持理智。
  • 验证损失 @ init。验证您的损失是否从正确的损失值开始。例如,如果您正确初始化最后一层,则应-log(1/n_classes)在初始化时测量 softmax。可以为 L2 回归、Huber 损失等得出相同的默认值。
  • 初始化好。正确初始化最终层的权重。例如,如果您要回归一些平均值为 50 的值,则将最终偏差初始化为 50。如果您有一个不平衡的数据集,其正数:负数的比例为 1:10,则在 logits 上设置偏差,以便您的网络在初始化时预测概率为 0.1。正确设置这些将加快收敛速度​​并消除“曲棍球棒”损失曲线,在前几次迭代中,您的网络基本上只是在学习偏差。
  • 人类基线。监控除损失之外的可人类解释和检查的指标(例如准确性)。尽可能评估您自己的(人类)准确性并与之进行比较。或者,对测试数据进行两次注释,对于每个示例,将一个注释视为预测,将第二个注释视为基本事实。
  • 独立于输入的基线。训练独立于输入的基线(例如,最简单的方法是将所有输入设置为零)。这应该比您实际插入数据而不将其清零时的表现更差。是吗?即您的模型是否学会从输入中提取任何信息?
  • 过拟合一个批次。过拟合一个只有几个示例的批次(例如少至两个)。为此,我们增加了模型的容量(例如添加层或过滤器)并验证我们是否可以达到最低的可实现损失(例如零)。我还喜欢在同一个图中可视化标签和预测,并确保一旦达到最小损失,它们最终会完美对齐。如果它们不对齐,则说明某处有错误,我们无法继续下一阶段。
  • 验证训练损失是否减少。在此阶段,由于您正在使用玩具模型,因此您的数据集可能会出现拟合不足的情况。尝试稍微增加其容量。您的训练损失是否按预期下降?
  • 在网络之前进行可视化。可视化数据的正确位置绝对就在网络之前y_hat = model(x)(或sess.run在 tf 中)。也就是说,您希望准确地可视化进入网络的内容,将原始数据张量和标签解码为可视化内容。这是唯一的“真相来源”。我数不清有多少次这拯救了我,并揭示了数据预处理和增强中的问题。
  • 可视化预测动态。我喜欢在训练过程中可视化固定测试批次上的模型预测。这些预测如何移动的“动态”将为您提供有关训练进展的非常好的直觉。很多时候,如果网络以某种方式摆动过多,就会感觉到网络“难以”适应您的数据,从而暴露出不稳定性。非常低或非常高的学习率也很容易在抖动量中被注意到。
  • 使用反向传播来绘制依赖关系图。您的深度学习代码通常包含复杂、矢量化和广播操作。我遇到过几次的一个相对常见的错误是人们会弄错这一点(例如他们使用view而不是transpose/permute某处)并无意中混合了批次维度上的信息。令人沮丧的事实是,您的网络通常仍会训练正常,因为它会学会忽略来自其他示例的数据。调试此问题(和其他相关问题)的一种方法是将损失设置为一些微不足道的东西,例如示例i的所有输出的总和,一直向后传递到输入,并确保仅在第i 个输入上获得非零梯度。相同的策略可用于例如确保您的自回归模型在时间 t 仅依赖于 1…t-1。更一般地说,梯度为您提供了有关网络中什么依赖于什么的信息,这对于调试很有用。
  • 概括特殊情况。这更像是一个通用的编码技巧,但我经常看到人们在从头开始编写相对通用的功能时,会犯错误。我喜欢为我现在正在做的事情编写一个非常具体的函数,让它工作,然后稍后将其概括,确保得到相同的结果。这通常适用于矢量化代码,我几乎总是先写出完全循环的版本,然后一次一个循环地将其转换为矢量化代码。
3. 过度拟合

现在,阶段已经准备好迭代一个好的模型。可以采用两个阶段来寻找好的模型:首先获得一个足够大的模型,使其能够过拟合(即专注于训练损失),然后对其进行适当的正则化(放弃一些训练损失以改善验证损失)。

此阶段的一些提示和技巧:

  • 选择模型。为了达到良好的训练损失,需要为数据选择合适的架构。将神经网络工具箱的乐高积木堆叠成各种的奇特架构,在项目的早期阶段要坚决抵制这种诱惑,而是应该找到最相关的,最简单的、能实现良好性能的架构。例如,如果你正在对图像进行分类,只需在第一次运行时复制粘贴 ResNet-50。然后可以做一些更自定义的事情并击败它。

  • adam 是安全的。在设定基线的早期阶段,建议使用学习率为3e-4的 Adam 。Adam 对超参数的容忍度更高,包括糟糕的学习率。对于 ConvNets,经过良好调整的 SGD 几乎总是会略胜 Adam,但最佳学习率区域要窄得多,而且针对具体问题。(注意:如果您使用的是 RNN 和相关序列模型,则更常使用 Adam。在项目的初始阶段,再次强调,不要妄自尊大,而要遵循最相关的论文。)

  • 每次只复杂化一个。如果有多个信号要插入分类器,建议逐个插入它们,并确保每次都能获得预期的性能提升。不要一开始就将整个模型都扔到水槽里。还有其他方法可以增加复杂性 - 例如,您可以尝试先插入较小的图像,然后再将它们放大,等等。

  • 不要相信学习率衰减默认值。如果你正在重新利用来自其他领域的代码,请始终非常小心学习率衰减。 在典型的实现中,计划将基于当前的时期数,这可能会根据数据集的大小而有很大差异。例如,ImageNet 在第 30 个时期会衰减 10。如果不小心,你的代码可能会偷偷地将你的学习率过早地降至零,从而不允许你的模型收敛。

4. 正则化

理想情况下,我们现在拥有了一个至少适合训练集的大型模型。现在是时候对其进行正则化并通过放弃一些训练精度来获得一些验证精度。一些提示和技巧:

  • 获取更多数据。首先,在任何实际环境中,迄今为止最好的和首选的正则化模型的方法是添加更多真实训练数据。花费大量工程周期试图从小数据集中榨取价值是一个非常常见的错误,添加更多数据几乎是无限期地单调提高配置良好的神经网络性能的唯一保证方法。另一种方法是集成(如果你能负担得起的话),但在大约 5 个模型之后就达到顶峰了。
  • 数据增强。除了真实数据之外,最好的选择就是半假数据 - 尝试更积极的数据增强。
  • 创造性增强。如果半假数据不起作用,假数据也可能起作用。人们正在寻找扩展数据集的创造性方法;例如,域随机化、使用模拟、巧妙的混合(例如将(潜在模拟的)数据插入场景)甚至 GAN。
  • 预训练。如果可以的话,使用预训练网络几乎不会有什么坏处,即使你- 有足够的数据。
  • 坚持监督学习。不要对无监督预训练过度兴奋。
  • 较小的输入维度。删除可能包含杂散信号的特征。如果数据集较小,任何添加的杂散输入都只是过度拟合的另一个机会。同样,如果低级细节并不重要,请尝试输入较小的图像。
  • 更小的模型尺寸。在许多情况下,您可以使用网络领域的知识约束来减小其尺寸。例如,过去在 ImageNet 的主干顶部使用全连接层是一种流行做法,但这些层后来被简单的平均池化所取代,从而消除了大量的参数。
  • 减小批量大小。由于批量规范中的规范化,较小的批量大小在某种程度上对应于更强的正则化。这是因为批量经验平均值/标准差是完整平均值/标准差的更近似版本,因此比例和偏移会使您的批量“摆动”更多。
  • drop。添加 dropout。对 ConvNets 使用 dropout2d(空间 dropout)。请谨慎/小心地使用它,因为 dropout似乎与批量标准化不太兼容。
  • 权重衰减。增加权重衰减惩罚。
  • 提前停止。根据测量到的验证损失停止训练,以便在模型即将过度拟合时抓住它。
  • 尝试更大的模型。我最后提到这一点,也是在提前停止之后,但我过去曾发现过几次,较大的模型最终当然会过度拟合,但它们的“提前停止”性能通常比较小的模型好得多。

最后,为了进一步确信得到的网络是一个合理的分类器,可以将网络的第一层权重可视化,并确保得到合理的边缘。如果的第一层过滤器看起来像噪音,那么可能有些不对劲。同样,网络内的激活有时会显示奇怪的伪影并暗示存在问题。

5. 调参

现在,应该已经“掌握”了数据集,并探索了广泛的模型空间,以找到可实现低验证损失的架构。此步骤的一些提示和技巧:

  • 随机搜索优于网格搜索。对于同时调整多个超参数,使用网格搜索来确保覆盖所有设置听起来很诱人,但请记住,最好使用随机搜索。直观地说,这是因为神经网络通常对某些参数比其他参数更敏感。在极限情况下,如果参数a很重要,但改变b没有效果,那么您宁愿更彻底地对a进行采样,而不是在几个固定点多次采样。
  • 超参数优化。目前有大量的贝叶斯超参数优化工具箱,我的一些朋友也报告说他们使用它们取得了成功,但我的个人经验是,探索良好而广泛的模型和超参数空间的最先进的方法是使用实​​习生 😃。开玩笑的。

6. 最后的手段
一旦找到了最佳类型的架构和超参数,你仍然可以使用一些技巧来从系统中进行改进:

  • 集成。模型集成几乎可以保证在任何方面获得 2% 的准确率。如果您在测试时无法承担计算量,请考虑使用暗知识将您的集成提炼成网络。
  • 让它继续训练。我经常看到人们在验证损失似乎趋于平稳时试图停止模型训练。而实际上,网络会持续训练很长时间。

http://www.ppmy.cn/ops/96789.html

相关文章

第二届海南大数据创新应用大赛 - 算法赛道冠军比赛攻略_海南新境界队

关联比赛: 第二届海南大数据创新应用大赛 - 智能算法赛 第二届海南大数据创新应用大赛 - 算法赛道冠军比赛攻略 首先很幸运能拿到这次初赛冠军,本着积极学习和提升自我的态度,团队成员通力合作是获胜关键,再次感谢。 赛题背景分析和理解 …

Vue 3 组合式 API 中的 nextTick 深入解析

Vue.js 是一个渐进式 JavaScript 框架,以其易学、高效和灵活的特点,成为构建交互式 Web 界面的理想选择。Vue 3 通过一系列性能提升、架构重构和改进开发体验等优点,进一步提高了 Vue.js 的优越性。在 Vue 3 中,组合式 API&#x…

第2章 C语言基础知识

第2章 C语言基础知识 1.printf()函数 在控制台输出数据,需要使用输出函数,C语言常用的输出函数为printf()。 printf()函数为格式化输出函数,其功能是按照用户指定的格式将数据输出到屏幕上。 printf(“格式控制字符串”,[输出列表]); 格式控…

基于Docker compose部署Confluence 8.3.4及设置数据持久化存储的总结

基于Docker compose部署Confluence 8.3.4及设置数据持久化存储的总结 一、环境信息二、安装部署三、向导 介绍如何基于Docker、Docker Compose的方式安装部署Confluence 8.3.4,并且设置数据的持久化存储。 一、环境信息 操作系统:CentOS 7.9 Docker Ver…

Redis系列之事务

概述 Redis事务提供一种将多个命令打包,然后一次性、按顺序地执行的机制,在事务执行的期间不会主动中断,服务器在执行完事务中的所有命令之后,才会继续处理其他客户端的其他命令。 三个重要的保证: 批量操作在发送E…

产品分析 | 便利蜂

​产品信息 产品名称:便利蜂 Slogan:小小的幸福 在你身边 版本号:V1.11.3 大小:23.6M 体验环境:Android6.0.1 品牌概述 便利蜂成立于2016年12月,算是起步较早的企业了,17年2月就开了第一家…

基于单片机的 GPS 信息处理系统

摘 要 : 介绍一种基于单片机的 GPS 信息处理系统 。 以 AT MEL 公司的单片机 AT 89C2051 作为核心控制器件 , LCD和键盘作为人机界面, 通过串行口接收 GPS 接收机输出的 NMEA 全球定位系统 ( Global Positioning System, GPS) 是美国从 20 世纪 70 年代开始研制…

基于微信小程序的课堂考勤系统的设计与实现(论文+源码)_kaic

基于微信小程序的课堂考勤系统的设计与实现 摘 要 在高校教育普及的今天,学生人数日益增多,为保证课堂质量,教师多要在课前进行考勤。因此本设计提出基于微信小程序的课堂考勤系统,增加了定位功能,避免了“假打卡”…