变分扩散模型 ELBO 重构推导详解
在变分扩散模型(Variational Diffusion Model)中,证据下界(Evidence Lower Bound, ELBO)的形式通过优化正向和逆向分布的匹配来实现数据生成。初始 ELBO (变分扩散模型中的 Evidence Lower Bound (ELBO) 详解)存在采样复杂性,尤其是过渡块中需要联合分布 ( q φ ( x t − 1 , x t + 1 ∣ x 0 ) q_φ(x_{t-1}, x_{t+1}|x_0) qφ(xt−1,xt+1∣x0) ) 的样本,这引发了重新设计的动机。后面提出了一种等价的 ELBO 形式,通过贝叶斯定理和条件调整简化了计算。本文将详细推导这一重构过程,解释这种转变,面向具备概率论和深度学习基础的读者。
参考:https://arxiv.org/pdf/2403.18103
初始 ELBO 的问题
原始 ELBO
原本定义的 ELBO 为:
ELBO φ , θ ( x ) = E q φ ( x 1 ∣ x 0 ) [ log p θ ( x 0 ∣ x 1 ) ] − E q φ ( x T − 1 ∣ x 0 ) [ D K L ( q φ ( x T ∣ x T − 1 ) ∥ p ( x T ) ) ] − ∑ t = 1 T − 1 E q φ ( x t − 1 , x t + 1 ∣ x 0 ) [ D K L ( q φ ( x t ∣ x t − 1 ) ∥ p θ ( x t ∣ x t + 1 ) ) ] \text{ELBO}_{φ,θ}(x) = \mathbb{E}_{q_φ(x_1|x_0)} [\log p_θ(x_0|x_1)] - \mathbb{E}_{q_φ(x_{T-1}|x_0)} \left[ D_{KL}(q_φ(x_T|x_{T-1}) \| p(x_T)) \right] - \sum_{t=1}^{T-1} \mathbb{E}_{q_φ(x_{t-1}, x_{t+1}|x_0)} \left[ D_{KL}(q_φ(x_t|x_{t-1}) \| p_θ(x_t|x_{t+1})) \right] ELBOφ,θ(x)=Eqφ(x1∣x0)[logpθ(x0∣x1)]−Eqφ(xT−1∣x0)[DKL(qφ(xT∣xT−1)∥p(xT))]−t=1∑T−1Eqφ(xt−1,xt+1∣x0)[DKL(qφ(xt∣xt−1)∥pθ(xt∣xt+1))]
- 初始块:重构项 ( E q φ ( x 1 ∣ x 0 ) [ log p θ ( x 0 ∣ x 1 ) ] \mathbb{E}_{q_φ(x_1|x_0)} [\log p_θ(x_0|x_1)] Eqφ(x1∣x0)[logpθ(x0∣x1)] )。
- 最终块:先验匹配项 ( − E q φ ( x T − 1 ∣ x 0 ) [ D K L ( q φ ( x T ∣ x T − 1 ) ∥ p ( x T ) ) ] -\mathbb{E}_{q_φ(x_{T-1}|x_0)} [D_{KL}(q_φ(x_T|x_{T-1}) \| p(x_T))] −Eqφ(xT−1∣x0)[DKL(qφ(xT∣xT−1)∥p(xT))] )。
- 过渡块:一致性项 ( − ∑ t = 1 T − 1 E q φ ( x t − 1 , x t + 1 ∣ x 0 ) [ D K L ( q φ ( x t ∣ x t − 1 ) ∥ p θ ( x t ∣ x t + 1 ) ) ] -\sum_{t=1}^{T-1} \mathbb{E}_{q_φ(x_{t-1}, x_{t+1}|x_0)} [D_{KL}(q_φ(x_t|x_{t-1}) \| p_θ(x_t|x_{t+1}))] −∑t=1T−1Eqφ(xt−1,xt+1∣x0)[DKL(qφ(xt∣xt−1)∥pθ(xt∣xt+1))] )。
问题所在
过渡块需要从联合分布 ( q φ ( x t − 1 , x t + 1 ∣ x 0 ) q_φ(x_{t-1}, x_{t+1}|x_0) qφ(xt−1,xt+1∣x0) ) 抽样,这涉及未来状态 ( x t + 1 x_{t+1} xt+1 ) 和过去状态 ( x t − 1 x_{t-1} xt−1 ) 的耦合。直接采样 ( ( x t − 1 , x t + 1 ) (x_{t-1}, x_{t+1}) (xt−1,xt+1) ) 复杂,因为 ( q φ ( x t + 1 ∣ x 0 ) q_φ(x_{t+1}|x_0) qφ(xt+1∣x0) ) 依赖多步正向过程,且正向 ( q φ ( x t ∣ x t − 1 ) q_φ(x_t|x_{t-1}) qφ(xt∣xt−1) ) 和逆向 ( p θ ( x t ∣ x t + 1 ) p_θ(x_t|x_{t+1}) pθ(xt∣xt+1) ) 方向相反,增加了计算负担。
重构动机与贝叶斯调整
一致性项的挑战
- ( q φ ( x t ∣ x t − 1 ) q_φ(x_t|x_{t-1}) qφ(xt∣xt−1) ) 是正向过渡,( p θ ( x t ∣ x t + 1 ) p_θ(x_t|x_{t+1}) pθ(xt∣xt+1) ) 是逆向过渡,两者方向相反,导致需要同时处理 ( x t − 1 x_{t-1} xt−1 ) 和 ( x t + 1 x_{t+1} xt+1 ) 的样本。
- 目标是简化一致性检查,避免“反向”依赖。
贝叶斯定理的引入
通过贝叶斯定理调整条件分布:
q ( x t ∣ x t − 1 ) = q ( x t − 1 ∣ x t ) q ( x t ) q ( x t − 1 ) q(x_t|x_{t-1}) = \frac{q(x_{t-1}|x_t) q(x_t)}{q(x_{t-1})} q(xt∣xt−1)=q(xt−1)q(xt−1∣xt)q(xt)
条件于 ( x 0 x_0 x0 ):
q ( x t ∣ x t − 1 , x 0 ) = q ( x t − 1 ∣ x t , x 0 ) q ( x t ∣ x 0 ) q ( x t − 1 ∣ x 0 ) q(x_t|x_{t-1}, x_0) = \frac{q(x_{t-1}|x_t, x_0) q(x_t|x_0)}{q(x_{t-1}|x_0)} q(xt∣xt−1,x0)=q(xt−1∣x0)q(xt−1∣xt,x0)q(xt∣x0)
- 这一变换将正向 ( q ( x t ∣ x t − 1 , x 0 ) q(x_t|x_{t-1}, x_0) q(xt∣xt−1,x0) ) 转化为逆向形式的 ( q ( x t − 1 ∣ x t , x 0 ) q(x_{t-1}|x_t, x_0) q(xt−1∣xt,x0) ),方向与 ( p θ ( x t − 1 ∣ x t ) p_θ(x_{t-1}|x_t) pθ(xt−1∣xt) ) 一致。
- ( x 0 x_0 x0 ) 的条件确保分布依赖初始状态,避免无限制采样。
重构 ELBO 的推导
步骤 1:从 Jensen 不等式开始
从之前的基础推导(变分扩散模型 ELBO 的推导过程详解)出发:
log p ( x ) ≥ E q φ ( x 1 : T ∣ x 0 ) [ log p ( x 0 : T ) q φ ( x 1 : T ∣ x 0 ) ] \log p(x) \geq \mathbb{E}_{q_φ(x_{1:T}|x_0)} \left[ \log \frac{p(x_{0:T})}{q_φ(x_{1:T}|x_0)} \right] logp(x)≥Eqφ(x1:T∣x0)[logqφ(x1:T∣x0)p(x0:T)]
代入联合分布:
p ( x 0 : T ) = p ( x T ) p ( x 0 ∣ x 1 ) ∏ t = 2 T p ( x t − 1 ∣ x t ) p(x_{0:T}) = p(x_T) p(x_0|x_1) \prod_{t=2}^T p(x_{t-1}|x_t) p(x0:T)=p(xT)p(x0∣x1)t=2∏Tp(xt−1∣xt)
q φ ( x 1 : T ∣ x 0 ) = q φ ( x 1 ∣ x 0 ) ∏ t = 2 T q φ ( x t ∣ x t − 1 , x 0 ) q_φ(x_{1:T}|x_0) = q_φ(x_1|x_0) \prod_{t=2}^T q_φ(x_t|x_{t-1}, x_0) qφ(x1:T∣x0)=qφ(x1∣x0)t=2∏Tqφ(xt∣xt−1,x0)
(注意:这里 ( q φ ( x t ∣ x t − 1 , x 0 ) q_φ(x_t|x_{t-1}, x_0) qφ(xt∣xt−1,x0) ) 因马尔可夫性简化为 ( q φ ( x t ∣ x t − 1 ) q_φ(x_t|x_{t-1}) qφ(xt∣xt−1)),但为一致性保留条件。)
步骤 2:展开对数项
log p ( x 0 : T ) q φ ( x 1 : T ∣ x 0 ) = log p ( x T ) p ( x 0 ∣ x 1 ) ∏ t = 2 T p ( x t − 1 ∣ x t ) q φ ( x 1 ∣ x 0 ) ∏ t = 2 T q φ ( x t ∣ x t − 1 , x 0 ) \log \frac{p(x_{0:T})}{q_φ(x_{1:T}|x_0)} = \log \frac{p(x_T) p(x_0|x_1) \prod_{t=2}^T p(x_{t-1}|x_t)}{q_φ(x_1|x_0) \prod_{t=2}^T q_φ(x_t|x_{t-1}, x_0)} logqφ(x1:T∣x0)p(x0:T)=logqφ(x1∣x0)∏t=2Tqφ(xt∣xt−1,x0)p(xT)p(x0∣x1)∏t=2Tp(xt−1∣xt)
分离:
= log p ( x T ) p ( x 0 ∣ x 1 ) q φ ( x 1 ∣ x 0 ) + log ∏ t = 2 T p ( x t − 1 ∣ x t ) ∏ t = 2 T q φ ( x t ∣ x t − 1 , x 0 ) = \log \frac{p(x_T) p(x_0|x_1)}{q_φ(x_1|x_0)} + \log \frac{\prod_{t=2}^T p(x_{t-1}|x_t)}{\prod_{t=2}^T q_φ(x_t|x_{t-1}, x_0)} =logqφ(x1∣x0)p(xT)p(x0∣x1)+log∏t=2Tqφ(xt∣xt−1,x0)∏t=2Tp(xt−1∣xt)
步骤 3:应用贝叶斯调整
对第二项,使用贝叶斯定理:
p ( x t − 1 ∣ x t ) q φ ( x t ∣ x t − 1 , x 0 ) = p ( x t − 1 ∣ x t ) q φ ( x t − 1 ∣ x t , x 0 ) q φ ( x t ∣ x 0 ) q φ ( x t − 1 ∣ x 0 ) \frac{p(x_{t-1}|x_t)}{q_φ(x_t|x_{t-1}, x_0)} = \frac{p(x_{t-1}|x_t)}{q_φ(x_{t-1}|x_t, x_0) \frac{q_φ(x_t|x_0)}{q_φ(x_{t-1}|x_0)}} qφ(xt∣xt−1,x0)p(xt−1∣xt)=qφ(xt−1∣xt,x0)qφ(xt−1∣x0)qφ(xt∣x0)p(xt−1∣xt)
= q φ ( x t − 1 ∣ x t , x 0 ) q φ ( x t ∣ x 0 ) q φ ( x t − 1 ∣ x 0 ) ⋅ p ( x t − 1 ∣ x t ) q φ ( x t − 1 ∣ x t , x 0 ) = \frac{q_φ(x_{t-1}|x_t, x_0) q_φ(x_t|x_0)}{q_φ(x_{t-1}|x_0)} \cdot \frac{p(x_{t-1}|x_t)}{q_φ(x_{t-1}|x_t, x_0)} =qφ(xt−1∣x0)qφ(xt−1∣xt,x0)qφ(xt∣x0)⋅qφ(xt−1∣xt,x0)p(xt−1∣xt)
整理乘积:
∏ t = 2 T p ( x t − 1 ∣ x t ) q φ ( x t ∣ x t − 1 , x 0 ) = ∏ t = 2 T p ( x t − 1 ∣ x t ) q φ ( x t − 1 ∣ x t , x 0 ) ⋅ q φ ( x t − 1 ∣ x 0 ) q φ ( x t ∣ x 0 ) \prod_{t=2}^T \frac{p(x_{t-1}|x_t)}{q_φ(x_t|x_{t-1}, x_0)} = \prod_{t=2}^T \frac{p(x_{t-1}|x_t)}{q_φ(x_{t-1}|x_t, x_0)} \cdot \frac{q_φ(x_{t-1}|x_0)}{q_φ(x_t|x_0)} t=2∏Tqφ(xt∣xt−1,x0)p(xt−1∣xt)=t=2∏Tqφ(xt−1∣xt,x0)p(xt−1∣xt)⋅qφ(xt∣x0)qφ(xt−1∣x0)
步骤 4:期望分离
E q φ ( x 1 : T ∣ x 0 ) [ log p ( x T ) p ( x 0 ∣ x 1 ) q φ ( x 1 ∣ x 0 ) + log ∏ t = 2 T p ( x t − 1 ∣ x t ) q φ ( x t ∣ x t − 1 , x 0 ) ] \mathbb{E}_{q_φ(x_{1:T}|x_0)} \left[ \log \frac{p(x_T) p(x_0|x_1)}{q_φ(x_1|x_0)} + \log \prod_{t=2}^T \frac{p(x_{t-1}|x_t)}{q_φ(x_t|x_{t-1}, x_0)} \right] Eqφ(x1:T∣x0)[logqφ(x1∣x0)p(xT)p(x0∣x1)+logt=2∏Tqφ(xt∣xt−1,x0)p(xt−1∣xt)]
- 第一项:
E q φ ( x 1 : T ∣ x 0 ) [ log p ( x T ) p ( x 0 ∣ x 1 ) q φ ( x 1 ∣ x 0 ) ] \mathbb{E}_{q_φ(x_{1:T}|x_0)} \left[ \log \frac{p(x_T) p(x_0|x_1)}{q_φ(x_1|x_0)} \right] Eqφ(x1:T∣x0)[logqφ(x1∣x0)p(xT)p(x0∣x1)]
= E q φ ( x 1 ∣ x 0 ) [ log p θ ( x 0 ∣ x 1 ) ] + E q φ ( x 1 : T ∣ x 0 ) [ log p ( x T ) q φ ( x T ∣ x 0 ) ] = \mathbb{E}_{q_φ(x_1|x_0)} [\log p_θ(x_0|x_1)] + \mathbb{E}_{q_φ(x_{1:T}|x_0)} \left[ \log \frac{p(x_T)}{q_φ(x_T|x_0)} \right] =Eqφ(x1∣x0)[logpθ(x0∣x1)]+Eqφ(x1:T∣x0)[logqφ(xT∣x0)p(xT)]
- 第二项:
E q φ ( x 1 : T ∣ x 0 ) [ log ∏ t = 2 T p ( x t − 1 ∣ x t ) q φ ( x t ∣ x t − 1 , x 0 ) ] = ∑ t = 2 T E q φ ( x 1 : T ∣ x 0 ) [ log p ( x t − 1 ∣ x t ) q φ ( x t ∣ x t − 1 , x 0 ) ] \mathbb{E}_{q_φ(x_{1:T}|x_0)} \left[ \log \prod_{t=2}^T \frac{p(x_{t-1}|x_t)}{q_φ(x_t|x_{t-1}, x_0)} \right] = \sum_{t=2}^T \mathbb{E}_{q_φ(x_{1:T}|x_0)} \left[ \log \frac{p(x_{t-1}|x_t)}{q_φ(x_t|x_{t-1}, x_0)} \right] Eqφ(x1:T∣x0)[logt=2∏Tqφ(xt∣xt−1,x0)p(xt−1∣xt)]=t=2∑TEqφ(x1:T∣x0)[logqφ(xt∣xt−1,x0)p(xt−1∣xt)]
使用贝叶斯调整:
p ( x t − 1 ∣ x t ) q φ ( x t ∣ x t − 1 , x 0 ) = p ( x t − 1 ∣ x t ) q φ ( x t − 1 ∣ x 0 ) q φ ( x t − 1 ∣ x t , x 0 ) q φ ( x t ∣ x 0 ) \frac{p(x_{t-1}|x_t)}{q_φ(x_t|x_{t-1}, x_0)} = \frac{p(x_{t-1}|x_t) q_φ(x_{t-1}|x_0)}{q_φ(x_{t-1}|x_t, x_0) q_φ(x_t|x_0)} qφ(xt∣xt−1,x0)p(xt−1∣xt)=qφ(xt−1∣xt,x0)qφ(xt∣x0)p(xt−1∣xt)qφ(xt−1∣x0)
log p ( x t − 1 ∣ x t ) q φ ( x t ∣ x t − 1 , x 0 ) = log p ( x t − 1 ∣ x t ) q φ ( x t − 1 ∣ x t , x 0 ) + log q φ ( x t − 1 ∣ x 0 ) q φ ( x t ∣ x 0 ) \log \frac{p(x_{t-1}|x_t)}{q_φ(x_t|x_{t-1}, x_0)} = \log \frac{p(x_{t-1}|x_t)}{q_φ(x_{t-1}|x_t, x_0)} + \log \frac{q_φ(x_{t-1}|x_0)}{q_φ(x_t|x_0)} logqφ(xt∣xt−1,x0)p(xt−1∣xt)=logqφ(xt−1∣xt,x0)p(xt−1∣xt)+logqφ(xt∣x0)qφ(xt−1∣x0)
步骤 5:简化期望
- 重构项:
E q φ ( x 1 ∣ x 0 ) [ log p θ ( x 0 ∣ x 1 ) ] \mathbb{E}_{q_φ(x_1|x_0)} [\log p_θ(x_0|x_1)] Eqφ(x1∣x0)[logpθ(x0∣x1)]
- 先验匹配项:
E q φ ( x 1 : T ∣ x 0 ) [ log p ( x T ) q φ ( x T ∣ x 0 ) ] = − D K L ( q φ ( x T ∣ x 0 ) ∥ p ( x T ) ) \mathbb{E}_{q_φ(x_{1:T}|x_0)} \left[ \log \frac{p(x_T)}{q_φ(x_T|x_0)} \right] = -D_{KL}(q_φ(x_T|x_0) \| p(x_T)) Eqφ(x1:T∣x0)[logqφ(xT∣x0)p(xT)]=−DKL(qφ(xT∣x0)∥p(xT))
- 一致性项:
∑ t = 2 T E q φ ( x 1 : T ∣ x 0 ) [ log p ( x t − 1 ∣ x t ) q φ ( x t − 1 ∣ x t , x 0 ) + log q φ ( x t − 1 ∣ x 0 ) q φ ( x t ∣ x 0 ) ] \sum_{t=2}^T \mathbb{E}_{q_φ(x_{1:T}|x_0)} \left[ \log \frac{p(x_{t-1}|x_t)}{q_φ(x_{t-1}|x_t, x_0)} + \log \frac{q_φ(x_{t-1}|x_0)}{q_φ(x_t|x_0)} \right] t=2∑TEqφ(x1:T∣x0)[logqφ(xt−1∣xt,x0)p(xt−1∣xt)+logqφ(xt∣x0)qφ(xt−1∣x0)]
第二项的和为:
log q φ ( x 1 ∣ x 0 ) q φ ( x T ∣ x 0 ) = log q φ ( x 1 ∣ x 0 ) − log q φ ( x T ∣ x 0 ) \log \frac{q_φ(x_1|x_0)}{q_φ(x_T|x_0)} = \log q_φ(x_1|x_0) - \log q_φ(x_T|x_0) logqφ(xT∣x0)qφ(x1∣x0)=logqφ(x1∣x0)−logqφ(xT∣x0)
但重点是第一项:
E q φ ( x t − 1 , x t ∣ x 0 ) [ log p ( x t − 1 ∣ x t ) q φ ( x t − 1 ∣ x t , x 0 ) ] = − E q φ ( x t ∣ x 0 ) [ D K L ( q φ ( x t − 1 ∣ x t , x 0 ) ∥ p θ ( x t − 1 ∣ x t ) ) ] \mathbb{E}_{q_φ(x_{t-1}, x_t|x_0)} \left[ \log \frac{p(x_{t-1}|x_t)}{q_φ(x_{t-1}|x_t, x_0)} \right] = -\mathbb{E}_{q_φ(x_t|x_0)} \left[ D_{KL}(q_φ(x_{t-1}|x_t, x_0) \| p_θ(x_{t-1}|x_t)) \right] Eqφ(xt−1,xt∣x0)[logqφ(xt−1∣xt,x0)p(xt−1∣xt)]=−Eqφ(xt∣x0)[DKL(qφ(xt−1∣xt,x0)∥pθ(xt−1∣xt))]
步骤 6:范围调整
从 ( t = 2 t=2 t=2 ) 到 ( t = T t=T t=T ) 对应 ( x t − 1 x_{t-1} xt−1 ) 从 ( x 1 x_1 x1 ) 到 ( x T − 1 x_{T-1} xT−1 ),与过渡块 ( t = 1 t=1 t=1 ) 到 ( T − 1 T-1 T−1 ) 一致,调整索引。
最终 ELBO
ELBO φ , θ ( x ) = E q φ ( x 1 ∣ x 0 ) [ log p θ ( x 0 ∣ x 1 ) ] − D K L ( q φ ( x T ∣ x 0 ) ∥ p ( x T ) ) − ∑ t = 2 T E q φ ( x t ∣ x 0 ) [ D K L ( q φ ( x t − 1 ∣ x t , x 0 ) ∥ p θ ( x t − 1 ∣ x t ) ) ] \text{ELBO}_{φ,θ}(x) = \mathbb{E}_{q_φ(x_1|x_0)} [\log p_θ(x_0|x_1)] - D_{KL}(q_φ(x_T|x_0) \| p(x_T)) - \sum_{t=2}^T \mathbb{E}_{q_φ(x_t|x_0)} \left[ D_{KL}(q_φ(x_{t-1}|x_t, x_0) \| p_θ(x_{t-1}|x_t)) \right] ELBOφ,θ(x)=Eqφ(x1∣x0)[logpθ(x0∣x1)]−DKL(qφ(xT∣x0)∥p(xT))−t=2∑TEqφ(xt∣x0)[DKL(qφ(xt−1∣xt,x0)∥pθ(xt−1∣xt))]
推导总结
- 贝叶斯定理将 ( q φ ( x t ∣ x t − 1 , x 0 ) q_φ(x_t|x_{t-1}, x_0) qφ(xt∣xt−1,x0) ) 转化为 ( q φ ( x t − 1 ∣ x t , x 0 ) q_φ(x_{t-1}|x_t, x_0) qφ(xt−1∣xt,x0) ),与 ( p θ ( x t − 1 ∣ x t ) p_θ(x_{t-1}|x_t) pθ(xt−1∣xt) ) 方向一致。
- 期望从联合分布简化为单变量,消除了 ( x t + 1 x_{t+1} xt+1 ) 的依赖。
- 新的 ELBO 保持优化目标,简化了采样复杂性。
代码实现片段(伪代码)
def elbo_loss_new(x0, model, T, alpha_schedule):elbo = 0.0x1 = forward_transition(x0, alpha_schedule[1])elbo += torch.mean(model.log_prob_x0_given_x1(x0, x1)) # ReconstructionxT = forward_multi_step(x0, alpha_schedule)kl_prior = kl_divergence(xT, torch.zeros_like(xT), torch.ones_like(xT))elbo -= kl_prior # Prior matchingfor t in range(2, T + 1):xt = forward_step(x0, t, alpha_schedule)xt_minus_1 = forward_step(x0, t - 1, alpha_schedule)kl_cons = kl_divergence(xt_minus_1, model.reverse_mean(xt, t), model.reverse_cov(xt, t))elbo -= torch.mean(kl_cons) # Consistencyreturn elbo
总结
重构后的 ELBO 通过贝叶斯调整消除了联合采样的复杂性,保持了模型的优化能力。这一设计体现了扩散模型的灵活性,为高效训练提供了可能。
希望这篇推导帮助你理解!
后记
2025年3月5日18点17分于上海,在grok 3大模型辅助下完成。