89. 注意力机制以及代码实现Nadaraya-Waston 核回归

news/2024/11/9 2:47:37/

1. 心理学

  • 动物需要在复杂环境下有效关注值得注意的点
  • 心理学框架:人类根据随意线索和不随意线索选择注意点

随意:随着自己的意识,有点强调主观能动性的意味。

在这里插入图片描述

2. 注意力机制

在这里插入图片描述

2. 非参注意力池化层

在这里插入图片描述

3. Nadaraya-Waston 核回归

在这里插入图片描述

4. 参数化的注意力机制

在这里插入图片描述

5. 总结

在这里插入图片描述

6. 代码实现注意力汇聚:Nadaraya-Waston 核回归

import torch
from torch import nn
from d2l import torch as d2l

6.1 生成数据集

在这里插入图片描述

n_train = 50  # 训练样本数
x_train, _ = torch.sort(torch.rand(n_train) * 5)   # 排序后的训练样本
def f(x):return 2 * torch.sin(x) + x**0.8y_train = f(x_train) + torch.normal(0.0, 0.5, (n_train,))  # 训练样本的输出
x_test = torch.arange(0, 5, 0.1)  # 测试样本
y_truth = f(x_test)  # 测试样本的真实输出
n_test = len(x_test)  # 测试样本数
n_test

下面的函数将绘制所有的训练样本(样本由圆圈表示), 不带噪声项的真实数据生成函数 𝑓 (标记为“Truth”), 以及学习得到的预测函数(标记为“Pred”)。

def plot_kernel_reg(y_hat):d2l.plot(x_test, [y_truth, y_hat], 'x', 'y', legend=['Truth', 'Pred'],xlim=[0, 5], ylim=[-1, 5])d2l.plt.plot(x_train, y_train, 'o', alpha=0.5);

6.2 平均汇聚

先使用最简单的估计器来解决回归问题。 基于平均汇聚来计算所有训练样本输出值的平均值:

在这里插入图片描述

如下图所示,这个估计器确实不够聪明。 真实函数 𝑓 (“Truth”)和预测函数(“Pred”)相差很大。

y_hat = torch.repeat_interleave(y_train.mean(), n_test)
plot_kernel_reg(y_hat)

在这里插入图片描述

6.3 非参数注意力汇聚

接下来,我们将基于这个非参数的注意力汇聚模型来绘制预测结果。 从绘制的结果会发现新的模型预测线是平滑的,并且比平均汇聚的预测更接近真实。

# X_repeat的形状:(n_test,n_train),
# 每一行都包含着相同的测试输入(例如:同样的查询)
X_repeat = x_test.repeat_interleave(n_train).reshape((-1, n_train))
# x_train包含着键。attention_weights的形状:(n_test,n_train),
# 每一行都包含着要在给定的每个查询的值(y_train)之间分配的注意力权重
attention_weights = nn.functional.softmax(-(X_repeat - x_train)**2 / 2, dim=1)
# y_hat的每个元素都是值的加权平均值,其中的权重是注意力权重
y_hat = torch.matmul(attention_weights, y_train)
plot_kernel_reg(y_hat)

运行结果:

在这里插入图片描述

现在来观察注意力的权重。 这里测试数据的输入相当于查询,而训练数据的输入相当于键。 因为两个输入都是经过排序的,因此由观察可知“查询-键”对越接近, 注意力汇聚的注意力权重就越高。

d2l.show_heatmaps(attention_weights.unsqueeze(0).unsqueeze(0),xlabel='Sorted training inputs',ylabel='Sorted testing inputs')

运行结果:

在这里插入图片描述

6.4 带参数注意力汇聚

1. 批量矩阵乘法

因此,假定两个张量的形状分别是 (𝑛,𝑎,𝑏) 和 (𝑛,𝑏,𝑐) , 它们的批量矩阵乘法输出的形状为 (𝑛,𝑎,𝑐)

X = torch.ones((2, 1, 4))
Y = torch.ones((2, 4, 6))
torch.bmm(X, Y).shape

运行结果:

在这里插入图片描述
在注意力机制的背景中,我们可以使用小批量矩阵乘法来计算小批量数据中的加权平均值。

weights = torch.ones((2, 10)) * 0.1
values = torch.arange(20.0).reshape((2, 10))
torch.bmm(weights.unsqueeze(1), values.unsqueeze(-1))

运行结果:

在这里插入图片描述

2. 定义模型

基于带参数的注意力汇聚,使用小批量矩阵乘法, 定义Nadaraya-Watson核回归的带参数版本为:

class NWKernelRegression(nn.Module):def __init__(self, **kwargs):super().__init__(**kwargs)self.w = nn.Parameter(torch.rand((1,), requires_grad=True))def forward(self, queries, keys, values):# queries和attention_weights的形状为(查询个数,“键-值”对个数)queries = queries.repeat_interleave(keys.shape[1]).reshape((-1, keys.shape[1]))self.attention_weights = nn.functional.softmax(-((queries - keys) * self.w)**2 / 2, dim=1)# values的形状为(查询个数,“键-值”对个数)return torch.bmm(self.attention_weights.unsqueeze(1),values.unsqueeze(-1)).reshape(-1)

3. 训练

