Pytorch实现扩散模型【DDPM代码解读篇2】

embedded/2024/11/24 12:09:35/

扩散的代码实现

本文承接  Pytorch实现扩散模型【DDPM代码解读篇1】http://t.csdnimg.cn/aDK0A

主要介绍“扩散是如何实现的”。代码逻辑清晰,可快速上手学习。

# 扩散的代码实现
# 扩散过程是训练部分的模型。它打开了一个采样接口,允许我们使用已经训练好的模型生成样本。
class DiffusionModel(nn.Module):# 类变量,用于将字符串调度器名称映射到相应的调度函数SCHEDULER_MAPPING = {"linear": linear_beta_schedule,"cosine": cosine_beta_schedule,"sigmoid": sigmoid_beta_schedule,}def __init__(self,model: nn.Module,image_size: int,*,beta_scheduler: str = "linear",  # 调度器类型,默认为线性timesteps: int = 1000,schedule_fn_kwargs: dict | None = None,  # 调度函数的关键字参数,默认为 Noneauto_normalize: bool = True,) -> None:super().__init__()self.model = modelself.channels = self.model.channelsself.image_size = image_size# 从 SCHEDULER_MAPPING 字典中获取与 beta_scheduler 字符串相对应的调度函数# 如果 beta_scheduler 字符串不存在于字典中,则返回 Noneself.beta_scheduler_fn = self.SCHEDULER_MAPPING.get(beta_scheduler)# 检查获取到的调度函数是否为 None,即检查是否成功选择了β调度函数# 如果调度函数为 None,则说明指定的 beta_scheduler 字符串不在预定义的调度函数列表中,于是抛出 ValueError 异常if self.beta_scheduler_fn is None:raise ValueError(f"unknown beta schedule {beta_scheduler}")# 检查是否提供了调度函数的关键字参数。若未提供,将schedule_fn_kwargs 设置为空字典。if schedule_fn_kwargs is None:schedule_fn_kwargs = {}# 用于计算扩散模型中的β调度函数,以及与β相关的其他参数,如α和后验方差:betas = self.beta_scheduler_fn(timesteps, **schedule_fn_kwargs)  # 生成一个包含β值的张量 betasalphas = 1.0 - betas# 对α值进行累积乘积,得到一个新的张量 alphas_cumprod,其形状与 betas 相同,包含了从0到 timesteps-1 时间步的所有α值的乘积。alphas_cumprod = torch.cumprod(alphas, dim=0)alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)'''对 alphas_cumprod 进行填充操作,将其第一个元素用 1.0 填充,以确保在计算后验方差时不会出现除以零的情况。F.pad 函数用于在张量的指定维度上进行填充,这里在维度 0 上进行填充,向左填充一个元素。'''# 计算后验方差posterior_variance = (betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod))# 注册缓冲区(buffer),并将每个相关的张量转换为 torch.float32 类型register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))register_buffer("betas", betas)  # 包含 β 值的张量register_buffer("alphas_cumprod", alphas_cumprod)  # 包含 α 累积乘积的张量register_buffer("alphas_cumprod_prev", alphas_cumprod_prev)  # 包含 α 累积乘积的前一个时间步的张量register_buffer("sqrt_recip_alphas", torch.sqrt(1.0 / alphas))  # α 的倒数的平方根的张量register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod))  # α 累积乘积的平方根的张量register_buffer("sqrt_one_minus_alphas_cumprod", torch.sqrt(1.0 - alphas_cumprod))register_buffer("posterior_variance", posterior_variance)  # 后验方差的张量timesteps, *_ = betas.shape'''这里使用了“*”操作符,它的作用是在变量解构(destructuring)中丢弃不需要的部分。因为 betas 张量是一维的,所以这里的“*”操作符实际上没有起到什么作用,只是为了让代码更具通用性。'''self.num_timesteps = int(timesteps)  # 将时间步数转换为整数self.sampling_timesteps = timesteps# 归一化# auto_normalize 为 True,则选择 normalize_to_neg_one_to_one 函数进行归一化;否则选择 identity 函数,即不进行归一化操作。self.normalize = normalize_to_neg_one_to_one if auto_normalize else identity#  auto_normalize 为 True,则选择 unnormalize_to_zero_to_one 函数进行反归一化;否则选择 identity 函数,即不进行反归一化操作。self.unnormalize = unnormalize_to_zero_to_one if auto_normalize else identity@torch.inference_mode()'''可以将下面的函数或代码块置于推断模式中。这意味着,在装饰器声明的范围内,PyTorch 将禁用梯度计算,不会跟踪梯度,也不会进行任何与梯度相关的操作。这有助于提高推断速度,并且可以确保模型在进行推断时不会意外地进行训练相关的计算。'''def p_sample(self, x: torch.Tensor, timestamp: int) -> torch.Tensor:# 使用了解构语法 *,将张量形状中的除最后一维之外的所有维度忽略掉,并将结果赋值给一个名为 _ 的临时变量,最后一个维度被赋值给 deviceb, *_, device = *x.shape, x.devicebatched_timestamps = torch.full((b,), timestamp, device=device, dtype=torch.long)# 创建了一个形状为 (b,) 的张量 batched_timestamps,用于存储批次中每个样本的时间戳。# timestamp,其数据类型为 torch.long,并且张量存储在与输入张量相同的设备上# 将输入张量 x 和时间戳张量 batched_timestamps 传递给模型 self.model,以获取预测值 predspreds = self.model(x, batched_timestamps)# 使用函数 extract 从预先计算的参数 self.betas 中提取与批次时间戳对应的β值betas_t = extract(self.betas, batched_timestamps, x.shape)sqrt_recip_alphas_t = extract(self.sqrt_recip_alphas, batched_timestamps, x.shape)  # 提取与批次时间戳对应的α倒数的平方根sqrt_one_minus_alphas_cumprod_t = extract(self.sqrt_one_minus_alphas_cumprod, batched_timestamps, x.shape)  # 提取与批次时间戳对应的1减去α累积乘积的平方根# 计算预测的样本均值predicted_mean = sqrt_recip_alphas_t * (x - betas_t * preds / sqrt_one_minus_alphas_cumprod_t)#如果时间戳为零,直接返回预测的样本均值;否则,计算样本的后验方差并添加噪声,然后返回结果。if timestamp == 0:return predicted_meanelse:posterior_variance = extract(self.posterior_variance, batched_timestamps, x.shape)noise = torch.randn_like(x)return predicted_mean + torch.sqrt(posterior_variance) * noise@torch.inference_mode()def p_sample_loop(self, shape: tuple, return_all_timesteps: bool = False) -> torch.Tensor:batch, device = shape[0], "mps"  # 从形状元组中获取批量大小 batch,并设置设备为 "mps"(多处理器尺寸)img = torch.randn(shape, device=device) # 函数生成一个具有指定形状的随机张量 img,其值服从标准正态分布# This cause me a RunTimeError on MPS device due to MPS back out of memory# No ideas how to resolve it at this point# imgs = [img]'''使用 tqdm 函数创建一个迭代进度条,迭代范围是从 0 到 self.num_timesteps 的逆序。每个时间步长 t 都会调用 p_sample 方法进行样本采样,并更新 img 的值。'''for t in tqdm(reversed(range(0, self.num_timesteps)), total=self.num_timesteps):img = self.p_sample(img, t)# imgs.append(img)'''将每个时间步长的采样结果添加到一个列表 imgs 中。在循环中,每次迭代会生成一个新的采样结果,并将其添加到列表中允许在函数结束后返回所有时间步长的采样结果,以便进一步分析或处理'''# 最终的采样结果ret = img # if not return_all_timesteps else torch.stack(imgs, dim=1)# 调用 unnormalize 方法将最终的采样结果反归一化,使其返回到原始数据范围内。ret = self.unnormalize(ret)return ret# return_all_timesteps指定是否返回所有时间步长的样本,默认为 False,表示只返回最终时间步长的样本def sample(self, batch_size: int = 16, return_all_timesteps: bool = False  ) -> torch.Tensor:shape = (batch_size, self.channels, self.image_size, self.image_size)return self.p_sample_loop(shape, return_all_timesteps=return_all_timesteps)# 用于在给定时间步长 t 上生成样本def q_sample(self, x_start: torch.Tensor, t: int, noise: torch.Tensor = None) -> torch.Tensor:# 首先检查是否提供了噪声if noise is None:noise = torch.randn_like(x_start)# 接着根据 t 从预先计算的参数中提取相应的系数sqrt_alphas_cumprod_t = extract(self.sqrt_alphas_cumprod, t, x_start.shape)sqrt_one_minus_alphas_cumprod_t = extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)# 最后根据扩散过程的定义,计算并返回生成的样本return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noisedef p_loss(self,x_start: torch.Tensor,t: int,noise: torch.Tensor = None,loss_type: str = "l2",) -> torch.Tensor:if noise is None:noise = torch.randn_like(x_start)x_noised = self.q_sample(x_start, t, noise=noise)  # 在给定时间步长 t 上生成经过噪声处理的样本 x_noised# 使用生成的 x_noised 作为输入,调用模型 self.model,并传入时间步长 t,以获取预测的噪声 predicted_noise。predicted_noise = self.model(x_noised, t)if loss_type == "l2":  # 均方误差损失函数loss = F.mse_loss(noise, predicted_noise)elif loss_type == "l1":  # 绝对值误差损失函数loss = F.l1_loss(noise, predicted_noise)else:raise ValueError(f"unknown loss type {loss_type}")return lossdef forward(self, x: torch.Tensor) -> torch.Tensor:b, c, h, w, device, img_size = *x.shape, x.device, self.image_sizeassert h == w == img_size, f"image size must be {img_size}"# 解析输入 x 的形状,并确保输入的图像是正方形且大小与 image_size 相同。# 生成一个随机的时间步长 timestamp,范围在 [0, num_timesteps) 内timestamp = torch.randint(0, self.num_timesteps, (1,)).long().to(device)x = self.normalize(x)return self.p_loss(x, timestamp)

