MicroDiffusion——采用新的掩码方法和改进的 Transformer 架构,实现了低预算的扩散模型

ops/2024/12/27 8:25:10/

介绍

论文地址:https://arxiv.org/abs/2407.15811
现代图像生成模型擅长创建自然、高质量的内容,每年生成的图像超过十亿幅。然而,从头开始训练这些模型极其昂贵和耗时。文本到图像(T2I)扩散模型降低了部分计算成本,但仍需要大量资源。

目前最先进的技术需要大约 18 000 个 A100 GPU 小时,而使用 8 个 H100 GPU 进行训练则需要一个多月的时间。此外,该技术通常依赖于大型或专有数据集,因此难以普及。

在这篇评论性论文中,我们开发了一种低成本、端到端文本到图像扩散建模管道,目的是在没有大型数据集的情况下显著降低成本。它侧重于基于视觉变换器的潜在扩散模型,利用其简单的设计和广泛的适用性。为了降低计算成本,可通过随机屏蔽输入标记来减少每幅图像需要处理的斑块数量。本文克服了现有遮蔽方法在高遮蔽率下性能下降的难题。

为了克服文本到图像扩散模型性能不佳的问题,本文提出了一种 "延迟掩蔽 "策略。通过在轻量级补丁混合器中处理补丁,然后将其输入扩散变换器,可以低成本实现可靠的训练,同时即使在高掩蔽率下也能保留语义信息。它还结合了变压器架构的最新发展,以提高大规模训练的性能。

该实验训练了一个 1.16 亿参数的稀疏扩散变换器,预算仅为 1,890 美元、3,700 万张图像和 75% 的屏蔽率。结果,在 COCO 数据集上零镜头生成的 FID 达到了 12.7。在一台 8×H100 GPU 机器上的训练时间仅为 2.6 天,与目前最先进的方法(37.6 天,GPU 成本 28,400 美元)相比缩短了 14 倍。

建议方法

延迟掩蔽

由于变换器的计算复杂度与序列的长度成正比,降低训练成本的一种方法是通过使用大尺寸补丁来减少序列,如图 1-b 所示。使用大尺寸补丁可使每幅图像的补丁数量呈二次方减少,但由于图像的大区域会被主动压缩成单个补丁,因此会显著降低性能。

一种方法是使用遮罩去除变换器输入层中的一些斑块,如图 1-c 所示,同时保持斑块大小。这种方法类似于卷积网络中的随机裁剪训练,但遮蔽补丁可以在图像的非连续区域进行训练。这种方法被广泛应用于视觉和语言领域。

图 1-d 中的MaskDiT还增加了补充自编码损失,以鼓励从遮蔽的斑块中学习表示法,从而促进遮蔽斑块的重建。这种方法屏蔽了 75% 的输入图像,大大降低了计算成本。

图 1.压缩补丁序列以降低计算成本。

然而,高遮罩率会大大降低转换器的整体性能:即使使用 MaskDiT,也只能看到与简单遮罩相比的微弱改进。这是因为即使采用这种方法,大部分图像斑块也会在输入层被去除。

本文引入了一个名为 "补丁混合器 "的预处理模块,用于在屏蔽之前处理补丁嵌入。这可确保未屏蔽的补丁保留整个图像的信息,从而提高学习效率。这种方法有可能提高性能,同时在计算上与现有的 MaskDiT 策略相当。

补丁混合器和学习障碍

补片混合器指的是任何能够融合单个补片嵌入的神经架构。在变换器模型中,这一目标自然可以通过注意机制和前馈层的结合来实现。因此,本文使用轻量级变压器(只有几层)作为补片混合器。输入序列标记经补丁混合器处理后,将被屏蔽(图 2e)。假定掩码为二进制 m,则使用以下损失函数对模型进行训练。

引入专家混合(MoE)和分层缩放的变换器架构

论文采用了先进变压器结构的创新技术,在计算受限的情况下提高了模型性能。

  • 专家混合物(MoE,Zhou 等人,2022 年):使用 MoE 层扩展模型的参数和表现力,同时避免训练成本的显著增加 简化的 MoE 层与专家选择路由允许额外的辅助损失函数无需调整负载。
  • 分层缩放 (Mehta 等人,2024 年 ):这种方法已被证明能提高大型语言模型的性能,其中变换器块的宽度(隐藏层的维度)随深度线性增加。更多的参数被分配给更深的层,以学习更复杂的特征。

整体架构如图 2 所示。

图 2:拟议方法的总体概览。

试验

验证延迟遮蔽和补丁混合器的有效性

当许多补丁被遮蔽时,遮蔽性能会下降;Zheng 等人(2024 年)指出,当遮蔽率超过 50%,MaskDiT 的性能会显著下降。本文评估了遮蔽率高达 87.5% 时的性能,并将其与不使用补丁混合器的传统天真遮蔽方法进行了比较。本文中的 "延迟掩蔽 "使用了一个四层变压器块贴片混合器,其参数小于主干变压器参数的 10%。两者都使用了设置完全相同的 AdamW 优化器。

