稀疏注意力:时间序列预测的局部性和Transformer的存储瓶颈

server/2024/12/22 19:32:56/

        时间序列预测是许多领域的重要问题,包括对太阳能发电厂发电量、电力消耗和交通拥堵情况的预测。在本文中,提出用Transformer来解决这类预测问题。虽然在我们的初步研究中对其性能印象深刻,但发现了它的两个主要缺点:(1)位置不可知性:规范Transformer架构中的点积自关注对局部上下文不敏感,这可能使模型在时间序列中容易出现异常;(2)内存瓶颈:正则Transformer的空间复杂度随序列长度L呈二次增长,使得直接建模长时间序列变得不可行。

        为了解决这两个问题,首先提出了卷积自注意力,通过使用因果卷积产生查询和键,以便更好地将本地上下文纳入注意机制。然后,提出了仅O(L(log L)^2)内存开销的LogSparse Transformer,在内存预算受限的情况下,提高了对细粒度、长期依赖性强的时间序列的预测精度。在合成数据和真实世界数据集上的实验表明,它比最先进的技术更有优势。        

1. 引言

        深度神经网络被提出作为另一种解决方案,其中递归神经网络(RNN)已被用于以自回归的方式对时间序列建模。然而,众所周知,RNN很难训练。由于梯度消失和爆炸问题。尽管出现了各种变体,包括LSTM和GRU(门控循环单元),问题仍然没有解决。如何对长期依赖关系进行建模成为实现良好性能的关键步骤。

        规范Transformer的空间复杂度随输入长度L呈二次增长,这对直接建模细粒度长时间序列造成了内存瓶颈。 主要贡献:①成功地将Transformer架构应用于时间序列预测,并在合成数据集和真实数据集上进行了广泛的实验,以验证Transformer在处理长期依赖关系方面比基于RNN的模型更好的潜在价值。

        ②提出卷积自注意,通过使用因果卷积在自注意层产生查询和键。感知局部上下文(如形状)的查询键匹配可以帮助模型降低训练损失,进一步提高预测精度。

        ③提出LogSparse Transformer,只有O(L(log L)^2)空间复杂度来打破内存瓶颈,不仅使细粒度的长时间序列建模可行,而且与规范Transformer相比,使用更少的内存可以产生相当甚至更好的结果。

2. 相关工作

        时间序列预测领域中不同方法的发展和挑战,并强调了几种主要的模型。首先,文章提到ARIMA模型,这是时间序列预测中非常著名的一种方法。ARIMA模型因其统计性质和Box-Jenkins方法论而备受推崇,后者是一种在模型选择过程中广泛使用的方法。因此,ARIMA模型通常是实践者在时间序列预测中首先尝试的工具。然而,ARIMA模型有一些局限性。它假设时间序列是线性的,这在处理更复杂的、非线性的时间序列时可能表现不佳。此外,ARIMA模型的扩展性有限,难以应用于大规模的预测任务,并且每个时间序列都必须独立拟合,这意味着无法在相似的时间序列之间共享信息。

        相反,有些方法尝试通过矩阵分解的方法处理相关时间序列数据,把预测问题看作矩阵分解问题。另外,还有研究提出了分层贝叶斯方法,从图模型的角度来学习多个相关的计数时间序列。接着,文章介绍了深度神经网络在时间序列预测中的应用。这些模型可以捕捉相关时间序列之间的共享信息,从而提高预测的准确性。比如,有研究将传统的自回归(AR)模型与递归神经网络(RNN)结合起来,采用编码器-解码器的方式对概率分布进行建模。另一种方法使用RNN作为编码器,使用多层感知机(MLP)作为解码器,以解决误差累积问题,并且能够进行多步并行预测。此外,还有模型使用全局RNN来直接输出线性状态空间模型(SSM)的参数,目的是用局部线性片段来近似非线性动态。也有研究通过使用局部高斯过程来处理每个时间序列中的噪声,同时使用全局RNN来建模共享模式。另一些方法则试图结合AR模型和SSM的优势,保持复杂的潜在过程以进行多步并行预测。

        Transformer在序列建模中取得了很大成功,并且已被广泛应用于翻译、语音、音乐和图像生成等领域。然而,当处理极长的序列时,注意力机制的计算复杂度会随着序列长度的增加而呈二次方增长,这在处理高粒度且具有强长期依赖性的时间序列时,成为了一个严重的问题。

3. 背景

