Diffusion模型中时间t嵌入的方法

ops/2025/3/6 10:04:56/

Diffusion模型中时间t嵌入的方法

class PositionalEmbedding(nn.Module):def __init__(self, dim, scale=1.0):super().__init__()assert dim % 2 == 0self.dim = dimself.scale = scaledef forward(self, x):device = x.devicehalf_dim = self.dim // 2emb = math.log(10000) / half_dimemb = torch.exp(torch.arange(half_dim, device=device) * -emb)# x * self.scale和emb外积emb = torch.outer(x * self.scale, emb)emb = torch.cat((emb.sin(), emb.cos()), dim=-1)return emb

我们用 dim=128x=[10, 12, 16, 100] 来具体计算 PositionalEmbedding 的输出。


1. 设定参数

  • dim=128,意味着嵌入向量的维度是 128。
  • half_dim = dim // 2 = 64,所以我们需要计算 64 个频率因子的正弦和余弦值。
  • x = [10, 12, 16, 100] 是输入值。

2. 计算频率因子

emb = math.log(10000) / half_dim  # 计算缩放因子
emb = torch.exp(torch.arange(half_dim) * -emb)  # 生成 64 维的指数频率因子
  • math.log(10000) ≈ 9.2103
  • emb = torch.exp(torch.arange(64) * (-9.2103 / 64))
  • torch.arange(64) 生成 [0, 1, 2, ..., 63],然后乘以 -emb,再计算指数 exp,得到 64 个递减的频率因子。

3. 计算外积

emb = torch.outer(x * self.scale, emb)
  • 计算 x * self.scale,如果 scale=1.0,那么 x 仍然是 [10, 12, 16, 100]
  • emb 是一个 4 × 64 的矩阵,每一行表示 x[i] 乘以 emb 里的每个频率因子。

假设 emb(频率因子)前 5 个数是:

[1.0000, 0.9120, 0.8318, 0.7586, 0.6918, ...]

那么 x=10 这一行计算结果是:

[10.0000, 9.1200, 8.3180, 7.5860, 6.9180, ...]

4. 计算正弦和余弦

emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
  • 先对 embsin,然后取 cos,最后拼接,得到 4 × 128 的矩阵。

如果 sin(10.0000) ≈ -0.5440cos(10.0000) ≈ -0.8391,那么 x=10 这一行最终变成:

[-0.5440, 0.4120, 0.9890, 0.9870, -0.9912, ... | -0.8391, 0.9111, -0.1479, 0.1603, -0.1321, ...]

其中,前 64 维是 sin 计算结果,后 64 维是 cos 计算结果。


5. 最终输出

如果 x = [10, 12, 16, 100],输出 emb4 × 128 的矩阵:

tensor([[-0.5440,  0.4120,  0.9890,  0.9870, -0.9912, ..., -0.8391,  0.9111, -0.1479,  0.1603, -0.1321, ...],[-0.5366,  0.4576,  0.9941,  0.9891, -0.9954, ..., -0.8437,  0.9005, -0.1085,  0.1521, -0.1242, ...],[-0.5215,  0.5321,  0.9971,  0.9922, -0.9986, ..., -0.8524,  0.8804, -0.0563,  0.1423, -0.1113, ...],[-0.5064,  0.8658, -0.9813,  0.9989,  0.9924, ..., -0.8849,  0.7912,  0.1951,  0.0234, -0.9811, ...],
])
  • 每一行对应输入 x 的一个数的 128 维位置编码。
  • 其中前 64 维是 sin(x * 频率),后 64 维是 cos(x * 频率)
  • x=100 时,周期性更明显,因为 sincos 是周期函数,大的 x 会导致编码的模式周期性更强。

6. 总结

  • 这个位置编码会为 x 生成一个 128 维的向量,每个维度都由 sincos 计算得到。
  • x 变大时,周期性更明显。
  • 适用于 Transformer 或其他模型,以在输入数据中添加位置信息,使模型能够区分不同位置的输入数据。

http://www.ppmy.cn/ops/163571.html

相关文章

FastGPT 源码:混合检索调用链路

文章目录 FastGPT 源码:混合检索调用链路1. 入口函数2. 核心搜索函数3. RRF合并函数4. Rerank重排序函数5. 完整流程 FastGPT 源码:混合检索调用链路 主要调用链路如下: 1. 入口函数 在 dispatchDatasetSearch(packages/servic…

Mac OS升级后变慢了,如何恢复老系统?

我的一台Mac Air闲置很久了,原因是某次系统升级后用着会卡,有差不多10年没用了。今天想试着恢复一下出厂系统,目前看这条路可以走通。记录如下: 1、去哪里下载旧版系统? https://support.apple.com/zh-cn/102662 2、…

游戏引擎学习第136天

回顾 今天,我们的工作重点是继续探索之前搭建的资产系统,目的是最终定义我们的资产包文件格式。通过这个工作,我们希望能够创建一个符合我们要求的资产包文件。这样,我可以在直播之外的时间完成它,并为我们提供一个符…

JAVA入门——网络编程简介

自己学习时的笔记,可能有点水( 以后可能还会补充(大概率不会) 一、基本概念 网络编程三要素: IP 设备在网络中的唯一标识 端口号 应用软件在设备中的唯一标识两个字节表示的整数,0~1023用于知名的网络…

基于 Kubernetes 搭建 DevOps 持续集成环境

环境准备 在部署 Kubernetes(K8s)以及相关 DevOps 工具(如 Jenkins、Kuboard、Harbor)时,我们需要确保服务器和软件环境符合要求。 服务器及软件环境 服务器配置:2 核 4G 及以上(推荐至少 2 …

【Vue3】实现一个超过高度后可控制显示隐藏的组件

组件效果图 未达到最大高度 达到设置的最大高度 进行展开 实现代码 组件代码 备注&#xff1a;通过tailwindcss设置的样式&#xff0c;通过element-plus/icons-vue设置的图标&#xff0c;可根据情况进行替换 <template><!-- 限制高度组件 --><div ref"…

smart代理原生住宅IP是如何避免跨境电商店铺被DDoS攻击的?

随着跨境电商的迅猛发展&#xff0c;越来越多的商家开始把业务拓展到国际市场&#xff0c;然而&#xff0c;随之而来的是网络安全问题的威胁&#xff0c;其中最常见的是DDoS攻击。 这种攻击方式会让商家的网站或应用程序停止运作&#xff0c;从而影响销售和声誉。 幸运的是&…

CES Asia 2025增设未来办公教育板块,科技变革再掀高潮

作为亚洲消费电子领域一年一度的行业盛会&#xff0c;CES Asia 2025&#xff08;第七届亚洲消费电子技术贸易展&#xff09;即将盛大启幕。今年展会规模再度升级&#xff0c;预计将吸引超过500家全球展商参展&#xff0c;专业观众人数有望突破10万。除了聚焦人工智能、物联网、…