Life is a journey. We pursue love and light with purity.

你的 “三连” 是小曦持续更新的动力!
下期将推出
扩散的代码实现,零距离解读扩散是如何实现的。


http://www.ppmy.cn/embedded/34215.html

相关文章

leetCode73. 矩阵置零

leetCode73. 矩阵置零 题目思路&#xff1a; 代码 class Solution { public:void setZeroes(vector<vector<int>>& matrix) {// 力扣特色&#xff1a;先判断是否为空if(matrix.empty() || matrix[0].empty()) return;int n matrix.size(), m matrix[0].siz…

python数据分析——业务数据描述

业务数据描述 前言一、数据收集数据信息来源 二、公司内部数据&#xff08;1&#xff09;客户资料数据&#xff08;2&#xff09;销售明细数据&#xff08;3&#xff09;营销活动数据 三、市场调查数据1 观察法2 提问法3 实验法 四、公共数据五、第三方数据六、数据预处理七、数…

步态识别论文(6)GaitDAN: Cross-view Gait Recognition via Adversarial Domain Adaptation

摘要: 视角变化导致步态外观存在显着差异。因此&#xff0c;识别跨视图场景中的步态是非常具有挑战性的。最近的方法要么在进行识别之前将步态从原始视图转换为目标视图&#xff0c;要么通过蛮力学习或解耦学习提取与相机视图无关的步态特征。然而&#xff0c;这些方法有许多约…