3.1 问题定义

        首先定义了一个时间序列预测的问题。在这个问题中,我们有一个包含N个相关单变量时间序列的集合,每个时间序列记为 z_{i,1:t_0},表示从时间1到时间 t_0​ 的观测值。目标是预测这些时间序列未来的 \tau 个时间步的值,即 z_{i,t_0+1:t_0+\tau}​。此外,假设有一个与时间相关的协变量集合 x_{i,1:t_0+\tau},其维度为d,这些协变量可能包括诸如星期几、一天中的小时等已知信息。我们需要建模条件分布 p(z_{i,t_0+1:t_0+\tau}|z_{i,1:t_0},x_{i,1:t_0+\tau};\omega),其中 \omega 是所有时间序列共享的可学习参数。

        接着,问题被简化为学习一个一步预测模型,即 p(z_t | z_{1:t-1}, x_{1:t}; \omega),其中 \omega 表示模型的可学习参数。为了充分利用观测值和协变量,作者将它们连接起来,形成一个扩展矩阵 y_t = [z_{t-1}, x_t],然后通过 Y_t = [y_1, \cdots, y_t]^T 表示所有的观测数据和协变量的集合。接下来,研究探索了一个合适的模型 z_t \sim f(Y_t),用于预测给定 Y_t​ 时 z_t​ 的分布。

        然后,文章介绍了Transformer模型,并提出将其作为函数 f 的实例,因为Transformer通过多头自注意力机制能够捕捉到时间序列中的长短期依赖性。不同的注意力头可以专注于不同的时间模式,这使得Transformer在时间序列预测中成为一个很有潜力的候选模型。

        在自注意力层中,多头自注意力子层同时将 Y 转换为H个不同的查询矩阵 Q_h、键矩阵 K_h​ 和值矩阵 V_h​,其中 h = 1, \cdots, H。这些矩阵通过线性投影获得,它们的学习参数分别为 W_Q^h​、 W_K^h​ 和 W_V^h​。在这些线性投影之后,缩放点积注意力机制计算出一系列的向量输出 O_h​,这些输出是通过公式 O_h=\mathrm{Attention}(Q_h,K_h,V_h)=\mathrm{softmax}\left(\frac{Q_hK_h^T}{\sqrt{d_k}}\cdot M\right)V_h 计算得到的。这里,掩码矩阵 MMM 被应用于过滤右侧的注意力,以避免未来信息泄露。然后,所有的 O_h 被连接起来并再次进行线性投影。最后,在注意力输出上叠加了一个由两层全连接网络和中间ReLU激活层组成的位置前馈子层。

4. 方法论

4.1 增强Transformer的局部性

        时间序列中的模式可能由于各种事件(如假期和极端天气)随时间发生显著变化的现象。因此,判断一个观测点是异常点、变更点还是模式的一部分,很大程度上依赖于其周围的上下文。然而,在经典Transformer的自注意力层中,查询和键之间的相似性是基于它们逐点值来计算的,未能充分利用局部上下文信息(如形状)。这种对局部上下文不敏感的查询-键匹配可能会导致自注意力模块混淆观测值的性质,从而引发潜在的优化问题。

        为了解决这个问题,提出了卷积自注意力机制。图1(c)和(d)展示了这种卷积自注意力的架构。不同于使用核大小为1且步幅为1的卷积(即矩阵乘法),采用核大小为k且步幅为1的因果卷积,将输入(经过适当的填充)转换为查询和键。因果卷积确保当前位置不会访问未来信息。通过使用因果卷积,生成的查询和键能够更加感知局部上下文,从而基于局部上下文信息(如局部形状)来计算相似性,而不是简单的逐点取值,这有助于提高预测的准确性。值得注意的是,当k=1时,卷积自注意力将退化为经典自注意力,因此它可以看作经典自注意力的一种广义形式。

        因果卷积:具体来说,假设输入序列的长度为 T,卷积核的大小为 k。在因果卷积中,输入会在前面添加 k-1个零填充,这样卷积运算就只会考虑当前和之前的时间步,而不会涉及未来的时间步。这种方式保证了模型在训练和推理时遵循时间顺序,从而保持因果性。 

        图1展示了经典自注意力层和卷积自注意力层的比较。图1(a)显示了经典自注意力可能错误地逐点匹配输入的情况,图1(b)则展示了经典自注意力在Transformer中的应用。而图1(c)和(d)则展示了卷积自注意力如何通过形状匹配来正确匹配最相关的特征。

