RNN中远距离时间步梯度消失问题及解决办法

devtools/2025/2/22 10:46:09/

RNN中远距离时间步梯度消失问题及解决办法

  • RNN 远距离时间步梯度消失问题
  • LSTM如何解决远距离时间步梯度消失问题


RNN 远距离时间步梯度消失问题

经典的RNN结构如下图所示:
在这里插入图片描述
假设我们的时间序列只有三段, S 0 S_{0} S0 为给定值,神经元没有激活函数,则RNN最简单的前向传播过程如下:

S 1 = W x X 1 + W s S 0 + b 1 , O 1 = W 0 S 1 + b 2 S_{1} = W_{x} X_{1} + W_{s}S_{0} + b_{1},O_{1} = W_{0} S_{1} + b_{2} S1=WxX1+WsS0+b1O1=W0S1+b2

S 2 = W x X 2 + W s S 1 + b 1 , O 2 = W 0 S 2 + b 2 S_{2} = W_{x} X_{2} + W_{s}S_{1} + b_{1},O_{2} = W_{0} S_{2} + b_{2} S2=WxX2+WsS1+b1O2=W0S2+b2

S 3 = W x X 3 + W s S 2 + b 1 , O 3 = W 0 S 3 + b 2 S_{3} = W_{x} X_{3} + W_{s}S_{2} + b_{1},O_{3} = W_{0} S_{3} + b_{2} S3=WxX3+WsS2+b1O3=W0S3+b2

假设在 t = 3 t=3 t=3时刻,损失函数为 L 3 = 1 2 ( Y 3 − O 3 ) 2 L_3 = \frac{1}{2}(Y_3 - O_3)^2 L3=21(Y3O3)2 。则对于一次训练任务的损失函数为 L = ∑ t = 0 T L t L = \sum_{t=0}^{T} L_t L=t=0TLt ,即每一时刻损失值的累加。

使用随机梯度下降法训练RNN其实就是对 W x W_x Wx W s W_s Ws W o W_o Wo 以及 b 1 、 b 2 b_1 、 b_2 b1b2 求偏导,并不断调整它们以使 L L L尽可能达到最小的过程。

现在假设我们我们的时间序列只有三段:t1,t2,t3。我们只对 t 3 t3 t3时刻的 W x W_x Wx W s W_s Ws W o W_o Wo 求偏导(其他时刻类似):

∂ L 3 ∂ W 0 = ∂ L 3 ∂ O 3 ∂ O 3 ∂ W o = ∂ L 3 ∂ O 3 S 3 \frac{\partial L_3}{\partial W_0} = \frac{\partial L_3}{\partial O_3} \frac{\partial O_3}{\partial W_o} = \frac{\partial L_3}{\partial O_3} S_3 W0L3=O3L3WoO3=O3L3S3

∂ L 3 ∂ W x = ∂ L 3 ∂ O 3 ∂ O 3 ∂ S 3 ∂ S 3 ∂ W x + ∂ L 3 ∂ O 3 ∂ O 3 ∂ S 3 ∂ S 3 ∂ S 2 ∂ S 2 ∂ W x + ∂ L 3 ∂ O 3 ∂ O 3 ∂ S 3 ∂ S 3 ∂ S 2 ∂ S 2 ∂ S 1 ∂ S 1 ∂ W x = ∂ L 3 ∂ O 3 W 0 ( X 3 + S 2 W s + S 1 W s 2 ) \frac{\partial L_3}{\partial W_x} = \frac{\partial L_3}{\partial O_3} \frac{\partial O_3}{\partial S_3} \frac{\partial S_3}{\partial W_x} + \frac{\partial L_3}{\partial O_3} \frac{\partial O_3}{\partial S_3} \frac{\partial S_3}{\partial S_2} \frac{\partial S_2}{\partial W_x} + \frac{\partial L_3}{\partial O_3} \frac{\partial O_3}{\partial S_3} \frac{\partial S_3}{\partial S_2} \frac{\partial S_2}{\partial S_1} \frac{\partial S_1}{\partial W_x} = \frac{\partial L_3}{\partial O_3} W_0 (X_3 + S_2 W_s + S_1 W_s^2) WxL3=O3L3S3O3WxS3+O3L3S3O3S2S3WxS2+O3L3S3O3S2S3S1S2WxS1=O3L3W0(X3+S2Ws+S1Ws2)