图 3 对结果进行了总结。延迟掩蔽在所有指标上都明显优于天真掩蔽和MaskDiT,表明随着掩蔽率的增加,性能差异也在扩大。例如,在屏蔽率为 75% 时,原始屏蔽的 FID 得分为 80,MaskDiT 的FID 得分为16.5,而拟议方法的 FID 得分为 5.03,优于未屏蔽时的 3.79。

图 3:验证延迟屏蔽和贴片混频器的有效性。

验证 "专家混合 "和 "分层缩放 "的有效性

分层缩放: 使用 DiT-Tiny 架构进行的实验比较了分层缩放和恒宽变换器与天真屏蔽。两个模型都在相同的计算负荷下进行了相同时间的训练。在所有性能指标上,逐层缩放方法始终优于恒定宽度模型,而且在屏蔽训练中更为有效。

专家混合物(MoE): 测试了在交替区块中具有 MoE 层的 DiT-Tiny/2 变压器。总体性能与无 MoE 层的基线模型相似,Clip-score 略有提高(从 28.11 到 28.66),FID 分数有所下降(从 6.92 到 6.98)。改进幅度有限的原因是 60K 步的训练和每位专家看到的样本量较小。

与以往研究的比较

COCO 数据集(表 1)上的零镜头图像生成: 根据标题生成 30 000 幅图像,并使用 FID-30K 比较其与真实图像的分布情况。拟议方法的FID-30K 得分为 12.66,与之前的低成本训练方法相比,计算成本降低了 14 倍,而且不依赖于专有数据集。该方法的计算成本也比 Würstchen 低 19 倍(Pernias et al.)

表 1:在 COCO 数据集上生成零镜头图像

详细图像生成比较**(表 2)****:**GenEval(Ghosh 等人,2024 年)用于评估生成物体位置、共现、数量和颜色的能力。与 Stable-DiffusionXL-turbo 和 PixArt-α 模型相比,拟议方法在单个物体生成方面的准确性接近完美,与 Stable-Diffusion 变体相当,优于 Stable-Diffusion-1.5。在以下方面也表现出卓越的性能

表 2.详细图像生成对比

总结

本评论文章重点讨论了旨在降低扩散变换器训练计算成本的补丁掩蔽策略。本文提出了一种 "延迟掩蔽 "策略,以缓解现有掩蔽方法的不足,并显示了所有掩蔽比率下的显著性能改进。

特别是使用了 75% 的延迟掩蔽率,并在真实和合成图像数据集上进行了大规模训练。尽管与最先进的技术相比成本大大降低,但还是取得了具有竞争力的零镜头图像生成性能结果。希望这种低成本的训练机制能鼓励更多研究人员参与大规模扩散模型的训练和开发。


http://www.ppmy.cn/ops/145328.html

相关文章

每天40分玩转Django:Django部署概述

一、Django部署概述 在开发阶段,我们通常使用Django内置的轻量级开发服务器runserver。但在生产环境中,为了应对大量并发请求,需要使用高性能的WSGI服务器,如Gunicorn、uWSGI等。同时还要配置Nginx等Web服务器作为反向代理,实现负载均衡、静态文件处理等。下面是Django部署的整…

Java数组深入解析:定义、操作、常见问题与高频练习

一、数组的定义 1. 什么是数组 数组是一个容器,用来存储多个相同类型的数据。它属于引用数据类型,可以存储基本数据类型(如int、char)或者引用数据类型(如String、对象)。 2. 数组的定义方式 a. 动态初…

windows11家庭版安装docker无法识别基于wsl2的Ubuntu

软件环境:windows11家庭版安装WSL2,Ubuntu22.04,docker4.34.2 问题描述:安装docker时,设置阶段无法识别Ubuntu22.04. 原因:windows11家庭版本默认没有Hyper-V 解决方案:将下述代码保存在新建记事本中&am…

python+reportlab创建PDF文件

目录 字体导入 画布写入 创建画布对象 写入文本内容 写入图片内容 新增页 画线 表格 保存 模板写入 创建模板对象 段落及样式 表格及样式 画框 图片 页眉页脚 添加图形 构建pdf文件 reportlab库支持创建包含文本、图像、图形和表格的复杂PDF文档。 安装&…

JWT认证机制在Node.js中的详细阐述

一、概念 JWT(JSON Web Token)是一种基于Token的认证机制,它允许服务器无状态地验证用户身份。JWT是一个开放标准(RFC 7519),它定义了一种简洁的、自包含的用于各方之间安全传输信息的JSON对象。JWT通常被…

论文研读:AnimateDiff—通过微调SD,用图片生成动画

1.概述 AnimateDiff 设计了3个模块来微调通用的文生图Stable Diffusion预训练模型, 以较低的消耗实现图片到动画生成。 论文名:AnimateDiff: Animate Your Personalized Text-to-Image Diffusion Models without Specific Tuning 三大模块: 视频域适应…

vue前端项目中实现电子签名功能(附完整源码)

文章目录 一、具体思路二、所需依赖三、添加签名面板2.1 canvas 转base642.2 电子签名等比例缩小 四、html转cavas(原始文档)五、合成图片六、效果测试七、完整源码 一、具体思路 在vue项目中使用以下步骤思路去实现: 起初的原始文档的格式都…

Scala课堂小结

(一)数组: 1.不可变数组 2.创建数组