扩散模型 DDPM 核心代码梳理

news/2024/11/16 19:29:57/

参考内容:

大白话AI | 图像生成模型DDPM | 扩散模型 | 生成模型 | 概率扩散去噪生成模型
AIGC 基础,从VAE到DDPM 原理、代码详解
全网最简单的扩散模型DDPM教程
The Annotated Diffusion Model
LaTeX公式编辑器

备注: 具体公式的推导请查看参考链接,本文只记录核心步骤的几个核心公式。

什么是扩散模型?

与Normalizing Flows、GAN或VAEs等生成模型一样,它们都将噪声从一些简单分布转换为数据样本。这也是使用神经网络学习从纯噪声开始逐渐去噪进行内容生成的过程。扩散模型主要包括以下两个过程:

  • 前向加噪: 前向加噪过程是一个固定的、预定义的过程,通过逐步的往一张真实图像上添加高斯噪声,最终得到一个完全的高斯噪声图像
  • 反向去噪: 反向去噪过程通过训练学习一个神经网络模型,模型的输入是一张带有噪声的图像,模型的输出是预测得到的噪声,逐步减去预测的噪声,最终得到一张真实的图像
    在这里插入图片描述

加噪、去噪、训练、推理阶段相关的数学公式

  • 前向加噪

在前向加噪过程中,逐步的往真实图片上添加高斯噪声,每一步添加高斯噪声的公式表示如下:
x t = 1 − β t x t − 1 + β t ϵ t \begin{equation}x_{t} = \sqrt{1-\beta_{t}}x_{t-1} + \sqrt{\beta_{t}}\epsilon_{t}\end{equation} xt=1βt xt1+βt ϵt
其中, 0 < β 1 < β 2 < ⋯ < β T < 1 0 < \beta_{1} < \beta_{2} < \dots < \beta_{T} < 1 0<β1<β2<<βT<1 ϵ ∼ N ( 0 , 1 ) \epsilon \sim N(0,1) ϵN(0,1) β t \beta_{t} βt的取值可以想神经网络的学习率衰减那样,使用线性的、余弦变化的。由于正态分布的均值和方差具有可加性,从[1, T]时刻逐步添加噪声的过程可以通过一步得到:
x t = α t ˉ x 0 + 1 − α t ˉ ϵ \begin{equation}x_{t} = \sqrt{\bar{\alpha_{t}}}x_{0} + \sqrt{1 - \bar{\alpha_{t}}}\epsilon\end{equation} xt=αtˉ x0+1αtˉ ϵ
其中, α t = 1 − β t \alpha_{t} = 1 - \beta_{t} αt=1βt α t ˉ = α t α t − 1 … α 1 \bar{\alpha_{t}} = \alpha_{t}\alpha_{t-1}\dots\alpha_{1} αtˉ=αtαt1α1

  • 模型训练

在模型训练阶段,对于一个真实的图像数据,随机生成[1, T]之前的整数,表示往真实图片数据中添加噪声的次数,然后将添加噪声后的图片输入到神经网络模型中,预测添加的噪声,基于神经网络预测的噪声和真实添加的噪声,计算损失:
L o s s = ∣ ∣ ϵ − ϵ θ ( α t ˉ x 0 + 1 − α t ˉ ∗ ϵ , t ) ∣ ∣ 2 \begin{equation}Loss = ||\epsilon -\epsilon_{\theta}(\sqrt{\bar{\alpha_{t}}}x_{0} + \sqrt{1 - \bar{\alpha_{t}}}*\epsilon, t)||^{2}\end{equation} Loss=∣∣ϵϵθ(αtˉ x0+1αtˉ ϵ,t)2
其中, ϵ \epsilon ϵ表示在前向加噪过程中,使用公式(2)往真实图片中添加的随机噪声, ϵ θ \epsilon_{\theta} ϵθ表示一个神经网络模型,输入一个带有噪声的图像,以及对应添加噪声的时间步数,输出预测的噪声, x 0 x_{0} x0表示原始的真实图像, t t t表示时间步数。
在这里插入图片描述

  • 反向去噪