∂ L 3 ∂ W s = ∂ L 3 ∂ O 3 ∂ O 3 ∂ S 3 ∂ S 3 ∂ W s + ∂ L 3 ∂ O 3 ∂ O 3 ∂ S 3 ∂ S 3 ∂ S 2 ∂ S 2 ∂ W s + ∂ L 3 ∂ O 3 ∂ O 3 ∂ S 3 ∂ S 3 ∂ S 2 ∂ S 2 ∂ S 1 ∂ S 1 ∂ W s = ∂ L 3 ∂ O 3 W 0 ( S 2 + S 1 W s + S 0 W s 2 ) \frac{\partial L_3}{\partial W_s} = \frac{\partial L_3}{\partial O_3} \frac{\partial O_3}{\partial S_3} \frac{\partial S_3}{\partial W_s} + \frac{\partial L_3}{\partial O_3} \frac{\partial O_3}{\partial S_3} \frac{\partial S_3}{\partial S_2} \frac{\partial S_2}{\partial W_s} + \frac{\partial L_3}{\partial O_3} \frac{\partial O_3}{\partial S_3} \frac{\partial S_3}{\partial S_2} \frac{\partial S_2}{\partial S_1} \frac{\partial S_1}{\partial W_s} = \frac{\partial L_3}{\partial O_3} W_0 (S_2 + S_1 W_s + S_0 W_s^2) WsL3=O3L3S3O3WsS3+O3L3S3O3S2S3WsS2+O3L3S3O3S2S3S1S2WsS1=O3L3W0(S2+S1Ws+S0Ws2)

关于上面这个多元复合函数链式求导过程,通过如下对变量层级树的遍历可以更加直观理解这一点:
在这里插入图片描述
可以看出对于 W o W_o Wo 求偏导并没有长期依赖,但是对于 W x W_x Wx W s W_s Ws 求偏导,会随着时间序列产生长期依赖。因为 S t S_t St 随着时间序列向前传播,而 S t S_t St 又是 W x W_x Wx W s W_s Ws 的函数。

根据上述求偏导的过程,我们可以得出任意时刻对 W x W_x Wx W s W_s Ws 求偏导的公式:

∂ L t ∂ W x = ∑ k = 0 t ∂ L t ∂ O t ∂ O t ∂ S t ( ∏ j = k + 1 t ∂ S j ∂ S j − 1 ) ∂ S k ∂ W x \frac{\partial L_t}{\partial W_x} = \sum_{k=0}^{t} \frac{\partial L_t}{\partial O_t} \frac{\partial O_t}{\partial S_t} \left(\prod_{j=k+1}^{t} \frac{\partial S_j}{\partial S_{j-1}}\right) \frac{\partial S_k}{\partial W_x} WxLt=k=0tOtLtStOt j=k+1tSj1Sj WxSk

任意时刻对 W s W_s Ws 求偏导的公式同上。

如果加上激活函数: S j = tanh ⁡ ( W x X j + W s S j − 1 + b 1 ) S_j = \tanh(W_x X_j + W_s S_{j-1} + b_1) Sj=tanh(WxXj+WsSj1+b1)

∏ j = k + 1 t ∂ S j ∂ S j − 1 = ∏ j = k + 1 t tanh ⁡ ′ W s \prod_{j=k+1}^{t} \frac{\partial S_j}{\partial S_{j-1}} = \prod_{j=k+1}^{t} \tanh' W_s j=k+1tSj1Sj=j=k+1ttanhWs

加上激活函数tanh复合后的多元链式求导过程如下图所示:

在这里插入图片描述

