Pytorch实现扩散模型【DDPM代码解读篇2】

news/2024/9/20 7:27:25/ 标签: pytorch, 深度学习, 机器学习, 人工智能

扩散的代码实现

本文承接  Pytorch实现扩散模型【DDPM代码解读篇1】http://t.csdnimg.cn/aDK0A

主要介绍“扩散是如何实现的”。代码逻辑清晰,可快速上手学习。

# 扩散的代码实现
# 扩散过程是训练部分的模型。它打开了一个采样接口,允许我们使用已经训练好的模型生成样本。
class DiffusionModel(nn.Module):# 类变量,用于将字符串调度器名称映射到相应的调度函数SCHEDULER_MAPPING = {"linear": linear_beta_schedule,"cosine": cosine_beta_schedule,"sigmoid": sigmoid_beta_schedule,}def __init__(self,model: nn.Module,image_size: int,*,beta_scheduler: str = "linear",  # 调度器类型,默认为线性timesteps: int = 1000,schedule_fn_kwargs: dict | None = None,  # 调度函数的关键字参数,默认为 Noneauto_normalize: bool = True,) -> None:super().__init__()self.model = modelself.channels = self.model.channelsself.image_size = image_size# 从 SCHEDULER_MAPPING 字典中获取与 beta_scheduler 字符串相对应的调度函数# 如果 beta_scheduler 字符串不存在于字典中,则返回 Noneself.beta_scheduler_fn = self.SCHEDULER_MAPPING.get(beta_scheduler)# 检查获取到的调度函数是否为 None,即检查是否成功选择了β调度函数# 如果调度函数为 None,则说明指定的 beta_scheduler 字符串不在预定义的调度函数列表中,于是抛出 ValueError 异常if self.beta_scheduler_fn is None:raise ValueError(f"unknown beta schedule {beta_scheduler}")# 检查是否提供了调度函数的关键字参数。若未提供,将schedule_fn_kwargs 设置为空字典。if schedule_fn_kwargs is None:schedule_fn_kwargs = {}# 用于计算扩散模型中的β调度函数,以及与β相关的其他参数,如α和后验方差:betas = self.beta_scheduler_fn(timesteps, **schedule_fn_kwargs)  # 生成一个包含β值的张量 betasalphas = 1.0 - betas# 对α值进行累积乘积,得到一个新的张量 alphas_cumprod,其形状与 betas 相同,包含了从0到 timesteps-1 时间步的所有α值的乘积。alphas_cumprod = torch.cumprod(alphas, dim=0)alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)'''对 alphas_cumprod 进行填充操作,将其第一个元素用 1.0 填充,以确保在计算后验方差时不会出现除以零的情况。F.pad 函数用于在张量的指定维度上进行填充,这里在维度 0 上进行填充,向左填充一个元素。'''# 计算后验方差posterior_variance = (betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod))# 注册缓冲区(buffer),并将每个相关的张量转换为 torch.float32 类型register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))register_buffer("betas", betas)  # 包含 β 值的张量register_buffer("alphas_cumprod", alphas_cumprod)  # 包含 α 累积乘积的张量register_buffer("alphas_cumprod_prev", alphas_cumprod_prev)  # 包含 α 累积乘积的前一个时间步的张量register_buffer("sqrt_recip_alphas", torch.sqrt(1.0 / alphas))  # α 的倒数的平方根的张量register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod))  # α 累积乘积的平方根的张量register_buffer("sqrt_one_minus_alphas_cumprod", torch.sqrt(1.0 - alphas_cumprod))register_buffer("posterior_variance", posterior_variance)  # 后验方差的张量timesteps, *_ = betas.shape'''这里使用了“*”操作符,它的作用是在变量解构(destructuring)中丢弃不需要的部分。因为 betas 张量是一维的,所以这里的“*”操作符实际上没有起到什么作用,只是为了让代码更具通用性。'''self.num_timesteps = int(timesteps)  # 将时间步数转换为整数self.sampling_timesteps = timesteps# 归一化# auto_normalize 为 True,则选择 normalize_to_neg_one_to_one 函数进行归一化;否则选择 identity 函数,即不进行归一化操作。self.normalize = normalize_to_neg_one_to_one if auto_normalize else identity#  auto_normalize 为 True,则选择 unnormalize_to_zero_to_one 函数进行反归一化;否则选择 identity 函数,即不进行反归一化操作。self.unnormalize = unnormalize_to_zero_to_one if auto_normalize else identity@torch.inference_mode()'''可以将下面的函数或代码块置于推断模式中。这意味着,在装饰器声明的范围内,PyTorch 将禁用梯度计算,不会跟踪梯度,也不会进行任何与梯度相关的操作。这有助于提高推断速度,并且可以确保模型在进行推断时不会意外地进行训练相关的计算。'''def p_sample(self, x: torch.Tensor, timestamp: int) -> torch.Tensor:# 使用了解构语法 *,将张量形状中的除最后一维之外的所有维度忽略掉,并将结果赋值给一个名为 _ 的临时变量,最后一个维度被赋值给 deviceb, *_, device = *x.shape, x.devicebatched_timestamps = torch.full((b,), timestamp, device=device, dtype=torch.long)# 创建了一个形状为 (b,) 的张量 batched_timestamps,用于存储批次中每个样本的时间戳。# timestamp,其数据类型为 torch.long,并且张量存储在与输入张量相同的设备上# 将输入张量 x 和时间戳张量 batched_timestamps 传递给模型 self.model,以获取预测值 predspreds = self.model(x, batched_timestamps)# 使用函数 extract 从预先计算的参数 self.betas 中提取与批次时间戳对应的β值betas_t = extract(self.betas, batched_timestamps, x.shape)sqrt_recip_alphas_t = extract(self.sqrt_recip_alphas, batched_timestamps, x.shape)  # 提取与批次时间戳对应的α倒数的平方根sqrt_one_minus_alphas_cumprod_t = extract(self.sqrt_one_minus_alphas_cumprod, batched_timestamps, x.shape)  # 提取与批次时间戳对应的1减去α累积乘积的平方根# 计算预测的样本均值predicted_mean = sqrt_recip_alphas_t * (x - betas_t * preds / sqrt_one_minus_alphas_cumprod_t)#如果时间戳为零,直接返回预测的样本均值;否则,计算样本的后验方差并添加噪声,然后返回结果。if timestamp == 0:return predicted_meanelse:posterior_variance = extract(self.posterior_variance, batched_timestamps, x.shape)noise = torch.randn_like(x)return predicted_mean + torch.sqrt(posterior_variance) * noise@torch.inference_mode()def p_sample_loop(self, shape: tuple, return_all_timesteps: bool = False) -> torch.Tensor:batch, device = shape[0], "mps"  # 从形状元组中获取批量大小 batch,并设置设备为 "mps"(多处理器尺寸)img = torch.randn(shape, device=device) # 函数生成一个具有指定形状的随机张量 img,其值服从标准正态分布# This cause me a RunTimeError on MPS device due to MPS back out of memory# No ideas how to resolve it at this point# imgs = [img]'''使用 tqdm 函数创建一个迭代进度条,迭代范围是从 0 到 self.num_timesteps 的逆序。每个时间步长 t 都会调用 p_sample 方法进行样本采样,并更新 img 的值。'''for t in tqdm(reversed(range(0, self.num_timesteps)), total=self.num_timesteps):img = self.p_sample(img, t)# imgs.append(img)'''将每个时间步长的采样结果添加到一个列表 imgs 中。在循环中,每次迭代会生成一个新的采样结果,并将其添加到列表中允许在函数结束后返回所有时间步长的采样结果,以便进一步分析或处理'''# 最终的采样结果ret = img # if not return_all_timesteps else torch.stack(imgs, dim=1)# 调用 unnormalize 方法将最终的采样结果反归一化,使其返回到原始数据范围内。ret = self.unnormalize(ret)return ret# return_all_timesteps指定是否返回所有时间步长的样本,默认为 False,表示只返回最终时间步长的样本def sample(self, batch_size: int = 16, return_all_timesteps: bool = False  ) -> torch.Tensor:shape = (batch_size, self.channels, self.image_size, self.image_size)return self.p_sample_loop(shape, return_all_timesteps=return_all_timesteps)# 用于在给定时间步长 t 上生成样本def q_sample(self, x_start: torch.Tensor, t: int, noise: torch.Tensor = None) -> torch.Tensor:# 首先检查是否提供了噪声if noise is None:noise = torch.randn_like(x_start)# 接着根据 t 从预先计算的参数中提取相应的系数sqrt_alphas_cumprod_t = extract(self.sqrt_alphas_cumprod, t, x_start.shape)sqrt_one_minus_alphas_cumprod_t = extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)# 最后根据扩散过程的定义,计算并返回生成的样本return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noisedef p_loss(self,x_start: torch.Tensor,t: int,noise: torch.Tensor = None,loss_type: str = "l2",) -> torch.Tensor:if noise is None:noise = torch.randn_like(x_start)x_noised = self.q_sample(x_start, t, noise=noise)  # 在给定时间步长 t 上生成经过噪声处理的样本 x_noised# 使用生成的 x_noised 作为输入,调用模型 self.model,并传入时间步长 t,以获取预测的噪声 predicted_noise。predicted_noise = self.model(x_noised, t)if loss_type == "l2":  # 均方误差损失函数loss = F.mse_loss(noise, predicted_noise)elif loss_type == "l1":  # 绝对值误差损失函数loss = F.l1_loss(noise, predicted_noise)else:raise ValueError(f"unknown loss type {loss_type}")return lossdef forward(self, x: torch.Tensor) -> torch.Tensor:b, c, h, w, device, img_size = *x.shape, x.device, self.image_sizeassert h == w == img_size, f"image size must be {img_size}"# 解析输入 x 的形状,并确保输入的图像是正方形且大小与 image_size 相同。# 生成一个随机的时间步长 timestamp,范围在 [0, num_timesteps) 内timestamp = torch.randint(0, self.num_timesteps, (1,)).long().to(device)x = self.normalize(x)return self.p_loss(x, timestamp)

