渐进蒸馏和v-prediction

embedded/2024/10/10 22:08:58/

渐进蒸馏和v-prediction

TL;DR:比较早期的用蒸馏的思想来做扩散模型采样加速的方法,通过渐进地对预训练的扩散模型进行蒸馏,学生模型一步学习教师模型两步的去噪结果,不断降低采样步数。并提出一种新的参数化形式 v \mathbf{v} v-prediction 来解决渐进蒸馏过程中信噪比太低时误差影响较大的问题。

在这里插入图片描述

渐进蒸馏

在一开始,我们有一个预训练的原始扩散模型作为初始教师模型。我们首先将学生模型初始化为一个结构、参数都与教师模型一模一样的扩散模型。然后,不断采样干净图像数据,加噪声,训练学生模型的去噪能力。由于我们要进行蒸馏,所以这里学生模型的预测目标不是干净的图片 x \mathbf{x} x,而是要学生模型单步(DDIM)预测出教师模型两步(DDIM)的去噪结果 x ~ \tilde{\mathbf{x}} x~

具体来说,我们这里考虑的是连续时间步 t ∈ [ 0 , 1 ] t\in[0,1] t[0,1],目标步数(即学生模型的步数)为 N N N,从而步长是 1 / N 1/N 1/N,在时刻 t t t 是要去噪从 z t \mathbf{z}_t zt z t − 1 / N \mathbf{z}_{t-1/N} zt1/N。这样教师模型的步数是 2 N 2N 2N,每一步是从 z t \mathbf{z}_{t} zt z t − 0.5 / N \mathbf{z}_{t-0.5/N} zt0.5/N。我们这里连续运行教师模型两步,即从 z t \mathbf{z}_t zt z t − 0.5 / N \mathbf{z}_{t-0.5/N} zt0.5/N 再到 z t − 1 / N \mathbf{z}_{t-1/N} zt1/N,我们的学生模型训练目标就是要一步直接从 z t \mathbf{z}_t zt 预测出教师模型的两步去噪的结果 z t − 1 / N \mathbf{z}_{t-1/N} zt1/N

在收敛之后,我们将当前的学生模型作为下一轮的噪声模型,再将自身进行拷贝重新初始化一个新的学生模型,重复上述步骤。循环往复,即可通过渐进蒸馏不断降低模型的采样步数。

下面是渐进蒸馏的算法流程,对比了标准的扩散模型训练流程,主要就是将模型的预测目标从上一步的加噪结果改换成了教师模型的两步去噪结果,并渐进式地迭代这一过程。

在这里插入图片描述

参数化形式和训练损失

自从 DDPM 以来,扩散模型的参数化形式一般都是 ϵ \epsilon ϵ-prediction,即预测噪声,再根据噪声计算出数据 x \mathbf{x} x。相当于间接地预测 x \mathbf{x} x x ^ θ ( z t ) = 1 α t ( z t − σ t ϵ ^ θ ( z t ) ) \hat{\mathbf{x}}_\theta(\mathbf{z}_t)=\frac{1}{\alpha_t}(\mathbf{z}_t-\sigma_t\hat\epsilon_\theta(\mathbf{z}_t)) x^θ(zt)=αt1(ztσtϵ^θ(zt))

在常规的扩散模型训练以及渐进蒸馏训练的早期(步数还比较多时),噪声预测的参数化形式工作得很好。因为这时信噪比 α t 2 / σ t 2 \alpha_t^2/\sigma_t^2 αt2/σt2 在一个比较宽的范围内。当随着渐进蒸馏的进行,步数越来越少,信噪比越来越低以至于接近于 0,此时 α t \alpha_t αt 接近于 0。根据上式, α t \alpha_t αt 在间接预测 x ^ θ ( z t ) \hat{\mathbf{x}}_\theta(\mathbf{z}_t) x^θ(zt) 公式的分母上,因此此时网络输出预测噪声 ϵ ^ θ ( z t ) \hat{\epsilon}_\theta(\mathbf{z}_t) ϵ^θ(zt) 都会噪声 x \mathbf{x} x 的巨大变化,从而导致训练不稳定。并且渐进蒸馏后期步数较少,无法通过后面的步数进行修正。