激活函数tanh和它的导数图像如下。

在这里插入图片描述

由上图可以看出 tanh ⁡ ′ ≤ 1 \tanh' \leq 1 tanh1,对于训练过程大部分情况下tanh的导数是小于1的,因为很少情况下会出现 W x X j + W s S j − 1 + b 1 = 0 W_x X_j + W_s S_{j-1} + b_1 = 0 WxXj+WsSj1+b1=0,如果 W s W_s Ws 也是一个大于0小于1的值,则当t很大时 ∏ j = k + 1 t tanh ⁡ ′ W s \prod_{j=k+1}^{t} \tanh' W_s j=k+1ttanhWs,就会趋近于0,和 0.0 1 50 0.01^{50} 0.0150 趋近于0是一个道理。同理当 W s W_s Ws 很大时 ∏ j = k + 1 t tanh ⁡ ′ W s \prod_{j=k+1}^{t} \tanh' W_s j=k+1ttanhWs 就会趋近于无穷,这就是RNN中梯度消失和爆炸的原因。

至于怎么避免这种现象,再看看 ∂ L t ∂ W x = ∑ k = 0 t ∂ L t ∂ O t ∂ O t ∂ S t ( ∏ j = k + 1 t ∂ S j ∂ S j − 1 ) ∂ S k ∂ W x \frac{\partial L_t}{\partial W_x} = \sum_{k=0}^{t} \frac{\partial L_t}{\partial O_t} \frac{\partial O_t}{\partial S_t} \left(\prod_{j=k+1}^{t} \frac{\partial S_j}{\partial S_{j-1}}\right) \frac{\partial S_k}{\partial W_x} WxLt=k=0tOtLtStOt j=k+1tSj1Sj WxSk 梯度消失和爆炸的根本原因就是 ∏ j = k + 1 t ∂ S j ∂ S j − 1 \prod_{j=k+1}^{t} \frac{\partial S_j}{\partial S_{j-1}} j=k+1tSj1Sj 这一坨,要消除这种情况就需要把这一坨在求偏导的过程中去掉,至于怎么去掉,一种办法就是使 ∂ S j ∂ S j − 1 ≈ 1 \frac{\partial S_j}{\partial S_{j-1}} \approx 1 Sj1Sj1 另一种办法就是使 ∂ S j ∂ S j − 1 ≈ 0 \frac{\partial S_j}{\partial S_{j-1}} \approx 0 Sj1Sj0。其实这就是LSTM做的事情。

总结:

  • RNN 的梯度计算涉及到对激活函数的导数以及权重矩阵的连乘

    • 以 sigmoid 函数为例,其导数的值域在 0 到 0.25 之间,当进行多次连乘时,这些较小的值相乘会导致梯度迅速变小。
    • 如果权重矩阵的特征值也小于 1,那么在多个时间步的传递过程中,梯度就会呈指数级下降,导致越靠前的时间步,梯度回传的值越少。
  • 由于梯度消失,靠前时间步的参数更新幅度会非常小,甚至几乎不更新。这使得模型难以学习到序列数据中长距离的依赖关系,对于较早时间步的信息利用不足,从而影响模型的整体性能和对序列数据的建模能力。

注意 : 注意: 注意:

RNN梯度爆炸好理解,就是 ∂ L t ∂ W x \frac{\partial L_t}{\partial W_x} WxLt梯度数值发散,甚至慢慢就NaN了;

那梯度消失就是 ∂ L t ∂ W x \frac{\partial L_t}{\partial W_x} WxLt梯度变成零吗?

并不是,我们刚刚说梯度消失是 ∣ ∂ S j ∂ S j − 1 ∣ \left|\frac{\partial S_j}{\partial S_{j-1}}\right| Sj1Sj 一直小于1,历史梯度不断衰减,但不意味着总的梯度就为0了。RNN中梯度消失的含义是距离当前时间步越长,那么其反馈的梯度信号越不显著,最后可能完全没有起作用,这就意味着RNN对长距离语义的捕捉能力失效了