在反向去噪过程中,使用神经网络预测输出一个和输入图像一样大小的噪声数据,从输入图像中减去噪声数据,实现去噪。
x t − 1 = 1 α t ( x t − β t β t ˉ ∗ ϵ θ ( x t , t ) ) + δ t ∗ z \begin{equation}x_{t-1} = \frac{1}{\sqrt{\alpha_{t}}}(x_{t} - \frac{\beta_{t}}{\sqrt{\bar{\beta_{t}}}}*\epsilon _{\theta }(x_{t},t)) + \delta_{t}*z\end{equation} xt1=αt 1(xtβtˉ βtϵθ(xt,t))+δtz
其中, ϵ θ \epsilon _{\theta} ϵθ是一个神经网络模型, ϵ θ ( x t , t ) \epsilon _{\theta }(x_{t},t) ϵθ(xt,t)是神经网络模型预测输出的噪声, β t ˉ = 1 − α t ˉ \bar{\beta_{t}} = 1 - \bar{\alpha_{t}} βtˉ=1αtˉ

  • 模型推理

在模型推理阶段,也就是模型训练完之后进行图像的生成阶段,设置好迭代生成的时间步数 t t t,通过一个随机噪声 x t x_{t} xt,不断执行下面的步骤,直到公式(5)中的 t = 1 t = 1 t=1,实现图像的生成:
x t − 1 = 1 α t ( x t − β t β t ˉ ∗ ϵ θ ( x t , t ) ) + δ t ∗ z \begin{equation}x_{t-1} = \frac{1}{\sqrt{\alpha_{t}}}(x_{t} - \frac{\beta_{t}}{\sqrt{\bar{\beta_{t}}}}*\epsilon _{\theta }(x_{t},t)) + \delta_{t}*z\end{equation} xt1=αt 1(xtβtˉ βtϵθ(xt,t))+δtz
x t = x t − 1 \begin{equation}x_{t} = x_{t-1}\end{equation} xt=xt1
t = t − 1 \begin{equation}t = t-1\end{equation} t=t1

当公式(5)中的 t = 1 t = 1 t=1时,也就是最后一轮去噪,不加 δ t ∗ z \delta_{t}*z δtz,最后得到的 x 0 x_{0} x0就是生成的图像内容。
在这里插入图片描述

UNet网络结构

UNet神经网络在特定的时间步 t t t 接收噪声图像并返回预测的噪声。预测的噪声是一个与输入图像具有相同的大小/分辨率的张量。从技术上讲,网络输入和输出相同形状的张量。在DDPM中采用UNet架构的神经网络,UNet网络中主要包括以下部分:
在这里插入图片描述

  • 下采样:使用卷积 + 池化的方式实现图像分辨率的下采样
  • 上采样:使用转置卷积或者线性插值的方式,提升特征图的分辨率
  • Short-cut连接:将下采样和上采样得到的分辨率相同额特征图在通道维度上进行融合,有利于捕捉细粒度的图像特征
  • 注意力机制:使用注意力机制计算特征图上每个位置之间的注意力关系
  • time-embedding:由于DDPM是逐步生成图像的,所以需要一个特征能够标记当前执行到哪个时间步了

DDPM核心代码解释

  1. 基础代码:构造 α , β , α ˉ , β ˉ \alpha,\beta,\bar{\alpha},\bar{\beta} α,β,αˉ,βˉ等参数
  • 使用不同的策略构建 β \beta β 序列
def linear_beta_schedule(timesteps):"""在0.0001到0.02之间,均匀采样timesteps个数值,构造成beta序列"""beta_start = 0.0001beta_end = 0.02return torch.linspace(beta_start, beta_end, timesteps)def cosine_beta_schedule(timesteps, s=0.008):"""cosine schedule as proposed in https://arxiv.org/abs/2102.09672"""steps = timesteps + 1x = torch.linspace(0, timesteps, steps)alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2alphas_cumprod = alphas_cumprod / alphas_cumprod[0]betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])return torch.clip(betas, 0.0001, 0.9999)def quadratic_beta_schedule(timesteps):beta_start = 0.0001beta_end = 0.02return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps) ** 2def sigmoid_beta_schedule(timesteps):beta_start = 0.0001beta_end = 0.02betas = torch.linspace(-6, 6, timesteps)return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
  • 根据生成的 β \beta β 序列,生成 α , α ˉ , β ˉ \alpha,\bar{\alpha},\bar{\beta} α,αˉ,βˉ等, α , β , α ˉ , β ˉ \alpha,\beta,\bar{\alpha},\bar{\beta} α,β,αˉ,βˉ等参数的序列长度等于最大的迭代步长timesteps