Life is a journey. We pursue love and light with purity.

你的 “三连” 是小曦持续更新的动力!
下期将推出
扩散的代码实现,零距离解读扩散是如何实现的。


http://www.ppmy.cn/news/1453495.html

相关文章

EPAI手绘建模APP资源管理和模型编辑器3

t) 立方体 图 42 模型编辑器-立方体 i. 修改立方体底部中心位置。 ii. 修改立方体的长、宽、高。 u) 圆柱体 图 43 模型编辑器-圆柱体 i. 修改圆柱体底部中心位置。 ii. 修改圆柱体半径。 iii. 修改圆柱体高度。 iv. 修改圆柱体角度。角度决定了圆柱体沿着圆周方向有效区域…

硬件原理图评审主要关注点

一、规格与需求符合性 在进行硬件原理图评审时,首先需要确保原理图的设计符合项目规格书和技术需求。评审人员应核对原理图中的各项参数,如工作电压、电流、频率等,确保它们与项目要求一致。同时,需要确认原理图是否满足产品的功能需求,避免出现设计缺陷或遗漏。 二、元…

JDBC连接openGauss6.0和PostgreSQL16.2性能对比

JDBC在Linux终端直接编译运行JAVA程序连接PG🆚OG数据库 前置准备Hello World连接数据库(PostgreSQL)连接数据库(openGauss)PG 🆚 OG 总结 看腻了就来听听视频演示吧:https://www.bilibili.com/video/BV1CH4y1N7xL/ 前置准备 安装JDK&#x…