4.2 突破Transformer的内存瓶颈

        首先对经典Transformer在traffic-f数据集上学习到的注意力模式进行了定性评估。traffic-f数据集包含旧金山湾区963条车道的占用率数据,每20分钟记录一次。在traffic-f数据集上训练了一个10层的经典Transformer,并对学习到的注意力模式进行了可视化。引入某种形式的稀疏性而不会显著影响性能。更重要的是,对于长度为 L 的序列,计算每对单元之间的注意力分数会导致 O(L^2) 的内存使用量,使得对具有精细粒度和强长期依赖的长时间序列进行建模变得非常困难。

        为了解决这个问题,提出了LogSparse Transformer,这种方法只需要计算每个单元在每层中的 O(\log L) 个点积。此外,只需要堆叠最多O(\log L)层,模型就能够访问每个单元的信息。因此,总的内存使用成本仅为 O(L(\log L)^2)。我们将 I_k^l​ 定义为在第 k 层到第 k+1 层计算过程中,单元 l 表示可以访问的单元索引集。在标准的Transformer自注意力中, I_k^l = \{j : j \leq l\},这意味着每个单元都可以访问其所有过去的单元及其自身,如图3(a)所示。

        然而,这种算法在输入长度增加时会导致空间复杂度的二次增长。为了解决这个问题,提出选择 I_k^l 的一个子集 I_k^l \subseteq \{j : j \leq l\},使得 \|I_k^l\| 不会随着 l 的增加而增长得太快。选择索引的一个有效方法是 |I_k^l| \propto \log L

        需要注意的是,单元 l 是在第 k 层中通过加权组合索引为 I_k^l​ 的单元生成的,并且可以将这些信息传递给下一层的后续单元。令 S_k^l 为包含所有到第 k 层为止传递给单元 l 的单元索引的集合。为了确保每个单元接收到所有之前的单元及其自身的信息,堆叠的层数 \tilde{k}_l 应满足 S_{\tilde{k}_l}^l = \{j : j \leq l\},即对于每个 lj \leq l,存在一个具有 \tilde{k}_l 条边的路径 P_{jl} = (j, p_1, p_2, \dots, l),其中 j \in I_1^{p_1}, p_1 \in I_2^{p_2}, \dots, p_{\tilde{k}_l-1} \in I_{\tilde{k}_l}^l

        通过允许每个单元仅以指数步长访问其之前的单元及其自身来提出LogSparse自注意力。即对于所有的 k 和 lI_k^l = \{l \% 2^{\lfloor \log_2 l \rfloor}, l \% 2^{\lfloor \log_2 l \rfloor - 1}, \dots, l \% 2^0, l\},其中 \lfloor \cdot \rfloor 表示向下取整运算,如图3(b)所示。

        定理1表明,尽管每层的内存使用量从 O(L^2) 减少到 O(L\log^2 L),但信息仍然可以从任意单元流向另一个单元,只需稍微“加深”模型——将层数设为 \lfloor \log_2 L \rfloor+1。这意味着总体内存使用量为 O(L(\log^2 L)^2),解决了Transformer在GPU内存限制下的扩展性瓶颈。此外,随着两个单元之间的距离增大,路径的数量会以 log_2(l - j) 的超指数速率增加,这表明LogSparse Transformer在建模精细的长期依赖关系时能够实现丰富的信息流动。

4.3 Logparase注意力

        LogSparse注意力是一种针对长序列时间复杂度和内存使用量进行优化的自注意力机制。它是对经典Transformer模型中自注意力机制的一种改进,旨在解决处理长序列时计算资源消耗过大的问题。

        在传统的Transformer中,自注意力机制的计算复杂度是二次方的,即对于长度为 L 的序列,每个元素需要与其他 L-1个元素进行相互计算,导致整个序列的计算量和内存使用量为 O(L^2)。这种复杂度在处理长序列时会变得非常昂贵,尤其是在需要处理大量数据的情况下。

LogSparse注意力 通过以下方式优化了这一过程:

  1. 选择性注意力(Selective Attention):LogSparse注意力并不计算每个序列元素与所有其他元素之间的注意力得分,而是引入了一种稀疏化策略。具体来说,它允许每个元素只关注一小部分与其相关的元素,而不是所有元素。这些相关元素的选择是基于指数步长的,即每个元素只与其之前的少量元素进行注意力计算,这些元素之间的距离按对数规律增长。这意味着,如果当前元素是第 l 个,那么它只会与之前的一些元素进行计算,而这些元素的索引为 l,l-2^1, l-2^2,\dots,l-2^{\left \lfloor log_2l \right \rfloor} 等。

  2. 对数级别的复杂度(Logarithmic Complexity):这种选择性注意力策略将原本的 O(L^2) 复杂度降低到了 O(L \log^2 L)。因为每个元素只需计算 O(\log L) 个注意力分数,而整个序列需要堆叠 O(\log L) 层,以确保所有元素都能互相通信。

        通过这种方法,LogSparse注意力在处理长序列时能够显著减少内存使用和计算时间,同时保留Transformer模型的强大建模能力,特别是对于长时间依赖关系的建模非常有效。