最终,如果我们将模型蒸馏到只剩下一个采样步,那么模型的输入就只是纯噪声 ϵ \epsilon ϵ,此时信噪比为零,即 α t = 0 , σ t = 1 \alpha_t = 0, \sigma_t = 1 αt=0,σt=1。在这种极端情况下, ϵ \epsilon ϵ 预测和 x \mathbf{x} x 预测之间的联系完全中断:观测数据 z t = ϵ z_t = \epsilon zt=ϵ 不再包含 x \mathbf{x} x 的信息,并且 ϵ \epsilon ϵ 的预测 ϵ ^ θ ( z t ) \hat{\epsilon}_{\theta}(\mathbf{z}_t) ϵ^θ(zt) 也无法再间接地预测 x \mathbf{x} x。在损失函数中,加权函数 w ( λ t ) w(\lambda_t) w(λt) 在此时的权重也成了 0。

为了解决这一问题,作者尝试了直接预测 x \mathbf{x} x、同时分别预测 x \mathbf{x} x ϵ \epsilon ϵ 后合并出 x ^ \hat{\mathbf{x}} x^,还提出了一种新的参数化形式 v \mathbf{v} v-prediction:
v ≡ α t ϵ − σ t x \mathbf{v}\equiv \alpha_t\epsilon-\sigma_t\mathbf{x} vαtϵσtx
从而:
x ^ = α t z t − σ t v ^ θ ( z t ) \hat{\mathbf{x}}=\alpha_t\mathbf{z}_t-\sigma_t\hat{\mathbf{v}}_\theta(\mathbf{z}_t) x^=αtztσtv^θ(zt)
实验显示,这三种方式在渐进蒸馏训练中都表现得不错,并在在常规扩散模型的训练中效果也很好。

下面对作者设计的 v \mathbf{v} v-prediction 进行推导:

DDPM 的加噪公式:
z t = α t x + σ t ϵ \mathbf{z}_t=\alpha_t\mathbf{x}+\sigma_t\epsilon zt=αtx+σtϵ
ϕ t = arctan ⁡ ( σ t / α t ) \phi_t=\arctan(\sigma_t/\alpha_t) ϕt=arctan(σt/αt),则有 α t = cos ⁡ ( ϕ ) , σ t = sin ⁡ ( ϕ ) \alpha_t=\cos(\phi),\sigma_t=\sin(\phi) αt=cos(ϕ),σt=sin(ϕ),从而:
z ϕ = cos ⁡ ( ϕ ) x + sin ⁡ ( ϕ ) ϵ \mathbf{z}_\phi=\cos(\phi)\mathbf{x}+\sin(\phi)\epsilon zϕ=cos(ϕ)x+sin(ϕ)ϵ
定义 z ϕ z_\phi zϕ 的 “速度” 为其关于 ϕ \phi ϕ 的导数:
v ϕ ≡ d z ϕ d ϕ = d cos ⁡ ( ϕ ) d ϕ x + d sin ⁡ ϕ d ϕ ϵ = sin ⁡ ( ϕ ) x − cos ⁡ ( ϕ ) ϵ \mathbf{v}_\phi\equiv\frac{d\mathbf{z}_\phi}{d\phi}=\frac{d\cos(\phi)}{d\phi}\mathbf{x}+\frac{d\sin{\phi}}{d\phi}\epsilon=\sin(\phi)\mathbf{x}-\cos(\phi)\epsilon vϕdϕdzϕ=dϕdcos(ϕ)x+dϕdsinϕϵ=sin(ϕ)xcos(ϕ)ϵ
这里就是上面 v \mathbf{v} v 的定义 v ≡ α t ϵ − σ t x \mathbf{v}\equiv \alpha_t\epsilon-\sigma_t\mathbf{x} vαtϵσtx。稍微进行变换,有:
sin ⁡ ( ϕ ) = cos ⁡ ( ϕ ) ϵ − v ϕ = cos ⁡ ( ϕ ) sin ⁡ ( ϕ ) ( z − cos ⁡ ( ϕ ) x ) − v ϕ sin ⁡ 2 ( ϕ ) x = cos ⁡ ( ϕ ) z − cos ⁡ 2 ( ϕ ) x − sin ⁡ ( ϕ ) v ϕ sin ⁡ 2 ( ϕ ) x + cos ⁡ 2 ( ϕ ) x = cos ⁡ ( ϕ ) z − sin ⁡ ( ϕ ) v ϕ x = cos ⁡ ( ϕ ) z − sin ⁡ ( ϕ ) v ϕ \begin{align} \sin(\phi)&=\cos(\phi)\epsilon-\mathbf{v}_\phi\\ &=\frac{\cos(\phi)}{\sin(\phi)}(\mathbf{z}-\cos(\phi)\mathbf{x})-\mathbf{v}_\phi\\ \sin^2(\phi)\mathbf{x}&=\cos(\phi)\mathbf{z}-\cos^2(\phi)\mathbf{x}-\sin(\phi)\mathbf{v}_\phi\\ \sin^2(\phi)\mathbf{x}+\cos^2(\phi)\mathbf{x}&=\cos(\phi)\mathbf{z}-\sin(\phi)\mathbf{v}_\phi\\ \mathbf{x}&=\cos(\phi)\mathbf{z}-\sin(\phi)\mathbf{v}_\phi \end{align} sin(ϕ)sin2(ϕ)xsin2(ϕ)x+cos2(ϕ)xx=cos(ϕ)ϵvϕ=sin(ϕ)cos(ϕ)(zcos(ϕ)x)vϕ=cos(ϕ)zcos2(ϕ)xsin(ϕ)vϕ=cos(ϕ)zsin(ϕ)vϕ=cos(ϕ)zsin(ϕ)vϕ
这里就是上面的第二个公式 x ^ = α t z t − σ t v ^ θ ( z t ) \hat{\mathbf{x}}=\alpha_t\mathbf{z}_t-\sigma_t\hat{\mathbf{v}}_\theta(\mathbf{z}_t) x^=αtztσtv^θ(zt)。这个推导过程可以参考下图来理解。

