DDPM | 扩散模型代码详解【较为详细细致!!!】

server/2024/9/23 14:28:35/

文章目录

    • 1、UNet网络结构
      • 1.1 residual网络和attention网络的细节
      • 1.2 t 的作用
      • 1.3 DDPM 中的 Positional Embedding 的使用
      • 1.4 DDPM 中的 Positional Embedding 代码
      • 1.5 residual block
      • 1.6 attention block
      • 1.7 UNet结构
    • 2、命令行参数解析
    • 3、数据的获取与预处理
    • 4、模型的训练框架
    • 参考:

1、UNet网络结构

UNet网络的总体框架如下,右边是UNet网络的整体框架,左边是residual网络和attention网络,

在这里插入图片描述

下面是UNet网络的详解结构图,左边进行有规律地残差、下采样、attention,右边也是有规律地残差、上采样、attention,相关的代码在图中给出,

在这里插入图片描述

1.1 residual网络和attention网络的细节

熟悉CNN的同学应该能看懂下图中的大部分过程。其中的 t 是时间从0到1000的随机值,假如是888,经过Positional Embedding输出长度是128的向量,下面再经过全连接层和silu层等,下面会详细讲解Positional Embedding、residual网络和attention网络,

在这里插入图片描述

1.2 t 的作用

1、和原图像一起,计算出 t 时刻的图像 x t = 1 − α t ‾ ϵ + α t ‾ x 0 x_t=\sqrt{1-\overline{\alpha_t}}\epsilon+\sqrt{\overline{\alpha_t}}x_0 xt=1αt ϵ+αt x0
2、将 t 进行编码,编码后,加到模型中,使模型学习到当前在哪个时刻

在这里插入图片描述

1.3 DDPM 中的 Positional Embedding 的使用

左图是Transformer的Positional Embedding,行索引代表第几个单词,列索引代表每个单词的特征向量,右图是DDPM的Positional Embedding,DDPM的Positional Embedding和Transformer的Positional Embedding的区别是DDPM的Positional Embedding并不是给每个词位置编码的,只需要在1000行中随机取出一行就可以了;另一个区别是DDPM的Positional Embedding并没有按照奇数位和偶数位进行拼接,而是按照前后的sin和cos进行拼接的,虽然拼接方式不同,但是最终的效果是一样的。如下图所示,
位置编码只要能保证每一行的唯一性,以及每一行和其他行的关系性就可以了。

在这里插入图片描述

1.4 DDPM 中的 Positional Embedding 代码

代码:

class PositionalEmbedding(nn.Module):__doc__ = r"""..."""def init (self, dim, scale=1.0):super().__init__()assert dim % 2 == 0self.dim = dim  # 特征向量self.scale = scale  # 正弦函数和余弦函数的周期不做调整def forward(self, x):  # x:表示t,从0-1000中随机出来的一个数值,因为设置batch-size=2,所以假设x:tensor([645,958])device = x.devicehalf_dim = self.dim // 2emb = math.log(10000) / half_dimemb = torch.exp(torch.arange(half_dim, device=device) * - emb)emb = torch.outer(x * self.scale, emb)emb = torch.cat((emb.sin(), emb.cos()), dim=-1)return emb

代码解释:

下图中的 e m b 2 × 64 emb_{2\times64} emb2×64中的2表示batch-size等于2,

在这里插入图片描述

使用位置:

在这里插入图片描述

1.5 residual block

原代码:

class ResidualBlock(nn.Module):__doc__ = r"""Applies two conv blocks with resudual connection. Adds time and class conditioning by adding bias after first convolution.Input:x: tensor of shape (N, in_channels, H, W)time_emb: time embedding tensor of shape (N, time_emb_dim) or None if the block doesn't use time conditioningy: classes tensor of shape (N) or None if the block doesn't use class conditioningOutput:tensor of shape (N, out_channels, H, W)Args:in_channels (int): number of input channelsout_channels (int): number of output channelstime_emb_dim (int or None): time embedding dimension or None if the block doesn't use time conditioning. Default: Nonenum_classes (int or None): number of classes or None if the block doesn't use class conditioning. Default: Noneactivation (function): activation function. Default: torch.nn.functional.relunorm (string or None): which normalization to use (instance, group, batch, or none). Default: "gn"num_groups (int): number of groups used in group normalization. Default: 32use_attention (bool): if True applies AttentionBlock to the output. Default: False"""def __init__(self,in_channels,out_channels,dropout,time_emb_dim=None,num_classes=None,activation=F.relu,norm="gn",num_groups=32,use_attention=False,):super().__init__()self.activation = activationself.norm_1 = get_norm(norm, in_channels, num_groups)self.conv_1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)self.norm_2 = get_norm(norm, out_channels, num_groups)self.conv_2 = nn.Sequential(nn.Dropout(p=dropout),nn.Conv2d(out_channels, out_channels, 3, padding=1),)self.time_bias = nn.Linear(time_emb_dim, out_channels) if time_emb_dim is not None else Noneself.class_bias = nn.Embedding(num_classes, out_channels) if num_classes is not None else Noneself.residual_connection = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()self.attention = nn.Identity() if not use_attention else AttentionBlock(out_channels, norm, num_groups)def forward(self, x, time_emb=None, y=None):out = self.activation(self.norm_1(x))out = self.conv_1(out)if self.time_bias is not None:if time_emb is None:raise ValueError("time conditioning was specified but time_emb is not passed")out += self.time_bias(self.activation(time_emb))[:, :, None, None]if self.class_bias is not None:if y is None:raise ValueError("class conditioning was specified but y is not passed")out += self.class_bias(y)[:, :, None, None]out = self.activation(self.norm_2(out))out = self.conv_2(out) + self.residual_connection(x)out = self.attention(out)return out

代码解释:

在这里插入图片描述
在这里插入图片描述

1.6 attention block

UNet网络中一共有5个attention block,每个attention block的输入尺寸都是256x16x16,输入尺寸和输出尺寸相同,

原代码:

class AttentionBlock(nn.Module):__doc__ = r"""Applies QKV self-attention with a residual connection.Input:x: tensor of shape (N, in_channels, H, W)norm (string or None): which normalization to use (instance, group, batch, or none). Default: "gn"num_groups (int): number of groups used in group normalization. Default: 32Output:tensor of shape (N, in_channels, H, W)Args:in_channels (int): number of input channels"""def __init__(self, in_channels, norm="gn", num_groups=32):super().__init__()self.in_channels = in_channelsself.norm = get_norm(norm, in_channels, num_groups)self.to_qkv = nn.Conv2d(in_channels, in_channels * 3, 1)self.to_out = nn.Conv2d(in_channels, in_channels, 1)def forward(self, x):b, c, h, w = x.shapeq, k, v = torch.split(self.to_qkv(self.norm(x)), self.in_channels, dim=1)q = q.permute(0, 2, 3, 1).view(b, h * w, c)k = k.view(b, c, h * w)v = v.permute(0, 2, 3, 1).view(b, h * w, c)dot_products = torch.bmm(q, k) * (c ** (-0.5))assert dot_products.shape == (b, h * w, h * w)attention = torch.softmax(dot_products, dim=-1)out = torch.bmm(attention, v)assert out.shape == (b, h * w, c)out = out.view(b, h, w, c).permute(0, 3, 1, 2)return self.to_out(out) + x

代码解释:

在这里插入图片描述

1.7 UNet结构

UNet的输入有2个部分,一个输入是之前介绍的time embedding,它是需要在每个 residual block 添加进来,另外一个输入是加噪后的数据 x t x_t xt,加噪后尺寸不变,

UNet原代码:

class UNet(nn.Module):__doc__ = """UNet model used to estimate noise.Input:x: tensor of shape (N, in_channels, H, W)time_emb: time embedding tensor of shape (N, time_emb_dim) or None if the block doesn't use time conditioningy: classes tensor of shape (N) or None if the block doesn't use class conditioningOutput:tensor of shape (N, out_channels, H, W)Args:img_channels (int): number of image channelsbase_channels (int): number of base channels (after first convolution)channel_mults (tuple): tuple of channel multiplers. Default: (1, 2, 4, 8)time_emb_dim (int or None): time embedding dimension or None if the block doesn't use time conditioning. Default: Nonetime_emb_scale (float): linear scale to be applied to timesteps. Default: 1.0num_classes (int or None): number of classes or None if the block doesn't use class conditioning. Default: Noneactivation (function): activation function. Default: torch.nn.functional.reludropout (float): dropout rate at the end of each residual blockattention_resolutions (tuple): list of relative resolutions at which to apply attention. Default: ()norm (string or None): which normalization to use (instance, group, batch, or none). Default: "gn"num_groups (int): number of groups used in group normalization. Default: 32initial_pad (int): initial padding applied to image. Should be used if height or width is not a power of 2. Default: 0"""def __init__(self,img_channels,base_channels,channel_mults=(1, 2, 4, 8),num_res_blocks=2,time_emb_dim=None,time_emb_scale=1.0,num_classes=None,activation=F.relu,dropout=0.1,attention_resolutions=(),norm="gn",num_groups=32,initial_pad=0,):super().__init__()self.activation = activationself.initial_pad = initial_padself.num_classes = num_classesself.time_mlp = nn.Sequential(PositionalEmbedding(base_channels, time_emb_scale),nn.Linear(base_channels, time_emb_dim),nn.SiLU(),nn.Linear(time_emb_dim, time_emb_dim),) if time_emb_dim is not None else Noneself.init_conv = nn.Conv2d(img_channels, base_channels, 3, padding=1)self.downs = nn.ModuleList()self.ups = nn.ModuleList()channels = [base_channels]now_channels = base_channelsfor i, mult in enumerate(channel_mults):out_channels = base_channels * multfor _ in range(num_res_blocks):self.downs.append(ResidualBlock(now_channels,out_channels,dropout,time_emb_dim=time_emb_dim,num_classes=num_classes,activation=activation,norm=norm,num_groups=num_groups,use_attention=i in attention_resolutions,))now_channels = out_channelschannels.append(now_channels)if i != len(channel_mults) - 1:self.downs.append(Downsample(now_channels))channels.append(now_channels)self.mid = nn.ModuleList([ResidualBlock(now_channels,now_channels,dropout,time_emb_dim=time_emb_dim,num_classes=num_classes,activation=activation,norm=norm,num_groups=num_groups,use_attention=True,),ResidualBlock(now_channels,now_channels,dropout,time_emb_dim=time_emb_dim,num_classes=num_classes,activation=activation,norm=norm,num_groups=num_groups,use_attention=False,),])for i, mult in reversed(list(enumerate(channel_mults))):out_channels = base_channels * multfor _ in range(num_res_blocks + 1):self.ups.append(ResidualBlock(channels.pop() + now_channels,out_channels,dropout,time_emb_dim=time_emb_dim,num_classes=num_classes,activation=activation,norm=norm,num_groups=num_groups,use_attention=i in attention_resolutions,))now_channels = out_channelsif i != 0:self.ups.append(Upsample(now_channels))assert len(channels) == 0self.out_norm = get_norm(norm, base_channels, num_groups)self.out_conv = nn.Conv2d(base_channels, img_channels, 3, padding=1)def forward(self, x, time=None, y=None):ip = self.initial_padif ip != 0:x = F.pad(x, (ip,) * 4)if self.time_mlp is not None:if time is None:raise ValueError("time conditioning was specified but tim is not passed")time_emb = self.time_mlp(time)else:time_emb = Noneif self.num_classes is not None and y is None:raise ValueError("class conditioning was specified but y is not passed")x = self.init_conv(x)skips = [x]for layer in self.downs:x = layer(x, time_emb, y)skips.append(x)for layer in self.mid:x = layer(x, time_emb, y)for layer in self.ups:if isinstance(layer, ResidualBlock):x = torch.cat([x, skips.pop()], dim=1)x = layer(x, time_emb, y)x = self.activation(self.out_norm(x))x = self.out_conv(x)if self.initial_pad != 0:return x[:, :, ip:-ip, ip:-ip]else:return x