timesteps = 300# define beta schedule
betas = linear_beta_schedule(timesteps=timesteps)# define alphas 
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)# calculations for diffusion q(x_t | x_{t-1}) and others
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
  • 备注
    • betas对应 β \beta β
    • alphas对应 α = 1 − β \alpha = 1 - \beta α=1β
    • alphas_cumprod对应 α ˉ \bar{\alpha} αˉ
    • sqrt_recip_alphas对应 1 α \frac{1}{\sqrt{\alpha}} α 1
    • sqrt_alphas_cumprod对应 1 α ˉ \frac{1}{\sqrt{\bar{\alpha}}} αˉ 1
    • sqrt_one_minus_alphas_cumprod对应 1 − α ˉ \sqrt{1 - \bar{\alpha}} 1αˉ
  • 在训练阶段对于batch中的每个样本,加噪的迭代次数是从[0, T]中进行随机采样的,所以训练阶段每个样本的加噪次数 t ∈ [ 0 , T ] t \in [0, T] t[0,T] 是不同的,使用gather函数获取到每个样本的t对应的 α , β , α ˉ , β ˉ \alpha,\beta,\bar{\alpha},\bar{\beta} α,β,αˉ,βˉ等参数,对应的代码如下:
def extract(a, t, x_shape):batch_size = t.shape[0]out = a.gather(-1, t.cpu())return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)
  1. 前向加噪:根据上一步计算得到的 α , β , α ˉ , β ˉ \alpha,\beta,\bar{\alpha},\bar{\beta} α,β,αˉ,βˉ等参数,将一张真实图像 x 0 x_{0} x0 使用公式(2)进行多次加噪,得到加噪后的图像,对应代码如下:
def q_sample(x_start, t, noise=None):if noise is None:noise = torch.randn_like(x_start)# x_start就是前面讲的最原始图像 x_0,根据 t 获取到对应的alpha,beta等参数sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape)sqrt_one_minus_alphas_cumprod_t = extract(sqrt_one_minus_alphas_cumprod, t, x_start.shape)# 使用公式(2)对图像进行前向加噪return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
  1. UNet模型:将加噪后的样本以及每个样本对应的加噪次数 t 输入到UNet网络模型中,UNet模型预测输出加入的噪声,将UNet的输出结果与加入到图像中的噪声使用公式(3)计算损失,训练UNet网络模型。
def p_losses(denoise_model, x_start, t, noise=None, loss_type="l1"):if noise is None:noise = torch.randn_like(x_start)# x_start就是前面讲的最原始图像 x_0,这一步就是往 x_0 中加入t次的噪声x_noisy = q_sample(x_start=x_start, t=t, noise=noise)# 将加入噪声的图像以及对应的时间步数 t 输入到UNet模型predicted_noise = denoise_model(x_noisy, t)# 将UNet预测的结果与加入的噪声计算损失if loss_type == 'l1':loss = F.l1_loss(noise, predicted_noise)elif loss_type == 'l2':loss = F.mse_loss(noise, predicted_noise)elif loss_type == "huber":loss = F.smooth_l1_loss(noise, predicted_noise)else:raise NotImplementedError()return loss
  1. 模型推理:当训练完UNet之后,在模型推理也就是图像生成阶段执行反向去噪过程。首先生成一张纯噪声的图像,初始时间步设置为timesteps,将噪声图像和时间步数值 t 输入到UNet模型中,预测得到输出结果,然后使用公式(4)计算得到经过去噪之后 t-1时间步的输出,如此迭代,直到 t=0为止。