说白了,你优化过程都跟长距离的反馈没关系,怎么能保证学习出来的模型能有效捕捉长距离呢?

再次通俗解释一下RNN梯度消失,其指的不是 ∂ L t ∂ W x \frac{\partial L_t}{\partial W_x} WxLt梯度值接近于0,而是靠前时间步的梯度 ∂ L 3 ∂ O 3 ∂ O 3 ∂ S 3 ∂ S 3 ∂ S 2 ∂ S 2 ∂ S 1 ∂ S 1 ∂ W x \frac{\partial L_3}{\partial O_3} \frac{\partial O_3}{\partial S_3} \frac{\partial S_3}{\partial S_2} \frac{\partial S_2}{\partial S_1} \frac{\partial S_1}{\partial W_x} O3L3S3O3S2S3S1S2WxS1值算出来很小,也就是靠前时间步计算出来的结果对序列最后一个预测词的生成影响很小,也就是常说的RNN难以去建模长距离的依赖关系的原因;这并不是因为序列靠前的词对最后一个词的预测输出不重要,而是由于损失函数在把有用的梯度更新信息反向回传的过程中,被若干小于0的偏导连乘给一点点削减掉了。


LSTM如何解决远距离时间步梯度消失问题

在这里插入图片描述

LSTM的更新公式比较复杂,它是:

f t = σ ( W f x t + U f h t − 1 + b f ) f_t = \sigma (W_f x_t + U_f h_{t-1} + b_f) ft=σ(Wfxt+Ufht1+bf)
i t = σ ( W i x t + U i h t − 1 + b i ) i_t = \sigma (W_i x_t + U_i h_{t-1} + b_i) it=σ(Wixt+Uiht1+bi)
o t = σ ( W o x t + U o h t − 1 + b o ) o_t = \sigma (W_o x_t + U_o h_{t-1} + b_o) ot=σ(Woxt+Uoht1+bo)
c ^ t = tanh ⁡ ( W c x t + U c h t − 1 + b c ) \hat{c}_t = \tanh (W_c x_t + U_c h_{t-1} + b_c) c^t=tanh(Wcxt+Ucht1+bc)
c t = f t ∘ c t − 1 + i t ∘ c ^ t c_t = f_t \circ c_{t-1} + i_t \circ \hat{c}_t ct=ftct1+itc^t
h t = o t ∘ tanh ⁡ ( c t ) h_t = o_t \circ \tanh(c_t) \qquad ht=ottanh(ct)

我们可以像上面一样计算 ∂ h t ∂ h t − 1 \frac{\partial h_t}{\partial h_{t-1}} ht1ht,但从 h t = o t ∘ tanh ⁡ ( c t ) h_t = o_t \circ \tanh(c_t) ht=ottanh(ct) 可以看出分析 c t c_t ct 就等价于分析 h t h_t ht,而计算 ∂ c t ∂ c t − 1 \frac{\partial c_t}{\partial c_{t-1}} ct1ct 显得更加简单一些,因此我们往这个方向走。

同样地,我们先只关心1维的情形,这时候根据求导公式,我们有

∂ c t ∂ c t − 1 = f t + c t − 1 ∂ f t ∂ c t − 1 + c ^ t ∂ i t ∂ c t − 1 + i t ∂ c ^ t ∂ c t − 1 \frac{\partial c_t}{\partial c_{t-1}} = f_t + c_{t-1} \frac{\partial f_t}{\partial c_{t-1}} + \hat{c}_t \frac{\partial i_t}{\partial c_{t-1}} + i_t \frac{\partial \hat{c}_t}{\partial c_{t-1}} \qquad ct1ct=ft+ct1ct1ft+c^tct1it+itct1c^t

