[DL]深度学习_扩散模型正弦时间编码

server/2024/12/2 22:21:47/

1  扩散模型时间步嵌入

1.1  时间步正弦编码

        在扩散模型按时间步 t 进行加噪去噪过程时,需要包括反映噪声水平的时间步长 t 作为噪声预测器的额外输入。但是最初与图像配套的时间步 t 是数字,需要将代表时间步 t 的数字编码为向量嵌入。嵌入时间向量的宽度dim是按照输出设定的。

def timestep_embedding(timesteps, dim, max_period=10000):"""Create sinusoidal timestep embeddings.:param timesteps: a 1-D Tensor of N indices, one per batch element.These may be fractional.:param dim: the dimension of the output.:param max_period: controls the minimum frequency of the embeddings.:return: an [N x dim] Tensor of positional embeddings."""half = dim // 2freqs = th.exp(-math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half).to(device=timesteps.device)args = timesteps[:, None].float() * freqs[None]embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)if dim % 2:embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)return embedding

参数定义

  • timesteps--时间步 t ,包含batch中每个元素的时间步信息,是一个一维张量,形状为[N];
  • dim--编码后时间嵌入的最后一维大小,即输出多宽的向量;
  • max_period--控制嵌入的最小频率,值为10000。 

        假设此时的batch中有4个图像,则此时4个图像有4个对应的扩散时间步,所需编码的时间嵌入宽度是8。timesteps的形状为[4] 。

timesteps = [t1, t2, t3, t4], dim = 8

编码过程

    half = dim // 2

        计算所需输出向量的一半尺寸,以供后续将正弦嵌入和余弦嵌入按照最后维度concat拼接。

    freqs = th.exp(-math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half).to(device=timesteps.device)

        用了固定化编码方式,计算出的频率向量,且该频率向量的维度是由dim的一半half决定的。arange(start=0,end=half)则代表arange(0,4),若以i为代表,则此时 i = 0,1,2,3。公式如下:

freqs=e^{(-\frac{i\times log(maxperiod)}{half})} 

则此时计算出的 freqs = [freqs1,freqs2,freqs3,freqs4],是一个一维张量,形状为[4]。

    args = timesteps[:, None].float() * freqs[None]

        首先,timesteps[:, None]作用是将timesteps的形状从一维张量[4]扩展为2维张量[4x1],之前是 timesteps = [ t1, t2, t3, t4 ] 扩展为 timesteps = [ [ t1 ], [ t2 ], [ t3 ], [ t4 ] ]。

        其次将freqs的形状从一维张量[4]扩展为2维张量[1x4],之前是 freqs = [freqs1,freqs2,freqs3,freqs4] 扩展为 freqs = [ [freqs1,freqs2,freqs3,freqs4] ]。

        这样做的目的是为了将时间步 timesteps 代表的batch内每个图片的时间步与频率 freqs 做乘法,将时间步广播进去。得出args的形状是 [4x4] ,即 [N x half] ,每个时间步对应得到half列,有N个时间步。

    embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)

        将得到的args求cos值和sin值,此时得到的 cos(args) 与 sin(args) 的形状依旧都是 [N x half],按照最后一维拼接 cos(args) 与 sin(args) ,则拼接后得到的embedding的形状是 [4 x 8], [N x 2half] ,即 [N x dim]。对应开始定义的 dim 代表的是编码后的最后维度大小。 

    if dim % 2:embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)return embedding

        最后加一个判断,当 dim % 2 ==1 即dim是奇数时,embedding的形状是 [N x dim-1],需要用零向量补充1个维度,变为 [N x dim] 。