代码解释:

整体解释结构如下:

在这里插入图片描述

—分割线—

这是时间编码的解释,self.time_mlp的输入是 t ,是0-1000中的随机数值,

在这里插入图片描述

—分割线—

这是下采样模块的解释,

在这里插入图片描述

—分割线—

这是middle部分的解释:

在这里插入图片描述

—分割线—

这是up部分的解释:

在这里插入图片描述

2、命令行参数解析

在这里插入图片描述


在这里插入图片描述

原代码:

'''
code from https://github.com/abarankab/DDPM/tree/main
'''import argparse
import datetime
import torch
import wandbfrom torch.utils.data import DataLoader
from torchvision import datasets
from ddpm import script_utilsdef main():args = create_argparser().parse_args()device = args.devicetry:diffusion = script_utils.get_diffusion_from_args(args).to(device)optimizer = torch.optim.Adam(diffusion.parameters(), lr=args.learning_rate)if args.model_checkpoint is not None:diffusion.load_state_dict(torch.load(args.model_checkpoint))if args.optim_checkpoint is not None:optimizer.load_state_dict(torch.load(args.optim_checkpoint))if args.log_to_wandb:if args.project_name is None:raise ValueError("args.log_to_wandb set to True but args.project_name is None")run = wandb.init(project=args.project_name,# entity='treaptofun',  # 用于指定实验所属的团队或组织config=vars(args),name=args.run_name,)wandb.watch(diffusion)batch_size = args.batch_sizetrain_dataset = datasets.CIFAR10(root='./cifar_train',train=True,download=True,transform=script_utils.get_transform(),)test_dataset = datasets.CIFAR10(root='./cifar_test',train=False,download=True,transform=script_utils.get_transform(),)train_loader = script_utils.cycle(DataLoader(train_dataset,batch_size=batch_size,shuffle=True,drop_last=True,num_workers=0,))test_loader = DataLoader(test_dataset, batch_size=batch_size, drop_last=True, num_workers=0)acc_train_loss = 0for iteration in range(1, args.iterations + 1):diffusion.train()x, y = next(train_loader)x = x.to(device)y = y.to(device)if args.use_labels:loss = diffusion(x, y)else:loss = diffusion(x)acc_train_loss += loss.item()optimizer.zero_grad()loss.backward()optimizer.step()diffusion.update_ema()if iteration % args.log_rate == 0:test_loss = 0with torch.no_grad():diffusion.eval()for x, y in test_loader:x = x.to(device)y = y.to(device)if args.use_labels:loss = diffusion(x, y)else:loss = diffusion(x)test_loss += loss.item()if args.use_labels:samples = diffusion.sample(10, device, y=torch.arange(10, device=device))else:samples = diffusion.sample(10, device)samples = ((samples + 1) / 2).clip(0, 1).permute(0, 2, 3, 1).numpy()test_loss /= len(test_loader)acc_train_loss /= args.log_ratewandb.log({"test_loss": test_loss,"train_loss": acc_train_loss,"samples": [wandb.Image(sample) for sample in samples],})acc_train_loss = 0if iteration % args.checkpoint_rate == 0:model_filename = f"{args.log_dir}/{args.project_name}-{args.run_name}-iteration-{iteration}-model.pth"optim_filename = f"{args.log_dir}/{args.project_name}-{args.run_name}-iteration-{iteration}-optim.pth"torch.save(diffusion.state_dict(), model_filename)torch.save(optimizer.state_dict(), optim_filename)if args.log_to_wandb:run.finish()except KeyboardInterrupt:if args.log_to_wandb:run.finish()print("Keyboard interrupt, run finished early")def create_argparser():device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")run_name = datetime.datetime.now().strftime("ddpm-%Y-%m-%d-%H-%M")defaults = dict(learning_rate=2e-4,batch_size=2,iterations=800000,log_to_wandb=True,log_rate=1000,checkpoint_rate=1000,log_dir="~/ddpm_logs",project_name="Enzo_ddpm",run_name=run_name,model_checkpoint=None,optim_checkpoint=None,schedule_low=1e-4,schedule_high=0.02,device=device,)defaults.update(script_utils.diffusion_defaults())parser = argparse.ArgumentParser()script_utils.add_dict_to_argparser(parser, defaults)return parserif __name__ == "__main__":main()

3、数据的获取与预处理

为什么要将图像 x 0 x_0 x0 像素值映射到 [-1,1]之间?

  • 因为图像后面加入的噪声是服从均值为0,方差为1的分布,原图像的像素值要和噪音的值做一个加权和(也就是加噪过程),所以也需要把原图像处理为均值为0的分布,

在这里插入图片描述

4、模型的训练框架

下图是betas的生成代码,以及代码的整体框架,

在这里插入图片描述

参考:

1、哔哩哔哩视频
2、https://github.com/Enzo-MiMan/cv_related_collections/tree/main/diffusion


http://www.ppmy.cn/server/102618.html

相关文章

仿RabbitMq实现简易消息队列正式篇(路由匹配篇)

TOC 目录 路由匹配模块 代码展示 路由匹配模块 决定了一条消息是否能够发布到指定的队列 在每个队列根交换机的绑定信息中,都有一个binding_key(在虚拟机篇有说到)这是队列发布的匹配规则 在每条要发布的消息中,都有一个rout…

海山数据库(He3DB)源码详解:CommitTransaction函数源码详解

文章目录 海山数据库(He3DB)源码详解:CommitTransaction函数1. 执行条件2. 执行过程2.1 获取当前节点状态:2.2 检查当前状态:2.3 预提交处理:2.4 提交处理:2.5 释放资源:2.6 提交事务: 作者介绍…

Servlet---axios框架 ▎路由守卫

前言 在现代Web应用中,前端和后端通常分离,前端使用框架(如Vue.js、React)与后端服务交互。Servlet是Java EE中处理HTTP请求的重要组成部分,能够生成动态Web内容。 Axios是一个基于Promise的HTTP客户端,简…

并查集(模板+例题)

文章目录 模板[1249. 亲戚 - AcWing题库](https://www.acwing.com/problem/content/description/1251/)思路代码 [237. 程序自动分析 - AcWing题库](https://www.acwing.com/file_system/file/content/whole/index/content/3788/)思路代码 [145. 超市 - AcWing题库](https://ww…

C++面向对象编程(上)

类与对象属于面向对象的程序设计思想(Object Oriented Programming),简称OOP。 面向对象基础理论 面向对象是一种对现实世界理解和抽象的方法,是计算机编程技术发展到一定阶段后的产物,是一种软件开发的方法 面向对象四大特性 1.抽象 忽…

1.微服务发展阶段

单体应用阶段 简介 系统业务量很小的时候我们把所有的代码都放在一个项目中,然后将这个项目部署在一台服务器上,整个项目所有的服务都由这台服务器去提供 优点 1.展现层、控制层、持久层全都在一个应用里面,调用方便、快速,单个请…

【Qt】Qt窗口 | QDialog 对话框

文章目录 一. 对话框二. 对话框的分类1. 非模态对话框2. 模态对话框3. 混合属性对话框 三. 自定义对话框1. 代码实现2. ui文件实现 四. 内置对话框1. QMessageBox 消息对话框2. QColorDialog 颜色对话框3. QFileDialog 文件对话框4. QFontDialog 字体对话框5. QInputDialog 输入…

直播App遭受抓包后的DDoS与CC攻击防御策略

随着直播应用的普及,越来越多的用户开始依赖这些平台进行娱乐和社交活动。然而,这也使得直播平台成为网络攻击的目标之一。其中,DDoS(分布式拒绝服务)攻击和CC(Challenge Collapsar,即HTTP慢速攻…