4.3.1 增强模型性能 

如何在LogSparse自注意力机制的基础上进一步增强模型的性能,同时保持计算复杂度的控制?

  1. 局部注意力(Local Attention):在LogSparse Transformer中,虽然每个单元只需要访问先前的一些关键单元,但为了更好地捕捉局部信息(如趋势),可以让每个单元密集地关注其左侧邻近的单元,窗口大小为 O(\log^2 L)。这样,每个单元可以利用更多的局部信息来进行当前步的预测。在这种局部窗口之外,仍然可以继续采用LogSparse注意力策略,如图3(c)所示。

  2. 重启注意力(Restart Attention):这个策略是将整个输入序列(长度为 L)分成多个子序列,每个子序列的长度为 L_{sub},其中 L_{sub} \approx L。对每个子序列分别应用LogSparse注意力策略,类似于重新开始注意力计算的过程。这样可以减少模型处理每个子序列时的复杂性,并且每个子序列可以独立地进行信息处理,如图3(d)所示。

  3. 结合局部注意力和重启注意力:使用局部注意力和重启注意力不会改变LogSparse自注意力策略的计算复杂度,但会增加更多的信息路径,并减少路径中所需的边数。这意味着可以在不增加计算成本的情况下,提高模型对局部信息的捕捉能力和对长序列的处理能力。通过结合局部注意力和重启注意力,模型能够更有效地捕捉序列中的各种模式和趋势。


http://www.ppmy.cn/server/100895.html

相关文章

Windows有哪些免费好用的PDF编辑器推荐?

不是所有PDF编辑器都免费,但我推荐的这3个一定免费简单好用!! 1、转转大师PDF编辑器 点击直达链接>>pdftoword.55.la 转转大师PDF编辑器是一款专业的PDF编辑工具,功能丰富,操作简单,作为微软office…

基于微信小程序的电子配件销售系统设计与实现 ---附源码15161

摘 要 随着移动互联网的快速发展,电子商务已成为传统零售业的重要补充。本论文针对电子配件销售领域,在分析市场需求的基础上,提出了一种基于微信小程序的电子配件销售系统设计方案。首先,利用市场调研和用户需求分析,…

mysql8.0使用binlog2sql恢复误删除数据踩坑记录

一、操作环境 操作系统版本 Windows 11 家庭中文版 Python版本 3.12.4 mysql版本 8.0.36二、mysql的binlog 确认binlog已开启 SHOW VARIABLES LIKE log_bin; #查询结果 #Variable_name,Value # log_bin,ON#Value为ON,binlog开启,未开始无法通过本方法恢…

陶晶池串口屏主动解析模式与被动解析模式的底层逻辑

实际上屏幕的每个页面都是一个main.c文件,在这个main中的操作代码会在打开该页面一开始就执行。 例如:你在该页面写打开一个定时器,但是该定时器只在该页面会被打开,离开该页面就恢复为停止 主动解析模式与被动解析模式在recmod…

c++多态以及模版

#include <iostream> using namespace std;class Animal { private:string name; public://纯虚函数virtual void perform()0; }; class Lion:public Animal { public:void perform(){cout << "舞狮" << endl;} }; class Elephant:public Animal {…

怎么防止源代码泄露?十种方法杜绝源代码泄密风险

源代码是软件开发的核心资产之一&#xff0c;保护其不被泄露对企业的安全至关重要。源代码泄露不仅可能导致知识产权的丧失&#xff0c;还可能给企业带来经济损失和品牌形象的损害。以下是十种有效的方法&#xff0c;可以帮助企业杜绝源代码泄密的风险。 1. 代码加密 对源代码…

【C++学习笔记 18】C++中的隐式构造函数

举个例子 #include <iostream> #include <string>using String std::string;class Entity{ private:String m_Name;int m_Age; public:Entity(const String& name):m_Name(name), m_Age(-1) {}Entity(int age) : m_Name("UnKnown"), m_Age(age) {}…

网络安全知识渗透测试

渗透测试是一种模拟网络攻击&#xff0c;用于识别漏洞并制定规避防御措施的策略。及早发现缺陷使安全团队能够修复任何漏洞&#xff0c;从而防止数据泄露&#xff0c;否则可能会造成数十亿美元的损失。笔测试还有助于评估组织的合规性、提高员工对安全协议的认识、评估事件响应…