【论文阅读】Scalable Diffusion Models with Transformers

news/2025/2/11 9:02:14/

DiT:基于transformer架构的扩散模型。

paper:[2212.09748] Scalable Diffusion Models with Transformers (arxiv.org)

code:facebookresearch/DiT: Official PyTorch Implementation of "Scalable Diffusion Models with Transformers" (github.com)

有空看

1. 介绍

对于扩散模型来说,自2020年DDPM诞生以来,连续3年的工作仍然延续最初的经典U-Net架构,在网络结构设计上仍依赖早期的研究经验,有着很大的提升空间;

而Transformer一直被诟病的则是其“错误累积”问题,简单来说,错误扩散来源于Transformer“预测下一个词”的生成模式,如果说前面生成的词出现了错误,那么模型在生成后续的词时会“将错就错”,进而导致误差的累积,扩散模型由于同时对所有的像素去除噪声(这种范式我们称为非自回归,non-autoregressive),从生成范式上规避了这一问题。

如何同时解决好二者的存在的缺陷,成为了一个很好的研究课题。扩散模型基于早期工作的经验,在网络结构设计上仍有很大的提升空间。而这篇工作在隐空间扩散模型范式的启发下,成功将扩散模型中经典的U-Net结构替换成了Transformer,在进一步提升网络架构复杂度的前提下,能够显著提升生成图片的质量。

2. 方法

图3。DiT架构。左:我们训练条件潜在DiT模型。输入潜信号被分解成小块,并由多个DiT块进行处理。右:DiT块的详细信息。我们尝试了各种标准的变压器模块,这些模块通过自适应层规范、交叉注意和额外的输入令牌结合了条件反射。自适应层规范效果最好。 

DiT既有着扩散模型对图片加噪、去噪的特殊机制,又同时有Transformer强大的自注意力机制,以及Transformer“预测下一个词”的特点。给定输入图片时,DiT首先通过扩散模型标准的加噪过程对图像压缩后的特征进行污染,将带噪特征、条件特征、ground truth对应的特征拼接在一起输入Transformer后输出结果,完成一次DiT的前传过程。

  • DiT模块:在完成“块化”的操作之后,下一步要做的就是输入DiT进行相应的运算。DiT由DiT 模块的基本单元构成,其中,DiT模块中的大多数元素类似于标准的Transformer模块,包括多头自注意力机制(Multi-Head Self-Attention)、Layer Normalization、Feed Forward Layer等等。其中,为了融合外部的条件控制,DiT有三种变种形式,分别与In-Context Conditioning、Cross-Attention、adaLN-Zero相组合,它们都用于融合外部的标签条件,对应Diffusion Transformer模型架构图中由右到左的顺序:
    • In-Context Conditioning:从字面意思上来看,In-Context Conditioning可以翻译为”上下文条件化“。其实就是将条件拼接在输入词的后面。前面我们说到,Transformer的输出过程其实是在做“预测下一个词”,而in-context conditioning其实是给这一过程加上了一个前缀,在“预测下一个词”的过程中,模型会持续收到这个前缀的作用。具体来说,在DiT这篇工作的设定中,这个条件对应ImageNet中图片的类别,在模型生成图片“词”的时候,模型就会在知道生成图片的类别的前提之下完成输出过程。从技术层面上来看,这个设计跟ViT中的cls token大同小异。
    • Cross-Attention:跨注意力机制其实很简单,类似于经典latent diffusion models中U-Net的跨注意力机制,将条件对应的特征作为注意力机制的K和V,以图片特征作为Q进行运算,从而达到将条件融入图片生成过程中的效果。
    • adaLN-Zero:这个模块是这篇工作中的另一创新点,是针对Transformer原本layer normalization在图像生成任务上的一个创新。具体做法抛弃了layer norm原来直接学习增益(scale)和偏置(shift)的做法,而是通过自适应地学习加权系数(图中的 α1,α2,β1,β2),加权系数将输入条件的特征处理后,再加到每层layer norm的增益和偏置中去,以此完成条件的融合。
      • 这里adaLN-Zero的设计其实感觉跟SPADE[6]的思路类似。SPADE的提出是为了融合分割图的条件输入而提出了,其做法是将分割图处理成可学习的增益和偏置,再将增益和偏置加权到图像特征上,完成条件的融合。可以看到,SPADE和adaLN-Zero的异曲同工之妙,说明增益和偏置是融入条件信号是一个有效的方式。

参考:Diffusion Transformer Family:关于Sora和Stable Diffusion 3你需要知道的一切 - 知乎 (zhihu.com)


http://www.ppmy.cn/news/1391507.html

相关文章

主干网络篇 | YOLOv8更换主干网络之ShuffleNetV2

前言:Hello大家好,我是小哥谈。ShuffleNetV2是一种轻量级的神经网络架构,用于图像分类和目标检测任务。它是ShuffleNet的改进版本,旨在提高模型的性能和效率。ShuffleNetV2相比于之前的版本,在保持模型轻量化的同时,提高了模型的准确性和性能。它在计算资源有限的设备上具…

Springboot+vue的医疗挂号管理系统+数据库+报告+免费远程调试

效果介绍: Springbootvue的医疗挂号管理系统,Javaee项目,springboot vue前后端分离项目 本文设计了一个基于Springbootvue的前后端分离的医疗挂号管理系统,采用M(model)V(view)C(con…

DC-1靶场

一.环境搭建 下载地址 http://www.five86.com/downloads/DC-1.zip 把桥接设置为nat模式,打开靶机的时候会提示几个错误,点击重试即可 启动靶机,如下图所示即可 二.开始打靶 1.信息收集 arp-scan -l 扫描跟kali(攻击机&…

C# 数组

C# 数组 一维数组初始化数组赋值给数组 二维数组初始化二维数组 交错数组交错数组的声明类型和其中的数组类型必须一致。 参数数组Array 类 一维数组 初始化数组 声明一个数组不会在内存中初始化数组。当初始化数组变量时,您可以赋值给数组。 数组是一个引用类型…

无服务器推理在大语言模型中的未来

服务器无服务器推理的未来:大型语言模型 摘要 随着大型语言模型(LLM)如GPT-4和PaLM的进步,自然语言任务的能力得到了显著提升。LLM被广泛应用于聊天机器人、搜索引擎和编程助手等场景。然而,由于LLM对GPU和内存的巨大需求,其在规…

Spring常用设计模式-实战篇之单例模式

实现案例,饿汉式 Double-Check机制 synchronized锁 /*** 以饿汉式为例* 使用Double-Check保证线程安全*/ public class Singleton {// 使用volatile保证多线程同一属性的可见性和指令重排序private static volatile Singleton instance;public static Singleton …

【洛谷 P8687】[蓝桥杯 2019 省 A] 糖果 题解(动态规划+位集合+位运算)

[蓝桥杯 2019 省 A] 糖果 题目描述 糖果店的老板一共有 M M M 种口味的糖果出售。为了方便描述,我们将 M M M 种口味编号 1 1 1 ∼ M M M。 小明希望能品尝到所有口味的糖果。遗憾的是老板并不单独出售糖果,而是 K K K 颗一包整包出售。 幸好糖…

【C++】1600. 请假时间计算

问题:1600. 请假时间计算 类型:基本运算、整数运算 题目描述: 假设小明的妈妈向公司请了 n 天的假,那么请问小明的妈妈总共请了多少小时的假,多少分钟的假?(提示: 1 天有 24 小时&…