Diffusion模型中时间t嵌入的方法
class PositionalEmbedding(nn.Module):def __init__(self, dim, scale=1.0):super().__init__()assert dim % 2 == 0self.dim = dimself.scale = scaledef forward(self, x):device = x.devicehalf_dim = self.dim // 2emb = math.log(10000) / half_dimemb = torch.exp(torch.arange(half_dim, device=device) * -emb)# x * self.scale和emb外积emb = torch.outer(x * self.scale, emb)emb = torch.cat((emb.sin(), emb.cos()), dim=-1)return emb
我们用 dim=128
和 x=[10, 12, 16, 100]
来具体计算 PositionalEmbedding
的输出。
1. 设定参数
dim=128
,意味着嵌入向量的维度是 128。half_dim = dim // 2 = 64
,所以我们需要计算 64 个频率因子的正弦和余弦值。x = [10, 12, 16, 100]
是输入值。
2. 计算频率因子
emb = math.log(10000) / half_dim # 计算缩放因子
emb = torch.exp(torch.arange(half_dim) * -emb) # 生成 64 维的指数频率因子
math.log(10000) ≈ 9.2103
emb = torch.exp(torch.arange(64) * (-9.2103 / 64))
torch.arange(64)
生成[0, 1, 2, ..., 63]
,然后乘以-emb
,再计算指数exp
,得到 64 个递减的频率因子。
3. 计算外积
emb = torch.outer(x * self.scale, emb)
- 计算
x * self.scale
,如果scale=1.0
,那么x
仍然是[10, 12, 16, 100]
。 emb
是一个4 × 64
的矩阵,每一行表示x[i]
乘以emb
里的每个频率因子。
假设 emb
(频率因子)前 5 个数是:
[1.0000, 0.9120, 0.8318, 0.7586, 0.6918, ...]
那么 x=10
这一行计算结果是:
[10.0000, 9.1200, 8.3180, 7.5860, 6.9180, ...]
4. 计算正弦和余弦
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
- 先对
emb
取sin
,然后取cos
,最后拼接,得到4 × 128
的矩阵。
如果 sin(10.0000) ≈ -0.5440
,cos(10.0000) ≈ -0.8391
,那么 x=10
这一行最终变成:
[-0.5440, 0.4120, 0.9890, 0.9870, -0.9912, ... | -0.8391, 0.9111, -0.1479, 0.1603, -0.1321, ...]
其中,前 64 维是 sin
计算结果,后 64 维是 cos
计算结果。
5. 最终输出
如果 x = [10, 12, 16, 100]
,输出 emb
是 4 × 128
的矩阵:
tensor([[-0.5440, 0.4120, 0.9890, 0.9870, -0.9912, ..., -0.8391, 0.9111, -0.1479, 0.1603, -0.1321, ...],[-0.5366, 0.4576, 0.9941, 0.9891, -0.9954, ..., -0.8437, 0.9005, -0.1085, 0.1521, -0.1242, ...],[-0.5215, 0.5321, 0.9971, 0.9922, -0.9986, ..., -0.8524, 0.8804, -0.0563, 0.1423, -0.1113, ...],[-0.5064, 0.8658, -0.9813, 0.9989, 0.9924, ..., -0.8849, 0.7912, 0.1951, 0.0234, -0.9811, ...],
])
- 每一行对应输入
x
的一个数的 128 维位置编码。 - 其中前 64 维是
sin(x * 频率)
,后 64 维是cos(x * 频率)
。 x=100
时,周期性更明显,因为sin
和cos
是周期函数,大的x
会导致编码的模式周期性更强。
6. 总结
- 这个位置编码会为
x
生成一个 128 维的向量,每个维度都由sin
和cos
计算得到。 x
变大时,周期性更明显。- 适用于 Transformer 或其他模型,以在输入数据中添加位置信息,使模型能够区分不同位置的输入数据。