def p_sample(model, x, t, t_index):betas_t = extract(betas, t, x.shape)sqrt_one_minus_alphas_cumprod_t = extract(sqrt_one_minus_alphas_cumprod, t, x.shape)sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)# Equation 11 in the paper# Use our model (noise predictor) to predict the meanmodel_mean = sqrt_recip_alphas_t * (x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t)if t_index == 0:return model_meanelse:posterior_variance_t = extract(posterior_variance, t, x.shape)noise = torch.randn_like(x)# Algorithm 2 line 4:return model_mean + torch.sqrt(posterior_variance_t) * noise # Algorithm 2 (including returning all images)def p_sample_loop(model, shape):device = next(model.parameters()).deviceb = shape[0]# start from pure noise (for each example in the batch)img = torch.randn(shape, device=device)imgs = []for i in tqdm(reversed(range(0, timesteps)), desc='sampling loop time step', total=timesteps):img = p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long), i)imgs.append(img.cpu().numpy())return imgsdef sample(model, image_size, batch_size=16, channels=3):return p_sample_loop(model, shape=(batch_size, channels, image_size, image_size))

注意事项:

  • torch.randn生成符合标准正态分布的数据,torch.rand生成符合0-1之间均匀分布的数据
  • UNet有利于细粒度的图像生成

http://www.ppmy.cn/news/1094380.html

相关文章

数字化医院信息云平台源代码 HIS系统全套成品源代码

基层医疗云HIS作为基于云计算的B/S构架的HIS系统&#xff0c;为基层医疗机构提供了标准化的、信息化的、可共享的医疗信息管理系统&#xff0c;可有效进行医疗数据共享与交换&#xff0c;解决数据重复采集及信息孤岛等问题&#xff0c;实现对基层医疗数据的分析和挖掘&#xff…

JavaScipt中如何实现函数缓存?函数缓存有哪些场景?

1、函数缓存是什么&#xff1f; 函数缓存就是将函数运行的结果进行缓存。本质上就是用空间&#xff08;缓存存储&#xff09;换时间&#xff08;计算过程&#xff09; 常用于缓存数据计算结果和缓存对象。 缓存只是一个临时的数据存储&#xff0c;它保存数据&#xff0c;以便将…

Android知识点整理

关键点 Activity Fragment 调试应用 处理应用程序配置 Intent 和 Intent 过滤器 会使用Context 后台处理指南 Android 的数据隐私 Android 网络数据安全教程 Android 中的依赖项注入 内容提供程序 Android 内存管理概览 一些重要的库 1.Glide 是一个 Android 上的…

elementPlus + table 树形懒加载 新增,删除,修改 局部刷新

#直接上代码# 1.表格数据 2.数据源 <m-table ref"cTable" v-if"Object.keys(props.tableData).length" :options"props.tableOptions" :data"props.tableData.data" :isLoading"props.tableData.loading" elementLo…

OJ练习第165题——修车的最少时间

修车的最少时间 力扣链接&#xff1a;2594. 修车的最少时间 题目描述 给你一个整数数组 ranks &#xff0c;表示一些机械工的 能力值 。ranksi 是第 i 位机械工的能力值。能力值为 r 的机械工可以在 r * n2 分钟内修好 n 辆车。 同时给你一个整数 cars &#xff0c;表示总…

K8S:kubeadm搭建K8S+Harbor 私有仓库

文章目录 一.部署规划1.主机规划2.部署流程 二.kubeadm搭建K8S1.环境准备2.安装docker3. 安装kubeadm&#xff0c;kubelet和kubectl4.部署K8S集群&#xff08;1&#xff09;初始化&#xff08;2&#xff09;部署网络插件flannel&#xff08;3&#xff09;创建 pod 资源 5.部署 …

初入行的IC工程师,如何快速提高自己的竞争力?

要想成为越来越吃香的IC工程师&#xff0c;就会先经历初期的成长阶段。今天就来聊聊初入行的ICer如何快速提升自己的竞争力&#xff08;验证篇&#xff09;。 首先希望大家在选择IC行业的时候就有清晰的认知&#xff0c;这是一个不得不深耕技术的行业。我们今天所谈论的快速提…

报错:crbug/1173575 non-js module files deprecated

环境&#xff1a; vue3 &#xff0c; visual studio code, bulma 背景&#xff1a; 在代码中&#xff0c;使用标签来进行导航栏跳转。 如&#xff1a; <div class"navbar-start"><a href"/groups">产品</router-link> </div>执…