在这里插入图片描述

总结

早期提出的渐进蒸馏是一种比较直觉的扩散模型步数蒸馏方法,其提出的 v-prediction 在后来也有广泛的应用。


http://www.ppmy.cn/embedded/125565.html

相关文章

设计模式之原型模式(通俗易懂--代码辅助理解【Java版】)

文章目录 设计模式概述1、原型模式2、原型模式的使用场景3、优点4、缺点5、主要角色6、代码示例7、总结题外话关于使用序列化实现深拷贝 设计模式概述 创建型模式:工厂方法、抽象方法、建造者、原型、单例。 结构型模式有:适配器、桥接、组合、装饰器、…

kubelet 运行机制、功能 全面分析

Kubelet 在Kubernetes集群中,在每个Node(又称为Minion)上都会启动一个Kubelet服务进程。该进程用于处理Master下发到本节点的任务,管理Pod及Pod中的容器。每个Kubelet进程都会在API Server上注册节点自身的信息,定期向…

微信搜一搜又升级啦!

据“微信派”微信公号10月8日消息,微信搜一搜上线“搜索直达”功能,相关话题登上微博热搜。网友表示这个更新可以省去翻设置功能的步骤,更直接、快速地解决问题。 搜“拍一拍设置” 搜“深夜模式设置” 搜“来电铃声设置” 搜“关怀模式”设置…

王道408考研数据结构-图-第六章

6.1 图的基本概念 6.1.1 图的定义 图G由顶点集V和边集 E组成,记为G(V,E),其中 V(G)表示图G中顶点的有限非空集;E(G)表示图G中顶点之间的关系(边)集合。若V{v?,v?,…,vn},则用|M表示图G中顶第6章 点的个数,E{(u,v) | uεV,vεV},用|E|表示图…

SQL自学:什么是SQL的聚集函数,如何利用它们汇总表的数据

在 SQL(Structured Query Language,结构化查询语言)中,聚集函数也称为聚合函数,是对一组值进行计算并返回单一值的函数。 一、常见的聚集函数及功能 1. AVG():用于计算某一列的平均值。 例如,…

自动驾驶汽车横向控制方法研究综述

【摘要】 为实现精确、稳定的横向控制,提高车辆自主行驶的安全性和保障乘坐舒适性,综述了近年来自动驾驶汽车横向控制方法的最新进展,包括经典控制方法和基于深度学习的方法,讨论了各类方法的性能特点及在应用中的优缺点&#xff…

开发常用编辑器,你知道几个?

以下是 Python 最受欢迎的 10 个编辑器: pyCharm 由捷克公司 JetBrains 开发,是使用最广泛的 Python IDE 之一。它分为社区版和专业版,社区版免费且功能足够满足日常基本需求,专业版功能更强大但需付费。 优势在于智能代码补全、…

2024年最佳平替电容笔对比:西圣、摩米士、倍思,哪款更适合你?

作为一位专注于数码产品的博主,我深知近年来平替电容笔在消费者中的热度不断攀升。这种电容笔以其亲民的价格和卓越的书写体验引起了广泛关注,尤其适合那些需要用iPad学习和办公的无纸化爱好者。 西圣这款自带充电仓的电容笔备受关注,尤其因…