右端第一项 f t f_t ft,也就是我们所说的“遗忘门”,从下面的论述我们可以知道一般情况下其余三项都是次要项,因此 f t f_t ft 是“主项”,由于 f t f_t ft 在0~1之间,因此就意味着梯度爆炸的风险将会很小,至于会不会梯度消失,取决于 f t f_t ft 是否接近于1。但非常碰巧的是,这里有个相当自洽的结论:如果我们的任务比较依赖于历史信息,那么 f t f_t ft 就会接近于1,这时候历史的梯度信息也正好不容易消失;如果 f t f_t ft 很接近于0,那么就说明我们的任务不依赖于历史信息,这时候就算梯度消失也无妨了

所以,现在的关键就是看“其余三项都是次要项”这个结论能否成立。后面的三项都是“一项乘以另一项的偏导”的形式,而且求偏导的项都是 σ \sigma σ tanh ⁡ \tanh tanh激活, σ \sigma σ tanh ⁡ \tanh tanh的偏导公式基本上是等价的,它们的导数均可以用它们自身来表示:

tanh ⁡ x = 2 σ ( 2 x ) − 1 \tanh x = 2\sigma(2x) - 1 tanhx=2σ(2x)1
σ ( x ) = 1 2 ( tanh ⁡ x 2 + 1 ) \sigma(x) = \frac{1}{2} \left( \tanh \frac{x}{2} + 1 \right) \qquad σ(x)=21(tanh2x+1)
( tanh ⁡ x ) ′ = 1 − tanh ⁡ 2 x (\tanh x)' = 1 - \tanh^2 x (tanhx)=1tanh2x
σ ′ ( x ) = σ ( x ) ( 1 − σ ( x ) ) \sigma'(x) = \sigma(x) (1 - \sigma(x)) σ(x)=σ(x)(1σ(x))

其中 σ ( x ) = 1 / ( 1 + e − x ) \sigma(x) = 1/(1 + e^{-x}) σ(x)=1/(1+ex) 是sigmoid函数。

因此后面三项是类似的,分析了其中一项就相当于分析了其余两项。以第二项为例,代入 h t − 1 = o t − 1 tanh ⁡ ( c t − 1 ) h_{t-1} = o_{t-1} \tanh(c_{t-1}) ht1=ot1tanh(ct1),可以算得

c t − 1 ∂ f t ∂ c t − 1 = f t ( 1 − f t ) o t − 1 ( 1 − tanh ⁡ 2 c t − 1 ) c t − 1 U f c_{t-1} \frac{\partial f_t}{\partial c_{t-1}} = f_t (1 - f_t) o_{t-1} (1 - \tanh^2 c_{t-1}) c_{t-1} U_f \qquad ct1ct1ft=ft(1ft)ot1(1tanh2ct1)ct1Uf

注意到 f t , 1 − f t , o t − 1 f_t, 1 - f_t, o_{t-1} ft,1ft,ot1都是在0~1之间,也可以证明 ∣ ( 1 − tanh ⁡ 2 c t − 1 ) c t − 1 ∣ < 0.45 |(1 - \tanh^2 c_{t-1}) c_{t-1}| < 0.45 (1tanh2ct1)ct1<0.45,因此它也在-1~1之间。所以 c t − 1 ∂ f t ∂ c t − 1 c_{t-1} \frac{\partial f_t}{\partial c_{t-1}} ct1ct1ft就相当于1个 U f U_f Uf乘上4个门,结果会变得更加小,所以只要初始化不是很糟糕,那么它都会被压缩得相当小,因此占不到主导作用。

剩下两项的结论也是类似的:

c ^ t ∂ i t ∂ c t − 1 = i t ( 1 − i t ) o t − 1 ( 1 − tanh ⁡ 2 c t − 1 ) c ^ t U i \hat{c}_t \frac{\partial i_t}{\partial c_{t-1}} = i_t (1 - i_t) o_{t-1} (1 - \tanh^2 c_{t-1}) \hat{c}_t U_i \qquad c^tct1it=it(1it)ot1(1tanh2ct1)c^tUi