1.2  馈入多层感知机 

        当对时间步长执行正弦编码之后,需要将其馈送到多层感知机中获得隐式时间嵌入。 

        time_embed_dim = model_channels * 4self.time_embed = nn.Sequential(linear(model_channels, time_embed_dim),nn.SiLU(),linear(time_embed_dim, time_embed_dim),)emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))

        多层感知机由两个具有倒置瓶颈结构的线性层和一个SiLU激活函数组成。其中dim = channels = 8,time_embed_dim = model_channels * 4 = 32。需要注意是 N 为 batch 大小。

  • linear(model_channels, time_embed_dim)中,输入形状为 [4 x 8] ,[N x dim] 即 [N x model_channels] ,输出形状为 [4 x 32] , [N x time_embed_dim]。
  • 激活函数不会改变形状。
  • linear(time_embed_dim, time_embed_dim)中,输入形状为 [4 x 32] ,[N x time_embed_dim],输出形状为 [4 x 32],[N x time_embed_dim]。

        多层感知机主要作用是将输入的时间步正弦编码向量变得更宽。通过激活函数引入非线性。 

        通过多层感知机中的线性层将正弦时间编码向量的宽调整成当前block所需输出图像的通道数宽,与通道数转换成一样大小是为了在UNet的每一个block中,将正弦时间编码向量加到图像的每个像素点上。


http://www.ppmy.cn/server/146853.html

相关文章

Qt MinGW环境下使用CEF

环境 Clion :2019.3.6 Qt :5.12.12(MinGW 7.3.0) CEF:cef_binary_87.1.14ga29e9a3chromium-87.0.4280.141_windows32 编译libcef_dll_wrapper 修改cmake 文件 修改根目录下的CMakeLists.txt # 修改cmake 版本 c…

matlab代码--卷积神经网络的手写数字识别

1.cnn介绍 卷积神经网络(Convolutional Neural Network, CNN)是一种深度学习的算法,在图像和视频识别、图像分类、自然语言处理等领域有着广泛的应用。CNN的基本结构包括输入层、卷积层、池化层(Pooling Layer)、全连…

Python plotly库介绍

目录 一、引言 二、plotly库的特点 三、安装plotly库 四、基本用法 五、高级功能 六、总结 一、引言 在数据可视化领域,Python提供了众多强大的库。其中,plotly是一个功能强大、交互式的可视化库,可以创建各种类型的图表,包…

详解 PyTorch 中的 Dataset:功能、实现及应用示例

详解 PyTorch 中的 Dataset:功能、实现及应用示例 在机器学习和深度学习中,Dataset 类是一个抽象类,通常用于封装对于数据集的各种操作,包括访问、处理和预处理数据。Dataset 为数据加载提供了一个标准的接口,使其能够…

uniapp 扩展picker-view实现条件查询

因为选项值过多,需要动态查询,现有组件无法实现,将picker-view扩展了一下,支持条件查询,接口调用。 实现效果 注意:直接使用,样式可能不准,根据自己的实际情况进行样式调整 参数说…

ELK超详细操作文档

ELK简介 ELK平台是一套完整的日志集中处理解决方案,将 ElasticSearch、Logstash 和 Kiabana 三个开源工具配合使用, 完成更强大的用户对日志的查询、排序、统计需求。 ElasticSearch ElasticSearch:是基于Lucene(一个全文检索引…

Oracle 11gR2 Data Guard 搭建 (一主一从)

一、环境规划 项目主库 Primary备库 Standby操作系统CentOS Linux 7.9.2009CentOS Linux 7.9.2009数据库版本11.2.0.411.2.0.4IP地址192.168.10.101192.168.10.102db_nameorclorclinstance_nameorclorcldb_unique_nameorcl_priorcl_sbytnsnameorcl_priorcl_sbyservice_names(服…

neo4j如何存储关于liquidity structure的层次和关联结构

在 Neo4j 中存储关于流动性结构(liquidity structure)的层次和关联结构非常适合,因为 Neo4j 是一个基于图的数据库,能够自然地建模和存储复杂的关系和层次结构。下面是如何在 Neo4j 中设计和实现这样的数据模型的详细步骤和示例。…