element-ui show-summary合计放第一行

element-ui show-summary合计放第一行 <style scoped> /* /deep/ 为深度操作符&#xff0c;可以穿透到子组件 */ /deep/ .el-table {display: flex;flex-direction: column; }/* order默认值为0&#xff0c;只需将表体order置为1即可移到最后&#xff0c;这样总计行就…

【城市】2023香港身份与生活定居相关政策(IANG,优才/高才/专才,受养人/单程证)

【城市】2023香港身份与生活定居相关政策&#xff08;IANG&#xff0c;优才/高才/专才&#xff0c;受养人/单程证&#xff09; 文章目录 一、如何获得香港身份1、7年计划2、旅游签 二、港澳相关的证件类别1、HK证件2、CN证件 三、香港生活对比内地 本文仅代表2023年查阅相关资料…

CP,FT,WAT有什么区别?

‍ 知 识星球&#xff08;星球名&#xff1a; 芯片制造与封测社区&#xff0c;星球号&#xff1a; 63559049&#xff09;里的学员问&#xff1a; CP,FT,WAT都是与 芯片的测试有关&#xff0c;他们有什么区别呢&#xff1f; 如何区‍分&#xff1f; ‍ ‍ CP,FT,WAT分别…

Hive大数据任务调度和业务介绍

目录 一、Zookeeper 1.zookeeper介绍 2.数据模型 3.操作使用 4.运行机制 5.一致性 二、Dolphinscheduler 1.Dolphinscheduler介绍 架构 2.架构说明 该服务内主要包含: 该服务包含&#xff1a; 3.FinalShell主虚拟机启动服务 4.Web网页登录 5.使用 5-1 安全中心…

Java面试题:解释Executor框架和其在并发编程中的作用

Executor框架是Java提供的一个用于管理线程的框架&#xff0c;它在Java 5中引入&#xff0c;用于简化多线程编程。Executor框架的主要目的是将任务的提交与任务的执行解耦&#xff0c;从而提供了一种更灵活和强大的方式来管理线程和任务。 在Executor框架中&#xff0c;有几个…

sql 中having和where区别

where 是用于筛选表中满足条件的行&#xff0c;不可以和聚类函数一起使用 having 是用于筛选满足条件的组 &#xff0c;可与聚合函数一起使用 所以having语句中不能使用select中定义的名字

【Python】回溯法解全排列问题

题目 给定一个不含重复数字的数组&#xff0c;返回其所有可能的全排列。 分析 要实现全排列&#xff0c;就有一个长度与原数组相等的数组&#xff0c;数组的第一位可能是原数组中的任意一位&#xff0c;第二位是除了第一位的原数组的任意一位&#xff0c;第三位则是除了前两位…

上位机图像处理和嵌入式模块部署(树莓派4b和qt应用全屏占有)

【 声明&#xff1a;版权所有&#xff0c;欢迎转载&#xff0c;请勿用于商业用途。 联系信箱&#xff1a;feixiaoxing 163.com】 我们都知道&#xff0c;嵌入式应用一般都是为了某一个特定应用而存在的。也就是说&#xff0c;和pc不同&#xff0c;这个嵌入式板子一般都是为了解…

【Docker】docker compose服务编排

docker compose 简介 Dockerfile模板文件可以定义一个单独的应用容器&#xff0c;如果需要定义多个容器就需要服务编排。 docker swarm&#xff08;管理跨节点&#xff09; Dockerfile可以让用户管理一个单独的应用容器&#xff1b;而Compose则允许用户在一个模板&#xff08…

基于SpringBoot的考务管理系统 - 源码免费(私信领取)

1. 研究目的 本项目旨在设计并实现一个基于Spring Boot的考务管理系统&#xff0c;以提高考试管理的效率&#xff0c;简化考试流程&#xff0c;确保考试的顺利进行。 2. 研究要求 a. 需求分析 通过深入了解考务管理流程和需求&#xff0c;分析用户对考试管理系统的需求&…

2021-10-21 51单片机两位数码管显示0-99循环

缘由单片机两位数码管显示0-99循环-编程语言-CSDN问答 #include "REG52.h" #include<intrins.h> sbit K1 P3^0; sbit K2 P3^1; sbit K3 P3^2; sbit K4 P3^3; sbit bpP3^4; bit k1,wk10,wk20; unsigned char code SmZiFu[]{63,6,91,79,102,109,125,7,127,1…

【Python】 逻辑回归:从训练到预测的完整案例

我把我唱给你听 把你纯真无邪的笑容给我吧 我们应该有快乐的 幸福的晴朗的时光 我把我唱给你听 用我炙热的感情感动你好吗 岁月是值得怀念的留恋的 害羞的红色脸庞 谁能够代替你呀 趁年轻尽情的爱吧 最最亲爱的人啊 路途遥远我们在一起吧 &#x1f3b5; 叶…

02 - 步骤 Kafka consumer

简介 Kafka consumer 步骤&#xff0c;用于连接和消费 Apache Kafka 中的数据,它可以作为数据管道的一部分&#xff0c;将 Kafka 中的数据提取到 Kettle 中进行进一步处理、转换和加载&#xff0c;或者将其直接传输到目标系统中。 使用 场景 我需要订阅一个Kafka的数据&…

ASP.NET网上车辆档案管理系统

摘 要 本文采用基于Web的Asp.net技术&#xff0c;并与sql server 2000数据库相结合&#xff0c;研发了一套车辆档案管理系统。该系统扩展性好&#xff0c;易于维护。简化了车辆档案设计流程&#xff0c;去除了冗余信息。汽车销售企业可以通过本系统完成整个销售及售后所有档案…

动态规划专训6——回文串系列

动态规划题目中&#xff0c;常出现回文串相关问题&#xff0c;这里单独挑出来训练 1.回文子串 LCR 020. 回文子串 给定一个字符串 s &#xff0c;请计算这个字符串中有多少个回文子字符串。 具有不同开始位置或结束位置的子串&#xff0c;即使是由相同的字符组成&#xff0…

Webshell绕过技巧分析之-base64/HEX/Reverse/Html/Inflate/Rot13

在网络安全运营&#xff0c;护网HVV&#xff0c;重保等活动的过程中&#xff0c;webshell是一个无法绕过的话题。通常出现的webshell都不是以明文的形式出现&#xff0c;而是针对webshell关键的内容进行混淆&#xff0c;编码来绕过网络安全产品&#xff08;IDS&#xff0c;WAF&…

typescript 不是特别常用,容易忘的知识点

1、花括号对象通过方括号字符串形式取值 let obj { name: asd, age: 21, salary: 400, desc: "asdasd", op: [asd, as, qwe] };for (let i in obj) {console.log(obj[i as keyof typeof obj]); }let key name; console.log(obj[key as name]); console.log(obj[ke…