接下来,将训练数据集变换为键和值用于训练注意力模型。 在带参数的注意力汇聚模型中, 任何一个训练样本的输入都会和除自己以外的所有训练样本的“键-值”对进行计算, 从而得到其对应的预测输出。

# X_tile的形状:(n_train,n_train),每一行都包含着相同的训练输入
X_tile = x_train.repeat((n_train, 1))
# Y_tile的形状:(n_train,n_train),每一行都包含着相同的训练输出
Y_tile = y_train.repeat((n_train, 1))
# keys的形状:('n_train','n_train'-1)
keys = X_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))
# values的形状:('n_train','n_train'-1)
values = Y_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))

训练带参数的注意力汇聚模型时,使用平方损失函数和随机梯度下降。

net = NWKernelRegression()
loss = nn.MSELoss(reduction='none')
trainer = torch.optim.SGD(net.parameters(), lr=0.5)
animator = d2l.Animator(xlabel='epoch', ylabel='loss', xlim=[1, 5])for epoch in range(5):trainer.zero_grad()l = loss(net(x_train, keys, values), y_train)l.sum().backward()trainer.step()print(f'epoch {epoch + 1}, loss {float(l.sum()):.6f}')animator.add(epoch + 1, float(l.sum()))

运行结果:

在这里插入图片描述
如下所示,训练完带参数的注意力汇聚模型后可以发现: 在尝试拟合带噪声的训练数据时, 预测结果绘制的线不如之前非参数模型的平滑。

# keys的形状:(n_test,n_train),每一行包含着相同的训练输入(例如,相同的键)
keys = x_train.repeat((n_test, 1))
# value的形状:(n_test,n_train)
values = y_train.repeat((n_test, 1))
y_hat = net(x_test, keys, values).unsqueeze(1).detach()
plot_kernel_reg(y_hat)

运行结果:

在这里插入图片描述

为什么新的模型更不平滑了呢? 下面看一下输出结果的绘制图: 与非参数的注意力汇聚模型相比, 带参数的模型加入可学习的参数后, 曲线在注意力权重较大的区域变得更不平滑

d2l.show_heatmaps(net.attention_weights.unsqueeze(0).unsqueeze(0),xlabel='Sorted training inputs',ylabel='Sorted testing inputs')

运行结果:

在这里插入图片描述


http://www.ppmy.cn/news/18435.html

相关文章

OpenCV级联分类器

OpenCV级联分类器 概览 OpenCV: 一个计算机视觉库, 提供了一种称级联分类器的方法检测对象级联分类器:一种基于AdaBoost算法的多级分类器, 用于在图像中检测目标对象. 它通过不断学习组合多个特征来识别目标对象. 每一级中, 级联分类器先检测出可能是目标对象的部分, 然后再这…

联合证券|港股再融资“春江水暖” 资本争购热门赛道企业

进入2023年,港股再融资商场有所回暖。到1月18日,已有27家港股上市公司发布拟配售股份(简称“配股”)再融资,募资总额164.01亿港元,较上一年同期增加148.16%。其间,微盟集团的配股再融资吸引了众…

python+selenium爬虫自动化批量下载文件

一、项目需求 在一个业务网站有可以一个个打开有相关内容的文本,需要逐个保存为TXT,数据量是以千为单位,人工操作会麻木到崩溃。 二、解决方案 目前的基础办法就是使用pythonselenium自动化来代替人工去操作,虽然效率比其他爬虫…

Java基础练习题(四)

13.求两点之间的距离 题目描述 给定A(x1, y1), B(x2, y2)两点坐标,计算它们间的距离。 输入 输入包含四个实数x1, y1, x2, y2,分别用空格隔开,含义如描述。其中0≤x1,x2,y1,y2≤100。 输出 输出占一行,包含一个实数d,表示A, B两点间的距离。结果保留两位小数。 样例输入

经典同步问题

同步问题是一个复杂的问题,但是它也有自己的方法去处理、去分析。PV操作系统的解题思路:关系分析。找出题目中描述的各个进程,分析它们之间的同步、互斥关系。(从事件的角度分析)整理思路。根据各进程的操作流程确定P、V操作的大致顺序。设置…

Cert Manager 申请SSL证书流程及相关概念-二

中英文对照表 英文英文 - K8S CRD中文备注certificatesCertificate证书certificates.cert-manager.io/v1certificate issuersIssuer证书颁发者issuers.cert-manager.ioClusterIssuer集群证书颁发者clusterissuers.cert-manager.iocertificate requestCertificateRequest证书申…

python—-下载Iwara视频

1.提示: 使用需要安装bs4库,selenium库,fake_useragent库,版本没什么要求 同时需要安装相同版本的Chrome浏览器和驱动器,注意驱动器和浏览器不一样哦 哦对了,还要自备梯子(不过某喵天天在Iwara打…

Linux网络编程套接字

文章目录一、预备知识1. IP 地址2.端口号3. TCP 协议和 UDP 协议4.网络字节序二、socket 编程接口0. socket 常见 API1. socket 系统调用2. bind 系统调用3. recvfrom 系统调用4. sendto 系统调用5. listen 系统调用6. accept 系统调用7. connect 系统调用三、简单的 UDP 网络程…