i t ∂ c ^ t ∂ c t − 1 = ( 1 − c ^ t 2 ) o t − 1 ( 1 − tanh ⁡ 2 c t − 1 ) i t U c i_t \frac{\partial \hat{c}_t}{\partial c_{t-1}} = (1 - \hat{c}_t^2) o_{t-1} (1 - \tanh^2 c_{t-1}) i_t U_c itct1c^t=(1c^t2)ot1(1tanh2ct1)itUc

所以,后面三项的梯度带有更多的“门”,一般而言乘起来后会被压缩的更厉害,因此占主导的项还是 f t f_t ft f t f_t ft 在0~1之间这个特性决定了它梯度爆炸的风险很小,同时 f t f_t ft 表明了模型对历史信息的依赖性,也正好是历史梯度的保留程度,两者相互自洽,所以LSTM也能较好地缓解梯度消失问题。因此,LSTM同时较好地缓解了梯度消失/爆炸问题,现在我们训练LSTM时,多数情况下只需要直接调用Adam等自适应学习率优化器,不需要人为对梯度做什么调整了。



http://www.ppmy.cn/devtools/160924.html

相关文章

DDD架构实战:用Java实现一个电商订单系统,快速掌握领域驱动设计

引言 你是否曾为复杂的业务逻辑感到头疼&#xff1f;是否在面对需求变更时感到无力&#xff1f;今天&#xff0c;我们将带你深入**领域驱动设计&#xff08;DDD&#xff09;**的世界&#xff0c;通过一个简单的电商订单系统实战项目&#xff0c;快速掌握DDD的核心思想与实现方…

vue-element-admin 打包部署到SpringBoot

更改vue里面vue.config.js 运行build命令 npm run build:prod 生成dist文件夹 打开你的springboot项目 复制static文件夹到 src/main/resources/ 并将index.html移动到templates(使用template) 更改index.html文件中导入地址 在colltroller层写一个控制器返回index.html i…

QT SQL框架及QSqlDatabase类

1、概述 本文对QT的SQL模块进行了整理&#xff0c;可供新同事参考&#xff0c;Qt SQL模块提供数据库编程的支持&#xff0c;MySQL、Oracle、MS SQL Server、SQlite等&#xff0c;作者未来的工作的其中一个接口将是QT接口。 Qt SQL模块包含多个类&#xff0c;实现数据库的连接…

C++17中的std::scoped_lock:简化多锁管理的利器

文章目录 1. 为什么需要std::scoped_lock1.1 死锁问题1.2 异常安全性1.3 锁的管理复杂性 2. std::scoped_lock的使用方法2.1 基本语法2.2 支持多种互斥锁类型2.3 自动处理异常 3. std::scoped_lock的优势3.1 避免死锁3.2 简化代码3.3 提供异常安全保证 4. 实际应用场景4.1 数据…

【Python爬虫(37)】解锁分布式爬虫:原理与架构全解析

【Python爬虫】专栏简介&#xff1a;本专栏是 Python 爬虫领域的集大成之作&#xff0c;共 100 章节。从 Python 基础语法、爬虫入门知识讲起&#xff0c;深入探讨反爬虫、多线程、分布式等进阶技术。以大量实例为支撑&#xff0c;覆盖网页、图片、音频等各类数据爬取&#xff…

计算机视觉算法实战——图像合成(主页有源码)

✨个人主页欢迎您的访问 ✨期待您的三连 ✨ ✨个人主页欢迎您的访问 ✨期待您的三连 ✨ ✨个人主页欢迎您的访问 ✨期待您的三连✨ ​ ✨✨1. 图像合成领域简介✨✨ 图像合成是计算机视觉中的一个重要研究方向&#xff0c;旨在通过算法生成或修改图像内容。图像合成技术广泛应…

DOS网络安全

ping -t 不间断地ping目标主机&#xff0c;直到用户用ctrlc键强行终止。经常用来排除网络故障 -l 定制ping信息包的容量,最大上限是65500字节 -n 向远程主机发送的数据 包个数&#xff0c;默认是4。 语法&#xff1a; ping 参数 IP地址 netstat -a 显示所有连接…