1 扩散模型时间步嵌入
1.1 时间步正弦编码
在扩散模型按时间步 t 进行加噪去噪过程时,需要包括反映噪声水平的时间步长 t 作为噪声预测器的额外输入。但是最初与图像配套的时间步 t 是数字,需要将代表时间步 t 的数字编码为向量嵌入。嵌入时间向量的宽度dim是按照输出设定的。
def timestep_embedding(timesteps, dim, max_period=10000):"""Create sinusoidal timestep embeddings.:param timesteps: a 1-D Tensor of N indices, one per batch element.These may be fractional.:param dim: the dimension of the output.:param max_period: controls the minimum frequency of the embeddings.:return: an [N x dim] Tensor of positional embeddings."""half = dim // 2freqs = th.exp(-math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half).to(device=timesteps.device)args = timesteps[:, None].float() * freqs[None]embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)if dim % 2:embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)return embedding
参数定义
- timesteps--时间步 t ,包含batch中每个元素的时间步信息,是一个一维张量,形状为[N];
- dim--编码后时间嵌入的最后一维大小,即输出多宽的向量;
- max_period--控制嵌入的最小频率,值为10000。
假设此时的batch中有4个图像,则此时4个图像有4个对应的扩散时间步,所需编码的时间嵌入宽度是8。timesteps的形状为[4] 。
timesteps = [t1, t2, t3, t4], dim = 8
编码过程
half = dim // 2
计算所需输出向量的一半尺寸,以供后续将正弦嵌入和余弦嵌入按照最后维度concat拼接。
freqs = th.exp(-math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half).to(device=timesteps.device)
用了固定化编码方式,计算出的频率向量,且该频率向量的维度是由dim的一半half决定的。arange(start=0,end=half)则代表arange(0,4),若以i为代表,则此时 i = 0,1,2,3。公式如下:
则此时计算出的 freqs = [freqs1,freqs2,freqs3,freqs4],是一个一维张量,形状为[4]。
args = timesteps[:, None].float() * freqs[None]
首先,timesteps[:, None]作用是将timesteps的形状从一维张量[4]扩展为2维张量[4x1],之前是 timesteps = [ t1, t2, t3, t4 ] 扩展为 timesteps = [ [ t1 ], [ t2 ], [ t3 ], [ t4 ] ]。
其次将freqs的形状从一维张量[4]扩展为2维张量[1x4],之前是 freqs = [freqs1,freqs2,freqs3,freqs4] 扩展为 freqs = [ [freqs1,freqs2,freqs3,freqs4] ]。
这样做的目的是为了将时间步 timesteps 代表的batch内每个图片的时间步与频率 freqs 做乘法,将时间步广播进去。得出args的形状是 [4x4] ,即 [N x half] ,每个时间步对应得到half列,有N个时间步。
embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
将得到的args求cos值和sin值,此时得到的 cos(args) 与 sin(args) 的形状依旧都是 [N x half],按照最后一维拼接 cos(args) 与 sin(args) ,则拼接后得到的embedding的形状是 [4 x 8], [N x 2half] ,即 [N x dim]。对应开始定义的 dim 代表的是编码后的最后维度大小。
if dim % 2:embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)return embedding
最后加一个判断,当 dim % 2 ==1 即dim是奇数时,embedding的形状是 [N x dim-1],需要用零向量补充1个维度,变为 [N x dim] 。
1.2 馈入多层感知机
当对时间步长执行正弦编码之后,需要将其馈送到多层感知机中获得隐式时间嵌入。
time_embed_dim = model_channels * 4self.time_embed = nn.Sequential(linear(model_channels, time_embed_dim),nn.SiLU(),linear(time_embed_dim, time_embed_dim),)emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
多层感知机由两个具有倒置瓶颈结构的线性层和一个SiLU激活函数组成。其中dim = channels = 8,time_embed_dim = model_channels * 4 = 32。需要注意是 N 为 batch 大小。
- linear(model_channels, time_embed_dim)中,输入形状为 [4 x 8] ,[N x dim] 即 [N x model_channels] ,输出形状为 [4 x 32] , [N x time_embed_dim]。
- 激活函数不会改变形状。
- linear(time_embed_dim, time_embed_dim)中,输入形状为 [4 x 32] ,[N x time_embed_dim],输出形状为 [4 x 32],[N x time_embed_dim]。
多层感知机主要作用是将输入的时间步正弦编码向量变得更宽。通过激活函数引入非线性。
通过多层感知机中的线性层将正弦时间编码向量的宽调整成当前block所需输出图像的通道数宽,与通道数转换成一样大小是为了在UNet的每一个block中,将正弦时间编码向量加到图像的每个像素点上。