51单片机入门:DS1302时钟

51单片机内部含有晶振&#xff0c;可以实现定时/计数功能。但是其缺点有&#xff1a;精度往往不高、不能掉电使用等。 我们可以通过DS1302时钟芯片来解决以上的缺点。 DS1302时钟芯片 功能&#xff1a;DS1302是一种低功耗实时时钟芯片&#xff0c;内部有自动的计时功能&#x…

IDEA访问不到静态资源

背景 我在resources下创建static文件夹&#xff0c;再创建front文件夹放前端资源&#xff0c;里面有index.html&#xff0c;游览器输入localhost:8011/front没反应。&#xff08;resources/static/front/index.html&#xff09; 解决办法 重启idea&#xff0c;清楚idea缓存&am…

springboot 学习路线

Spring Boot 是一个开源的 Java 基础框架&#xff0c;它提供了快速开发、配置简单的特性&#xff0c;帮助开发者轻松创建独立的、生产级别的基于 Spring 框架的应用。以下是一条推荐的 Spring Boot 学习路线&#xff1a; 1. Java 基础知识 Java SE&#xff1a;掌握 Java 标准…

tomcat+maven+java+mysql图书管理系统1-配置项目环境

目录 一、软件版本 二、具体步骤 一、软件版本 idea2022.2.1 maven是idea自带不用另外下载 tomcat8.5.99 Javajdk17 二、具体步骤 1.新建项目 稍等一会&#xff0c;创建成功如下图所示&#xff0c;主要看左方目录相同不。 给maven配置国外镜像 在左上…

Py深度学习基础|关于reshape()函数

在代码中经常能看到reshape((1, -1))或者reshape((-1, 1))的用法&#xff0c;这里予以记录&#xff0c;如有错误还请大佬指正。 reshape函数用于改变数组或系列的形状。当使用-1作为参数时&#xff0c;它是一种灵活的方式来告诉函数自动帮助计算出应该有的行数或列数&#xff0…