Prompt-to-Prompt Image Editing with Cross Attention Control
TL; DR:prompt2prompt 提出通过替换 UNet 中的交叉注意力图,在图像编辑过程中根据新的 prompt 语义生图的同时,保持图像整体布局结构不变。从而实现了基于纯文本(不用 mask 等额外信息)的图像编辑。
导语
基于文本控制图像生成模型取得了飞速的进展,一个自然的想法是如何实现基于文本控制的图像编辑模型。图像编辑需要在保持原始图片构图不变的情况下,对部分元素或整体风格进行修改。然而,在文生图模型中,文本 prompt 中一个小小的修改(比如替换一个单词),通常都会使得生图结果完全不同。现有的图像编辑方法一般是通过要求用户同时提供编辑区域的 mask,并只在 mask 内进行图像编辑,从而保持图像整体构图布局不变。
本文中,作者深入分析了现有的文生图模型,发现交叉注意力层是控制图像空间布局与 prompt 中每个单词之间关系的关键。基于此,作者通过替换交叉注意力图,构建了一种仅需修改 prompt (无需额外 mask) 的图像编辑框架。并基于该框架介绍了三种图像编辑应用,包括替换单词的局部编辑、添加描述的全局编辑以及更改单词权重来修改单词在生图结果中的反映程度。
下图展示了 prompt2prompt 图像编辑的一些结果,可以看到,仅通过在 prompt 中替换、删除或添加几个单词,就能完成对应的图像编辑,关键是在这个过程过程中保持了图像整体空间布局和其他背景元素不变。
方法
记 I \mathcal{I} I 为文生图模型在 prompt 为 P \mathcal{P} P ,随机种子为 s s s 下生成的一张图像,我们的目标是仅通过修改 prompt 为 P ∗ \mathcal{P}^* P∗ ,同时固定随机种子 s s s,生成出语义符合 P ∗ \mathcal{P}^* P∗ 且空间构图与 I \mathcal{I} I 一致的编辑图像 I ∗ \mathcal{I}^* I∗。
图文交叉注意力
prompt2prompt 的方法注意力替换技术如下图所示。带噪图像 ϕ ( z t ) \phi(z_t) ϕ(zt) 映射为 query 矩阵 Q = l Q ( ϕ ( z t ) Q=l_Q(\phi(z_t) Q=lQ(ϕ(zt) ,文本 prompt 嵌入 ψ ( P ) \psi(\mathcal{P}) ψ(P) 映射为 key 矩阵 K = l K ψ ( P ) K=l_K\psi(\mathcal{P}) K=lKψ(P) 和 value 矩阵 V = l V ψ ( P ) V=l_V\psi(\mathcal{P}) V=lVψ(P) ,其中 l Q , l K , l V l_Q,l_K,l_V lQ,lK,lV 是三个线性映射层。注意力图(attention maps) M M M 为:
M = Softmax ( Q K T d ) M=\text{Softmax}(\frac{QK^T}{\sqrt{d}}) M=Softmax(dQKT)
其中 d d d 是 key 和 value 的隐层维度,每个单元 M i j M_{ij} Mij 表示第 j j j 个文本 token 和第 i i i 个图像块 token 的相关性,该值会作为 value 矩阵的权重,从而交叉注意力层最终的输出为 ϕ ^ ( z t ) = M V \hat\phi(z_t)=MV ϕ^(zt)=MV 。
总结一下,交叉注意力层的输出 M V MV MV 是 value 矩阵 V V V 的加权平均,而各位置的权重就是注意力图 M M M , M M M 其实就是 query 矩阵 Q Q Q 和 key 矩阵 K K K 的各位置的相似度。
交叉注意力图替换
本文的关键发现是:生成图像的空间布局是由交叉注意力图决定的。该发现可由下图看出。图中展示了文本 prompt 中的各个单词对应的交叉注意力图,其中上方一行展示了交叉注意力图在所有时间步平均的结果,可以看到:像素位置与描述它的单词之间的注意力权重更高(如 bear、bird 都对应于各自的位置);下方两行展示了 bear 和 bird 两个单词在整个生图过程中每一时间步的对应注意力图,可以看到:文本单词与空间位置的对应关系在生图早期就已经确定下来了。这个可视化实验支撑了本文的关键发现:生成图像的空间布局是由交叉注意力图决定的。
既然已经知道了生图的空间构图是由交叉注意力图决定的,那要保持图像编辑过程中的空间构图的思路就很直接了:将第一次生图 I \mathcal{I} I 时的交叉注意力图 M M M 保存下来,在图像编辑生图 I ∗ \mathcal{I}^* I∗ 的过程中替换上去,让对应位置的单词还是在编辑生图的对应空间位置。这样就可以使得编辑图像 I ∗ \mathcal{I}^* I∗ 不仅语义与编辑 prompt P ∗ \mathcal{P}^* P∗ 一致,同时还能保持空间构图与原图 I \mathcal{I} I 一致。
接下来,本文提出了一种统一的图像编辑框架,并介绍了该框架下的三种图像编辑应用:分别是替换单词的局部编辑、添加描述的全局编辑以及更改单词权重来修改单词在生图结果中的反映程度。
记扩散模型的在时间步 t t t 的去噪过程为 D M ( z t , P , t , s ) DM(z_t,\mathcal{P},t,s) DM(zt,P,t,s) ,其输出为预测的上一时间步的噪声图像 z t − 1 z_{t-1} zt−1,我们同时保存该步去噪过程中的交叉注意力图,记为 M t M_t Mt。记 D M ( z t , P , t , s ) { M ← M ^ } DM(z_t,\mathcal{P},t,s)\{M\leftarrow \hat{M}\} DM(zt,P,t,s){M←M^} 为将交叉注意力图 M M M 替换为 M ^ \hat{M} M^ 的去噪步,注意其中的 value 矩阵 V V V 不做替换,仅替换交叉注意力图。我们定义 E d i t ( M t , M t ∗ , t ) Edit(M_t,M_t^*,t) Edit(Mt,Mt∗,t) 为交叉注意力图编辑函数,其中 M t ∗ M_t^* Mt∗ 为生成编辑图像时的交叉注意力图。编辑函数输入为:生成原始图片时的注意力图、生成编辑图片时的注意力图和时间步,输出为当前步的注意力图。
注意 prompt2prompt 需要保持生成原始图片时的随机种子和生成编辑图片时的随机种子一致,因为在扩散模型中,即使文本 prompt 相同,如果随机种子不同,生成结果也会非常不同。
通过在上述框架下使用不同的编辑函数,就能实现不同的图像编辑应用。本文提到的三种应用图示见图 2 下侧,以下分别具体介绍三种图像编辑应用。
应用
word swap
通过单词替换来改变原图中的物体或属性。比如原 prompt 为 “a big red car”,编辑 prompt 为 “a big red bike”。由于 prompt 中 token 数目没变,这种情况可以直接通过交叉注意力替换来实现语义与编辑 prompt 一致,而构图与原图一致。但是,在去噪过程全程都进行注意力替换约束可能会过强,尤其是当编辑前后 prompt 的轮廓发生很大变化时(如 car -> bike)。考虑到生图的空间位置关系是在早期的去噪步中确定的,这里通过指定在某个时间步 τ \tau τ 之前进行注意力替换,来使得整体约束相对宽松,给编辑图像生图一定的自由度:
E d i t ( M t , M t ∗ , t ) : = { M t ∗ if t < τ M t oherwise Edit(M_t,M_t^*,t):= \begin{cases} M_t^*\ \ \ \text{if}\ t<\tau \\ M_t\ \ \ \text{oherwise} \end{cases} Edit(Mt,Mt∗,t):={Mt∗ if t<τMt oherwise
如果编辑 prompt 中的新单词编码为多个 token,还可以对不同的 token 分别指定不同的 τ \tau τ ;如果编辑前后两个单词编码的 token 数不同,可以通过复制、平均或使用对齐函数来对齐。
adding a new phrase
用户还可以在原 prompt 中添加一些单词,得到编辑 prompt。比如原 prompt 为 “a castle next to a river”,编辑 prompt 为 “children drawing of a castle next to a river”。此时我们仅对编辑前后两个 prompt 中相同的单词进行注意力图替换。具体老说,我们使用一个对齐函数 A A A,将编辑 prompt 中的 token 的索引映射为原始 prompt 中相同 token 的索引,如果原始 prompt 没有就返回 None。此时编辑函数为:
( E d i t ( M t , M t ∗ , t ) ) i , j : = { ( M t ∗ ) i , j if A ( j ) = N o n e ( M t ) i , j oherwise (Edit(M_t,M_t^*,t))_{i,j}:= \begin{cases} (M_t^*)_{i,j}\ \ \ \text{if}\ A(j)=None \\ (M_t)_{i,j}\ \ \ \text{oherwise} \end{cases} (Edit(Mt,Mt∗,t))i,j:={(Mt∗)i,j if A(j)=None(Mt)i,j oherwise
这种情况下同时可以结合情况 1 来控制使用注意力替换的时机 τ \tau τ 。这种编辑方式可以实现整体图像风格的修改,图中某物体属性的修改等。
attention re-weighting
还有一种情况,用户可能想要强化或者弱化某个单词在最终的生图结果中反映的程度。比如原始 prompt 为 “a fluffy ball”,可能会想要控制生图结果中 ball fluffy 的程度。这种情况下的编辑函数其实不是注意力替换了,而是对 prompt 中某些 token 的注意力值乘上一个权重 c ∈ [ − 2 , 2 ] c\in[-2,2] c∈[−2,2]:
( E d i t ( M t , M t ∗ , t ) ) i , j : = { c ⋅ ( M t ) i , j if j = j ∗ ( M t ) i , j oherwise (Edit(M_t,M_t^*,t))_{i,j}:= \begin{cases} c\cdot(M_t)_{i,j}\ \ \ \text{if}\ j=j^* \\ (M_t)_{i,j}\ \ \ \text{oherwise} \end{cases} (Edit(Mt,Mt∗,t))i,j:={c⋅(Mt)i,j if j=j∗(Mt)i,j oherwise
接下来作者介绍了 prompt2prompt 的几种实际应用,包括 text-only localized editing、global editing、fader control 和 real image editing。这里就不一一介绍了。
提一下关于 real image editing,由于此时的待编辑图片不是模型生成的,而是真实图片,因此我们无法直接获得其去噪过程的交叉注意力图。对于真实图片,我们需要先根据图片反推出扩散模型的起始噪声,确保这个其实噪声输入扩散模型能够生成出这张真实图片。然后再使用 prompt2prompt 进行图像编辑。这种根据图片反推生成模型的起始噪声的技术称为 inversion。在之前的生成模型中,GAN Inversion 是图像编辑中的一种关键技术。而在扩散模型中,比较常用的是 DDIM Inversion 及其变体。这是因为 DDIM 的去噪过程是确定的,而原始的 DDPM 去噪过程每一步还加入了额外的随机噪声,无法做 inversion。关于 inversion 技术用于图像编辑也有相当多的工作,有兴趣的读者可自行查找。
代码简读
谷歌官方的 prompt2prompt 代码 封装得非常优雅,但第一次读起来可能会被稍显复杂的继承关系搞晕,这里简单梳理一下。diffusers 中集成的 prompt2prompt 代码 基本是将官方代码的主要部分搬了过来,然后通过 diffusers 库提供的可自定义的 AttnProcessor 的方式将 prompt2prompt 对注意力的修改注册到 unet 中,这里也简单介绍一下。
AttentionControlClasses
prompt2prompt 代码将注意力图修改抽象为了四层类。其中最顶层的三个子类 AttentionReplace、AttentionRefine、AttentionReweight 分别对应了本文介绍的三种编辑方式,这三个子类中主要是实现不同的 replace_cross_attention 方法。AttentionControlEdit 是这三个子类的共同父类,该类的 forward 方法定义了注意力图的存储(由其父类 AttentionStore 的 forward 方法实现)和编辑,其中注意力图的编辑需要调用 replace_self_attention(通用) 和 replace_cross_attention(由不同子类实现)。再低一层是 AttentionStore,该类主要是在其 forward 函数中定义了注意力图的存储,并由其子类通过 super().forward() 来调用。最底层是 AttentionControl,该类在 __call__ 方法中调用 forward 方法,最终子类的 __call__ 方法一直继承到这一层。
从调用链的角度(图中左下角)来说,最顶层的子类在实例化后调用 __call__ 方法时会一直向上继承调用到 AttentionControl 的 __call__ 方法,该方法调用了 forward 方法,Attention 的 foward 方法实现了注意力图的存储,AttentionControlEdit 的 forward 方法则结合其子类的 replace_cross_attention 方法实现了注意力图的修改。图中左下角调用链中红色部分是存在实际执行逻辑的方法,其他则是继承链中的抽象方法。
Utils
这里 是包括不同 token 对齐等在内的一些 utils,不作过多介绍。
P2PAttentionProcessor
在 diffusers 继承的 prompt2prompt pipeline 中,是通过自定义 AttnProcessor 的方式来进行上述注意力图修改的。
首先根据传入的 cross_attention_kwargs 参数,通过 create_controller 方法 实例化对应的注意力图修改类。然后通过 prompt2prompt pipeline 的 register_attention_control 方法 controller 实例化对应的 P2PAttnProcessor 类 ,并遍历替换掉 UNet 各层的原 AttnProcessor 类,这里其实只需要增加一行,在此处使用 controller 对原注意力图进行修改即可。
总结
prompt2prompt 是扩散模型图像编辑领域非常有意义的一个工作,它通过替换编辑图像生图过程中的交叉注意力图,使得生成出的编辑图像与原始图像的空间布局保持一致。从而实现了仅需修改 prompt 的图像编辑。在之后的 InstructPix2Pix、OMG 等图像编辑方法中,都用到了这项技术。