nn.MultiheadAttention返回的注意力权重和标准的计算注意力权重的区别

devtools/2024/11/26 12:05:10/

#本篇只是随笔小记

试了一下,拿正弦位置编码直接做自注意力,如果手算注意力权重,应该是这样的:

attention_scores = torch.matmul(pos_x, pos_x.T)/512**0.5  # (seq_len, seq_len)# 应用 softmax 来获得注意力权重
attention_weights = torch.softmax(attention_scores, dim=-1)

可视化的效果是这样的:

如果用nn.MultiheadAttention返回注意力权重矩阵,应该是这样的:

q = torch.tensor(pos_x)
q = q.reshape(100,512)
q = q[:,None,:]
print(q.shape)
self_attn = nn.MultiheadAttention(num_pos_feats, num_heads=2)
# 手动设置权重和偏置为单位矩阵和零
#identity = torch.eye(512)
#self_attn.in_proj_weight.data = torch.cat([identity, identity, identity])
#self_attn.in_proj_bias.data = torch.zeros(3 * 512)
temp = self_attn(q, q, q)[1]
print(temp)
temp = temp.squeeze().detach().numpy()
#temp = temp.detach().numpy()
print(temp.shape)

可视化结果是这样的

显然不对,问题出在哪?

原因是nn.MultiheadAttention中对QKV有各自的投射层改变了向量,现在将投射向量全改成单位矩阵,偏置全置为0,代码如下

q = torch.tensor(pos_x)
q = q.reshape(100,512)
q = q[:,None,:]
print(q.shape)
self_attn = nn.MultiheadAttention(num_pos_feats, num_heads=2)
# 手动设置权重和偏置为单位矩阵和零
identity = torch.eye(512)
self_attn.in_proj_weight.data = torch.cat([identity, identity, identity])
self_attn.in_proj_bias.data = torch.zeros(3 * 512)
temp = self_attn(q, q, q)[1]
print(temp)
temp = temp.squeeze().detach().numpy()
#temp = temp.detach().numpy()
print(temp.shape)

可视化如下

ok,即为所得,总体代码如下:

import torch
import numpy as np
import mathfrom torch import nn
import matplotlib.pyplot as plttorch.manual_seed(0)
np.random.seed(0)# 正余弦位置编码
num_pos_feats = 512
temperature = 10000
normalize = False
scale = 2 * math.pi#圆周率a = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
a = torch.tensor(a)
#mask = [[False, False, False, True], [False, False, False, True], [False, False, False, True], [True, True, True, True]]
#mask = [[False, False, False, False], [False, False, False, False], [False, False, False, False], [False, False, False, False]]
mask = [[False for _ in range(100)] for _ in range(100)]
mask = torch.tensor(mask)
#print(mask)
assert mask is not None
not_mask = ~mask
#print(not_mask)
y_embed = not_mask.cumsum(0, dtype=torch.float32)
x_embed = not_mask.cumsum(1, dtype=torch.float32)
#print(y_embed)
print(x_embed)if normalize:eps = 1e-6# b = a[i:j:s]表示:i,j与上面的一样,但s表示步进,缺省为1.# 所以a[i:j:1]相当于a[i:j]# 当s<0时,i缺省时,默认为-1. j缺省时,默认为-len(a)-1# 所以a[::-1]相当于 a[-1:-len(a)-1:-1],也就是从最后一个元素到第一个元素复制一遍,即倒序。# 对于X[:,:,m:n]是取三维矩阵中第m维到第n-1维的所有数据# 归一化y_embed = y_embed / (y_embed[-1:, :] + eps) * scale  # y_embed[:, -1:, :]代表取三维数据中的最后一行数据x_embed = x_embed / (x_embed[:, -1:] + eps) * scale  # x_embed[:, :, -1:]代表取三维数据中的最后一列数据#print(y_embed)print(x_embed)
dim_t1 = torch.arange(num_pos_feats, dtype=torch.float32, device=a.device)
#print(dim_t1)
dim_t = temperature ** (2 * (dim_t1 // 2) / num_pos_feats)  # i=dim_t1 // 2
#print(dim_t)
pos_x = x_embed[:, :, None] / dim_t
pos_x = pos_x[0]
pos_y = y_embed[:, :, None] / dim_t
print(pos_x)
#print(pos_y)
pos_x = torch.stack((pos_x[ :, 0::2].sin(), pos_x[ :, 1::2].cos()), dim=2).flatten(1)  # 不降维
pos_y = torch.stack((pos_y[ :, 0::2].sin(), pos_y[ :, 1::2].cos()), dim=2).flatten(1)  # 不降维
print(pos_x)
#print(pos_y)
#pos = torch.cat((pos_y, pos_x), dim=2)
#print(pos)
#print(pos.shape)attention_scores = torch.matmul(pos_x, pos_x.T)/512**0.5  # (seq_len, seq_len)# 应用 softmax 来获得注意力权重
attention_weights = torch.softmax(attention_scores, dim=-1)# 可视化注意力权重
plt.figure(figsize=(10, 8))
plt.imshow(attention_weights.detach().numpy(), cmap='viridis', aspect='auto')
plt.colorbar(label='Attention Weight')
plt.title('Attention Weights from Sinusoidal Position Encoding')
plt.xlabel('Position')
plt.ylabel('Position')
plt.show()q = torch.tensor(pos_x)
q = q.reshape(100,512)
q = q[:,None,:]
print(q.shape)
self_attn = nn.MultiheadAttention(num_pos_feats, num_heads=2)
# 手动设置权重和偏置为单位矩阵和零
identity = torch.eye(512)
self_attn.in_proj_weight.data = torch.cat([identity, identity, identity])
self_attn.in_proj_bias.data = torch.zeros(3 * 512)
temp = self_attn(q, q, q)[1]
print(temp)
temp = temp.squeeze().detach().numpy()
#temp = temp.detach().numpy()
print(temp.shape)# 绘制热图
plt.figure(figsize=(8, 6))plt.imshow(temp, cmap='viridis', aspect='auto')
plt.colorbar(label='Value')  # 添加颜色条
plt.title('Heatmap of Tensor Values')
plt.xlabel('Column Index')
plt.ylabel('Row Index')# 显示图形
plt.show()


http://www.ppmy.cn/devtools/137111.html

相关文章

RabbitMQ的交换机总结

1.direct交换机 2.fanout交换机

数据库的联合查询

数据库的联合查询 简介为什么要使⽤联合查询多表联合查询时MYSQL内部是如何进⾏计算的构造练习案例数据案例&#xff1a;⼀个完整的联合查询的过程 内连接语法⽰例 外连接语法 ⽰例⾃连接应⽤场景示例表连接练习 ⼦查询语法单⾏⼦查询多⾏⼦查询多列⼦查询在from⼦句中使⽤⼦查…

更高效的Java 23开发,IntelliJ IDEA助力全面升级

IntelliJ IDEA&#xff0c;是java编程语言开发的集成环境。IntelliJ在业界被公认为最好的java开发工具&#xff0c;尤其在智能代码助手、代码自动提示、重构、JavaEE支持、各类版本工具(git、svn等)、JUnit、CVS整合、代码分析、 创新的GUI设计等方面的功能可以说是超常的。 随…

uiautomator案例

test下新建类 public class ButtonClickTest {private UiDevice device;Beforepublic void setUp() {// 初始化 UiDevice 实例device UiDevice.getInstance(InstrumentationRegistry.getInstrumentation());try {device.executeShellCommand("am start -n com.yy.test/.…

C++ 优先算法 —— 长度最小的子数组(滑动窗口)

目录 题目&#xff1a;长度最小的子数组 1. 题目解析 2. 算法原理 Ⅰ. 暴力枚举 Ⅱ. 滑动窗口&#xff08;同向双指针&#xff09; 滑动窗口正确性 3. 代码实现 Ⅰ. 暴力枚举(会超时&#xff09; Ⅱ. 滑动窗口&#xff08;同向双指针&#xff09; 题目&#xff1a;长…

【GPT】睡觉时,大脑在做什么

睡觉时&#xff0c;大脑并不是完全“关闭”的&#xff0c;而是处于高度活跃的状态&#xff0c;进行许多重要的功能。以下是大脑在不同睡眠阶段的主要活动&#xff1a; 1. 修复与恢复 神经元修复&#xff1a;大脑细胞会修复白天受到的损伤&#xff0c;同时清除代谢废物&#xf…

【ArcGISPro】使用AI提取要素-土地分类(sentinel2)

Sentinel2数据处理 【ArcGISPro】Sentinel-2数据处理-CSDN博客 土地覆盖类型分类 处理结果

51单片机快速入门之中断的应用 2024/11/23 串口中断

51单片机快速入门之中断的应用 基本函数: void T0(void) interrupt 1 using 1 { 这里放入中断后需要做的操作 } void T0(void)&#xff1a; 这是一个函数声明&#xff0c;表明函数 T0 不接受任何参数&#xff0c;并且不返回任何值。 interrupt 1&#xff1a; 这是关键字和参…