Deep Learning - Lecture 10 Gated Recurrent Networks
- 门控循环网络
- 门控循环单元(Gated Recurrent Unit,GRU)
- 长短期记忆单元(LSTM recurrent unit)
- 梯度消失和爆炸
- 软件中的门循环网络
- 总结
- 引用
本节目标:
- 解释门控循环网络的关键要素。
- 设计并实现门控循环网络。
门控循环网络
训练简单的循环网络可能需要在多个时间步上进行求导计算,而梯度往往会出现消失或爆炸的情况!
这是因为在反向传播过程中,梯度是通过链式法则从后往前计算的。对于循环网络,随着时间步的增加,梯度在反向传播时会不断乘以权重矩阵。
- 梯度消失:如果权重矩阵的元素值较小,经过多个时间步的连乘后,梯度会越来越小,趋近于零。这就导致较早期时间步的参数更新非常缓慢,甚至几乎无法更新,使得网络难以学习到长序列中的依赖关系。
- 梯度爆炸:与梯度消失相反,如果权重矩阵的元素值较大,经过多个时间步的连乘后,梯度会越来越大,变得非常不稳定,甚至可能导致参数值变得极大而无法收敛。在这种情况下,网络的训练过程会变得难以控制,无法正常学习。
所以本节的重点是门控循环网络,这类网络能够缓解梯度消失和梯度爆炸的问题。
回顾:简单循环网络(Simple Recurrent Networks,RNNs)的训练相关内容
当循环网络在时间维度上展开时,它类似于非常深的前馈网络,每个时间步就相当于前馈网络的一层。
(还记得下面这个图不?)
状态和输出方程
- 状态更新方程: x t = tanh ( A x t − 1 + B u t ) x_t = \tanh(Ax_{t - 1} + Bu_t) xt=tanh(Axt−1+But)。其中, x t x_t xt是 t t t时刻的隐藏状态; x t − 1 x_{t - 1} xt−1是上一时刻( t − 1 t-1 t−1时刻)的隐藏状态; A A A和 B B B是权重矩阵; u t u_t ut是 t t t时刻的输入; tanh \tanh tanh是激活函数,用于给隐藏状态引入非线性。
- 输出方程: y ^ t = g ( C x t ) \hat{y}_t = g(Cx_t) y^t=g(Cxt)。这里 y ^ t \hat{y}_t y^t是 t t t时刻的输出, C C C是权重矩阵, g g g为激活函数,其形式会根据具体任务(如回归或分类)而不同 。
- 二分类问题: g可以是sigmoid等。
- 多分类问题:g可以是softmax等。
- 回归问题,如果线性的话可以直接用恒等函数,如果需要引入非线性,那么可以用Relu。
损失函数
- 公式为 J ( θ ) = ∑ t = 1 N L t ( y ^ t , y t ) J(\theta) = \sum_{t = 1}^{N} L_t(\hat{y}_t, y_t) J(θ)=∑t=1NLt(y^t,yt) 。其中, J ( θ ) J(\theta) J(θ)表示总损失, θ \theta θ代表网络中的所有参数; N N N是时间步的总数; L t ( y ^ t , y t ) L_t(\hat{y}_t, y_t) Lt(y^t,yt)是 t t t时刻模型预测值 y ^ t \hat{y}_t y^t与真实值 y t y_t yt之间的损失。
- 损失函数可以根据需求选择均方误差(用于回归任务)或分类交叉熵(用于分类任务),并且总损失是将各个时间步上的损失进行求和得到。(具体内容可以看上一节。)
门控循环单元(Gated Recurrent Unit,GRU)
GRU对简单循环模型进行了改进,引入了可学习的“门”机制,这些门在状态方程中控制信号的流动,以此缓解简单循环网络中梯度消失或爆炸的问题,增强网络处理长序列数据的能力。
GRU的状态和输出方程
- 状态更新方程: x t = z t ⊙ x t − 1 + ( 1 − z t ) ⊙ tanh ( A x ( r t ⊙ x t − 1 ) + B x u t ) x_t = z_t \odot x_{t - 1} + (1 - z_t) \odot \tanh(A_x(r_t \odot x_{t - 1}) + B_xu_t) xt=zt⊙xt−1+(1−zt)⊙tanh(Ax(rt⊙xt−1)+Bxut)
其中, x t x_t xt是 t t t时刻的隐藏状态, x t − 1 x_{t - 1} xt−1是上一时刻的隐藏状态, u t u_t ut是 t t t时刻的输入; z t z_t zt是更新门(Update gate), r t r_t rt是重置门(Reset gate); ⊙ \odot ⊙表示元素相乘; A x A_x Ax、 B x B_x Bx、是权重矩阵。(下面的 A z A_z Az、 B z B_z Bz、 A r A_r Ar、 B r B_r Br同理) - 门方程
- 更新门方程: z t = σ ( A z x t − 1 + B z u t ) z_t = \sigma(A_zx_{t - 1} + B_zu_t) zt=σ(Azxt−1+Bzut) ,更新门决定了上一时刻的隐藏状态有多少信息要保留到当前时刻。
- 重置门方程: r t = σ ( A r x t − 1 + B r u t ) r_t = \sigma(A_rx_{t - 1} + B_ru_t) rt=σ(Arxt−1+Brut) ,重置门决定了对上一时刻隐藏状态的遗忘程度。
- 这里的门函数 σ \sigma σ是sigmoid函数,其输出值范围在0到1之间。当 σ = 0 \sigma = 0 σ=0时,门关闭,信号无法通过;当 σ = 1 \sigma = 1 σ=1时,门打开,信号可以顺利通过。
- 与简单RNN状态方程 x t = tanh ( A x t − 1 + B u t ) x_t = \tanh(Ax_{t - 1} + Bu_t) xt=tanh(Axt−1+But)相比,GRU的状态更新更加复杂且灵活。
- 输出方程: y ^ t = g ( C x t ) \hat{y}_t = g(Cx_t) y^t=g(Cxt),和简单RNN类似。
门控循环单元中的“门”
在GRU中,门函数 σ \sigma σ表示的是逻辑斯蒂sigmoid激活函数。因为是sigmoid函数,所以每个门的值都在0到1这个范围内。其中, σ = 0 \sigma = 0 σ=0表示门完全关闭, σ = 1 \sigma = 1 σ=1表示门完全打开 。
如下是逻辑斯蒂sigmoid函数的图像,横坐标表示输入值(图中示例为 A z x t − 1 + B z u t A_zx_{t - 1} + B_zu_t Azxt−1+Bzut ),纵坐标表示函数的输出值(即门的值)。当输入值趋于负无穷时,函数输出趋近于0,对应“Gate closed”(门关闭);当输入值趋于正无穷时,函数输出趋近于1,对应“Gate open”(门打开) 。
门控循环单元(GRU)保持状态值的能力
GRU可以学习在内存中长时间保存任意时间步数的状态值 x t x_t xt 。
因为GRU的状态更新方程 x t = z t ⊙ x t − 1 + ( 1 − z t ) ⊙ tanh ( A x ( r t ⊙ x t − 1 ) + B x u t ) x_t = z_t \odot x_{t - 1} + (1 - z_t) \odot \tanh(A_x(r_t \odot x_{t - 1}) + B_xu_t) xt=zt⊙xt−1+(1−zt)⊙tanh(Ax(rt⊙xt−1)+Bxut)和输出方程 y ^ t = g ( C x t ) \hat{y}_t = g(Cx_t) y^t=g(Cxt) 。
所以:
- 当 z t = 1 z_t = 1 zt=1时, x t = x t − 1 x_t = x_{t - 1} xt=xt−1 ,此时GRU能将上一时刻的状态值完整保留到当前时刻,体现了GRU长时间记忆状态的能力。
- 当 r t = 1 r_t = 1 rt=1且 z t = 0 z_t = 0 zt=0时, x t = tanh ( A x x t − 1 + B x u t ) x_t = \tanh(A_xx_{t - 1} + B_xu_t) xt=tanh(Axxt−1+Bxut) ,此时模型会综合考虑上一时刻隐藏状态和当前输入来更新状态。
- 当 r t = 0 r_t = 0 rt=0且 z t = 0 z_t = 0 zt=0时, x t = tanh ( B x u t ) x_t = \tanh(B_xu_t) xt=tanh(Bxut) ,此时模型基本忽略上一时刻隐藏状态,主要依据当前输入更新状态 。
门控循环单元的运行方式与简单循环单元类似 。
示意图如下:
长短期记忆单元(LSTM recurrent unit)
LSTM单元是GRU(门控循环单元)的一种替代方案,LSTM具有内部细胞状态 s t s_t st,这个状态能够在内存中长时间保存信息。
LSTM 单元示意图
- LSTM状态和输出方程:
- 状态更新方程 x t = o t ⊙ tanh ( s t ) x_t = o_t \odot \tanh(s_t) xt=ot⊙tanh(st),其中 x t x_t xt是 t t t时刻的隐藏状态, o t o_t ot是输出门, s t s_t st是内部细胞状态, ⊙ \odot ⊙表示元素相乘, tanh \tanh tanh是激活函数。
- 输出方程 y ^ t = g ( C x t ) \hat{y}_t = g(Cx_t) y^t=g(Cxt), y ^ t \hat{y}_t y^t是 t t t时刻的输出, C C C是权重矩阵, g g g为激活函数,根据具体任务而定。
- 内部细胞状态方程: s t = f t ⊙ s t − 1 + i t ⊙ tanh ( A s x t − 1 + B s u t ) s_t = f_t \odot s_{t - 1} + i_t \odot \tanh(A_sx_{t - 1} + B_su_t) st=ft⊙st−1+it⊙tanh(Asxt−1+Bsut) 。这里, s t s_t st和 s t − 1 s_{t - 1} st−1分别是当前时刻和上一时刻的内部细胞状态, f t f_t ft是遗忘门, i t i_t it是输入门, A s A_s As、 B s B_s Bs是权重矩阵, u t u_t ut是 t t t时刻的输入。
- 门方程:
- 遗忘门方程 f t = σ ( A f x t − 1 + B f u t ) f_t = \sigma(A_fx_{t - 1} + B_fu_t) ft=σ(Afxt−1+Bfut) ,用于控制细胞状态 s t s_t st是被保留还是被遗忘,决定从上一时刻细胞状态中丢弃哪些信息, σ \sigma σ是sigmoid激活函数。
- 输入门方程 i t = σ ( A i x t − 1 + B i u t ) i_t = \sigma(A_ix_{t - 1} + B_iu_t) it=σ(Aixt−1+Biut) ,控制新的输入是否会影响细胞状态 s t s_t st,即当前输入信息有多少被存入细胞状态。
- 输出门方程 o t = σ ( A o x t − 1 + B o u t ) o_t = \sigma(A_ox_{t - 1} + B_ou_t) ot=σ(Aoxt−1+Bout) ,决定细胞状态是否以及多少被作为当前隐藏状态输出。
所以,根据如上门方程,显而易见的:
- 当 f t = 1 f_t = 1 ft=1 且 i t = 0 i_t = 0 it=0 时, s t = s t − 1 s_t = s_{t - 1} st=st−1 ,此时LSTM单元可以将上一时刻的内部细胞状态完整保留到当前时刻,体现了LSTM长时间记忆内部细胞状态的能力。
- 当 f t = 0 f_t = 0 ft=0 且 i t = 1 i_t = 1 it=1 时, s t = tanh ( A s x t − 1 + B s u t ) s_t = \tanh(A_sx_{t - 1} + B_su_t) st=tanh(Asxt−1+Bsut) ,此时LSTM单元基本丢弃上一时刻的内部细胞状态,主要依据当前输入更新内部细胞状态。
LSMT的特点
刚刚我们也说了:
- LSTM具有内部细胞状态 s t s_t st。
- 这个状态能够在内存中长时间保存信息。
- 所以LSTM单元能够学习在内存中长时间保存内部状态值 s t s_t st,并且可以保持任意时间步数。
那GRU也有遗忘门啊,也能一直保留信息啊,他们有什么区别?
因为与 LSTM 的对比,他们的结构复杂度不同:
GRU 只有两个门,结构相对简单,计算效率更高。但在处理复杂的长期依赖关系时,其门控机制的灵活性不如 LSTM。
因为GRU 没有专门的遗忘门来精确控制信息的遗忘,他在大多数情况下无法像 LSTM 那样精细地管理内部状态。
(所以干脆不要内部状态了,反正也管理不了。)
长短期记忆网络单元(LSTM)和门控循环单元(GRU)的应用
长短期记忆网络单元(LSTM)和门控循环单元(GRU)可以直接替代简单循环网络的状态方程,而无需对外部连接进行修改。
(还记得之前的时序展开图吗?这回终于明白哪些小绿块里面是什么样子的了。都是GRU或者LSTM。)
梯度消失和爆炸
在简单循环神经网络中进行自动求导的反向传播时,梯度可能会出现消失或爆炸的情况。
- 损失函数的导数: ∂ L t ∂ A = ∑ k = 1 t ∂ L t ∂ y t ∂ y t ∂ x t ∂ x t ∂ x k ∂ x k ∂ A \frac{\partial L_t}{\partial A} = \sum_{k = 1}^{t} \frac{\partial L_t}{\partial y_t} \frac{\partial y_t}{\partial x_t} \frac{\partial x_t}{\partial x_k} \frac{\partial x_k}{\partial A} ∂A∂Lt=k=1∑t∂yt∂Lt∂xt∂yt∂xk∂xt∂A∂xk此公式用于计算损失函数 L t L_t Lt关于权重矩阵 A A A的导数。其中 x t = A x t − 1 + B u t x_t = Ax_{t - 1} + Bu_t xt=Axt−1+But ,该式表明在反向传播过程中,梯度是多个偏导数项的乘积和,体现了误差从当前时间步向之前时间步传递的过程。
- 关键问题项: ∂ x t ∂ x k = ∂ x t ∂ x t − 1 ∂ x t − 1 ∂ x t − 2 ⋯ ∂ x k + 1 ∂ x k = ∏ t ≥ i > k ∂ x i ∂ x i − 1 = ∏ t ≥ i > k A \frac{\partial x_t}{\partial x_k} = \frac{\partial x_t}{\partial x_{t - 1}} \frac{\partial x_{t - 1}}{\partial x_{t - 2}} \cdots \frac{\partial x_{k + 1}}{\partial x_k} = \prod_{t \geq i > k} \frac{\partial x_i}{\partial x_{i - 1}} = \prod_{t \geq i > k} A ∂xk∂xt=∂xt−1∂xt∂xt−2∂xt−1⋯∂xk∂xk+1=t≥i>k∏∂xi−1∂xi=t≥i>k∏A 这一项展示了隐藏状态的变化率在不同时间步之间的连乘关系,是导致梯度问题的关键。因为每次计算都涉及到权重矩阵 A A A,其连乘结果对梯度有重要影响。
- 矩阵 A A A的特征分解: A A A进行特征分解为 A = Q Λ Q T A = Q\Lambda Q^T A=QΛQT( Λ \Lambda Λ包含特征值),那么 ∏ t ≥ i > k A = ∏ t ≥ i > k Q Λ Q T = Q ( ∏ t ≥ i > k Λ ) Q T = Q Λ ( t − k ) Q T \prod_{t \geq i > k} A = \prod_{t \geq i > k} Q\Lambda Q^T = Q(\prod_{t \geq i > k} \Lambda)Q^T = Q\Lambda^{(t - k)}Q^T t≥i>k∏A=t≥i>k∏QΛQT=Q(t≥i>k∏Λ)QT=QΛ(t−k)QT通过特征分解,将连乘问题转化为特征值的连乘问题,便于分析梯度情况。
梯度问题说明
- 梯度消失:如果特征值矩阵 Λ \Lambda Λ中的元素都小于1,在时间步 t t t和 k k k间隔较大时,经过多次连乘, Λ ( t − k ) \Lambda^{(t - k)} Λ(t−k)会趋近于0,导致梯度在反向传播过程中越来越小,最终趋近于0,即发生梯度消失问题。
- 梯度爆炸:反之,如果 Λ \Lambda Λ中的元素都大于1,随着时间步的增加, Λ ( t − k ) \Lambda^{(t - k)} Λ(t−k)会迅速增大,使得梯度在反向传播中变得非常大,从而出现梯度爆炸问题。
门控单元能够防止梯度消失和梯度爆炸
门控单元可以学习保持状态不变,这是解决梯度问题的关键。
- 损失函数的导数:与上面类似: ∂ L t ∂ A = ∑ k = 1 t ∂ L t ∂ y t ∂ y t ∂ x t ∂ x t ∂ x k ∂ x k ∂ A \frac{\partial L_t}{\partial A} = \sum_{k = 1}^{t} \frac{\partial L_t}{\partial y_t} \frac{\partial y_t}{\partial x_t} \frac{\partial x_t}{\partial x_k} \frac{\partial x_k}{\partial A} ∂A∂Lt=k=1∑t∂yt∂Lt∂xt∂yt∂xk∂xt∂A∂xk
还是损失函数 L t L_t Lt关于权重矩阵 A A A的导数 - 简单RNN的问题项:对于简单循环神经网络(RNN) ∂ x t ∂ x k = ∂ x t ∂ x t − 1 ∂ x t − 1 ∂ x t − 2 ⋯ ∂ x k + 1 ∂ x k = ∏ t ≥ i > k ∂ x i ∂ x i − 1 = ∏ t ≥ i > k A \frac{\partial x_t}{\partial x_k} = \frac{\partial x_t}{\partial x_{t - 1}} \frac{\partial x_{t - 1}}{\partial x_{t - 2}} \cdots \frac{\partial x_{k + 1}}{\partial x_k} = \prod_{t \geq i > k} \frac{\partial x_i}{\partial x_{i - 1}} = \prod_{t \geq i > k} A ∂xk∂xt=∂xt−1∂xt∂xt−2∂xt−1⋯∂xk∂xk+1=t≥i>k∏∂xi−1∂xi=t≥i>k∏A 这个连乘项 ∏ t ≥ i > k A \prod_{t \geq i > k} A ∏t≥i>kA是导致梯度消失或爆炸的关键,因为权重矩阵 A A A连乘的结果会使梯度出现异常变化。
- GRU的情况:当门控循环单元(GRU)中的更新门 z t = 1 z_t = 1 zt=1时,状态 x t = x t − 1 x_t = x_{t - 1} xt=xt−1 ,此时 ∂ x t ∂ x k = ∏ t ≥ i > k ∂ x i ∂ x i − 1 = ∏ t ≥ i > k I \frac{\partial x_t}{\partial x_k} = \prod_{t \geq i > k} \frac{\partial x_i}{\partial x_{i - 1}} = \prod_{t \geq i > k} I ∂xk∂xt=t≥i>k∏∂xi−1∂xi=t≥i>k∏I 这里 I I I是单位矩阵。由于单位矩阵连乘结果仍为单位矩阵,使得梯度在反向传播中不会因为连乘而出现消失或爆炸的情况。
- 在实际的GRU中,更新门 z t z_t zt 并不总是等于1 。GRU通过门控机制动态地控制信息的流动和状态的更新, z t z_t zt 的取值范围是0到1之间的实数。当 z t z_t zt 不等于1时,GRU的状态更新会综合考虑上一时刻的状态和当前输入的信息,此时 ∂ x t ∂ x k \frac{\partial x_t}{\partial x_k} ∂xk∂xt 的计算就不是简单的单位矩阵连乘。
不过,即便 z t z_t zt 不等于1 ,GRU依然能缓解梯度消失和爆炸问题,原因如下:
- 门控机制的调节作用:除了更新门 z t z_t zt ,GRU还有重置门 r t r_t rt ,它们协同工作来控制状态更新。更新门 z t z_t zt 决定上一时刻状态被保留的程度以及当前输入对状态的影响程度;重置门 r t r_t rt 控制上一时刻状态被遗忘的程度。这种精细的控制使得GRU在不同时间步之间传递梯度时,不会像简单RNN那样出现权重矩阵多次连乘导致梯度异常变化的情况。
- 非线性变换与信息整合:GRU在状态更新过程中会进行一系列非线性变换和信息整合操作。通过门控机制对输入和历史状态信息进行筛选和融合,使得信息的传递更加稳定。即使在更新门 z t z_t zt 不等于1的情况下,这些操作也能在一定程度上保持梯度的合理范围,避免梯度消失或爆炸。
软件中的门循环网络
MATLAB示例代码
%MATLAB 有专门的长短期记忆网络(LSTM)层和门控循环单元(GRU)层。
lstmLayer(numHiddenUnits,'OutputMode','last') % output mode last or sequence
gruLayer(numHiddenUnits,'OutputMode','last') % GRU layer% Define LSTM network for sequence-to-label classification
numFeatures = 12; % number of input features
numHiddenUnits = 100; % number of hidden units
numClasses = 9; % number of classes
layers = [ ...sequenceInputLayer(numFeatures) % Sequence input layer lstmLayer(numHiddenUnits) % Define LSTM layer fullyConnectedLayer(numClasses) % Fully connected layer softmaxLayer % Softmax layerclassificationLayer]; % Classification output layer
Python示例代码
Keras - TensorFlow 有专门用于长短期记忆网络(LSTM)和门控循环单元(GRU)的层。
我们可以在Keras - TensorFlow中使用顺序结构来添加长短期记忆网络(LSTM)层。
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import LSTM
from keras.layers.embeddings import Embedding
model = Sequential()
model.add(LSTM(128))
model.add(Dense(1, activation='sigmoid'))
# 对于GRU层
from keras.layers import GRU
model.add(GRU(128))# 模型训练和之前一样。
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
model.fit(X_train, Y_train, epochs=10, batch_size=32)
总结
- 简单循环神经网络容易出现梯度消失和梯度爆炸问题。
- 基于长短期记忆网络(LSTM)和门控循环单元(GRU)的门控循环网络能够缓解梯度消失/爆炸问题。
- 门控循环单元(GRU)和长短期记忆网络(LSTM)被广泛应用,而简单循环神经网络(RNN)由于梯度消失/爆炸问题,使用频率较低 。
引用
- (门控循环单元)Cho, K., Van Merriënboer, B., Gulcehre, C., Bahdanau, D., Bougares, F., Schwenk, H., & Bengio, Y. (2014). Learning phrase representations using RNN encoder-decoder for statistical machine translation. arXiv preprint arXiv:1406.1078
- (长短期记忆单元)Hochreiter, S., & Schmidhuber, J. (1997). Long short-term memory. Neural Computation, 9(8), 1735-1780.
- (梯度消失或梯度爆炸)Pascanu, R., Mikolov, T., & Bengio, Y. (2013). On the difficulty of training recurrent neural networks. In Proc ICML (pp. 1310-1318).