RNN基本原理及梯度消失、梯度爆炸的问题原因及解决方法

您所在的位置:网站首页 ggg是什么意思呢 RNN基本原理及梯度消失、梯度爆炸的问题原因及解决方法

RNN基本原理及梯度消失、梯度爆炸的问题原因及解决方法

2023-09-22 22:35| 来源: 网络整理| 查看: 265

一、RNN基本结构

在这里插入图片描述

1、隐层状态 s t s_t st​

s t = σ ( U x t + W s t − 1 + b 1 ) s_t=\sigma(Ux_t+Ws_{t-1}+b_1) st​=σ(Uxt​+Wst−1​+b1​) σ \sigma σ()是激活函数,通常选用Tanh、ReLU。

2、输出状态 o t o_t ot​

o t = g ( V s t + b 2 ) o_t=g(Vs_t+b_2) ot​=g(Vst​+b2​) g g g()是激活函数,对于分类任务通常选用 s i g m o i d sigmoid sigmoid()。

3、Loss计算

输出状态 o t o_t ot​与目标输出 y t y_t yt​计算Loss: L = ∑ t L t = ∑ t L o s s ( o t , y t ) L=\sum_{t}L_t=\sum_{t}Loss(o_t,y_t) L=t∑​Lt​=t∑​Loss(ot​,yt​) L o s s Loss Loss是损失函数,对于分类任务通常选用交叉熵损失函数。

二、RNN参数更新方式 1、首先需要明确:上述的循环重复结构,都是共享参数的,也就是说不管在什么时刻,权重矩阵 U U U、 W W W、 V V V都是相同的。

好处:极大减少参数量+可以处理不定长序列。

2、梯度下降、反向传播过程

假设 t = 3 t=3 t=3的时刻,计算它的损失函数: s 3 = σ ( U x 3 + W s 2 + b 1 ) o 3 = g ( V s 3 + b 2 ) L 3 = 1 2 ( o 3 − y 3 ) 2 s_3=\sigma(Ux_3+Ws_{2}+b_1) \\ o_3=g(Vs_3+b_2) \\ L_3=\frac{1}{2}(o_3-y_3)^2 s3​=σ(Ux3​+Ws2​+b1​)o3​=g(Vs3​+b2​)L3​=21​(o3​−y3​)2那么求偏导的时候: ∂ L 3 ∂ V = ∂ L 3 ∂ o 3 ∂ o 3 ∂ V \frac{ \partial L_3 }{ \partial V}=\frac{ \partial L_3 }{ \partial o_3}\frac{ \partial o_3 }{ \partial V} ∂V∂L3​​=∂o3​∂L3​​∂V∂o3​​ ∂ L 3 ∂ U = ∂ L 3 ∂ o 3 ∂ o 3 ∂ s 3 ∂ s 3 ∂ U + ∂ L 3 ∂ o 3 ∂ o 3 ∂ s 3 ∂ s 3 ∂ s 2 ∂ s 2 ∂ U + ∂ L 3 ∂ o 3 ∂ o 3 ∂ s 3 ∂ s 3 ∂ s 2 ∂ s 2 ∂ s 1 ∂ s 1 ∂ U \frac{ \partial L_3 }{ \partial U}=\frac{ \partial L_3 }{ \partial o_3}\frac{ \partial o_3 }{ \partial s_3} \frac{ \partial s_3 }{ \partial U}+\frac{ \partial L_3 }{ \partial o_3}\frac{ \partial o_3 }{ \partial s_3} \frac{ \partial s_3 }{ \partial s_2}\frac{ \partial s_2 }{ \partial U}+\frac{ \partial L_3 }{ \partial o_3}\frac{ \partial o_3 }{ \partial s_3} \frac{ \partial s_3 }{ \partial s_2}\frac{ \partial s_2 }{ \partial s_1}\frac{ \partial s_1 }{ \partial U} ∂U∂L3​​=∂o3​∂L3​​∂s3​∂o3​​∂U∂s3​​+∂o3​∂L3​​∂s3​∂o3​​∂s2​∂s3​​∂U∂s2​​+∂o3​∂L3​​∂s3​∂o3​​∂s2​∂s3​​∂s1​∂s2​​∂U∂s1​​ ∂ L 3 ∂ W = ∂ L 3 ∂ o 3 ∂ o 3 ∂ s 3 ∂ s 3 ∂ W + ∂ L 3 ∂ o 3 ∂ o 3 ∂ s 3 ∂ s 3 ∂ s 2 ∂ s 2 ∂ W + ∂ L 3 ∂ o 3 ∂ o 3 ∂ s 3 ∂ s 3 ∂ s 2 ∂ s 2 ∂ s 1 ∂ s 1 ∂ W \frac{ \partial L_3 }{ \partial W}=\frac{ \partial L_3 }{ \partial o_3}\frac{ \partial o_3 }{ \partial s_3} \frac{ \partial s_3 }{ \partial W}+\frac{ \partial L_3 }{ \partial o_3}\frac{ \partial o_3 }{ \partial s_3} \frac{ \partial s_3 }{ \partial s_2}\frac{ \partial s_2 }{ \partial W}+\frac{ \partial L_3 }{ \partial o_3}\frac{ \partial o_3 }{ \partial s_3} \frac{ \partial s_3 }{ \partial s_2}\frac{ \partial s_2 }{ \partial s_1}\frac{ \partial s_1 }{ \partial W} ∂W∂L3​​=∂o3​∂L3​​∂s3​∂o3​​∂W∂s3​​+∂o3​∂L3​​∂s3​∂o3​​∂s2​∂s3​​∂W∂s2​​+∂o3​∂L3​​∂s3​∂o3​​∂s2​∂s3​​∂s1​∂s2​​∂W∂s1​​因为 s 3 s_3 s3​是由前面的 s 1 s_1 s1​、 s 2 s_2 s2​递推出来的,所以 L L L对 U U U、 W W W求偏导的公式需要把前面的 s 1 s_1 s1​、 s 2 s_2 s2​带入进去: s 3 = σ ( U x 3 + W s 2 + b 1 ) = σ ( U x 3 + W ( σ ( U x 2 + W s 1 + b 1 ) ) + b 1 ) s_3=\sigma(Ux_3+Ws_{2}+b_1)\\ =\sigma(Ux_3+W(\sigma(Ux_2+Ws_{1}+b_1))+b_1) s3​=σ(Ux3​+Ws2​+b1​)=σ(Ux3​+W(σ(Ux2​+Ws1​+b1​))+b1​)由此能知道,时间序列越长,出现连乘的部分会越集中在后面。也就是通过时间的反向传播。

三、RNN和普通神经网络梯度消失的本质区别

普通神经网络:它不是按时间步进行反向传播的,因此不会有一项一项相加的部分,只有一个总体的连乘求偏导过程。它的梯度消失是总的梯度会趋于0的。 RNN:每一项一项进行相加,可以发现距离拉的越长,连乘的项就越多,远距离的梯度会趋于0的,近距离的梯度不会消失。RNN梯度消失的真正含义是总的梯度受近距离梯度的主导,远距离的梯度消失。

四、RNN梯度消失梯度爆炸及解决方式

梯度爆炸:采用梯度截断的方式 梯度消失:1、采用跨时域的残差连接 。 2、采用门控机制(LSTM、GRU)作为RNN基本单元,控制信息流入量



【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3