LSTM网络
LSTM网络
什么是LSTM?
LSTM(Long Short-Term Memory,长短期记忆网络)是一种特殊的循环神经网络(RNN)结构,由Hochreiter和Schmidhuber在1997年提出,专门设计用来解决传统RNN中的长期依赖问题。简单来说,它能够”记住”长序列中的重要信息,并”遗忘”不重要的信息,从而在处理长序列数据时表现出色。
LSTM的核心思想
传统RNN在处理长序列时会遇到梯度消失或梯度爆炸问题,导致无法有效学习长距离依赖关系。LSTM通过引入精心设计的内部结构解决了这一问题。
想象一下,当你阅读一本小说时,你的大脑会选择性地记住重要情节,而对一些细枝末节则可能会逐渐淡忘。LSTM网络的工作原理与此类似,它通过精心设计的”门”结构来控制信息的流动:
遗忘门(Forget Gate):决定哪些旧信息需要丢弃。这个门会输出一个0到1之间的值,0表示”完全遗忘”,1表示”完全保留”。
输入门(Input Gate):决定哪些新信息需要存储到细胞状态中。由两部分组成:sigmoid层决定更新哪些值,tanh层创建新的候选值。
细胞状态(Cell State):贯穿整个网络的记忆通道,允许信息在序列中长期传递。
输出门(Output Gate):决定哪些信息需要输出。根据细胞状态过滤信息,确定下一个隐藏状态。
结构图如下所示。
工作流程
遗忘阶段:LSTM首先决定从上一状态中丢弃哪些信息。遗忘门接收前一时刻的隐藏状态和当前输入,输出一个0到1之间的向量。
记忆阶段:接着,LSTM决定要在当前细胞状态中存储哪些新信息。这包括两个步骤:
输入门决定更新哪些值
创建新的候选值,可能会被添加到细胞状态中
更新阶段:旧的细胞状态与遗忘门相乘,然后加上输入门与候选值的乘积,完成细胞状态的更新。
输出阶段:最后,LSTM决定输出什么。输出基于细胞状态,但会经过过滤。先运行sigmoid层确定细胞状态的哪些部分会输出,然后将细胞状态通过tanh处理并与sigmoid输出相乘。
LSTM的优势
与传统RNN相比,LSTM具有以下优点:
- 能够学习长期依赖关系,避免梯度消失问题
- 可以有效处理时序数据中的长距离关系
- 对噪声有较强的鲁棒性
应用场景
LSTM广泛应用于:
- 自然语言处理(文本生成、机器翻译)
- 语音识别
- 时间序列预测
- 异常检测
简化的LSTM公式
LSTM的核心计算可以简化为以下几个步骤:
遗忘门:$$ f_t = σ(W_f·[h_{t-1}, x_t] + b_f) $$
输入门:$$ i_t = σ(W_i·[h_{t-1}, x_t] + b_i) $$
候选记忆:$$ \tilde{C}t = tanh(W_C·[h{t-1}, x_t] + b_C) $$
记忆更新:$$ C_t = f_t * C_{t-1} + i_t * C̃_t $$
输出门:$$ o_t = σ(W_o·[h_{t-1}, x_t] + b_o) $$
隐藏状态:$$ h_t = o_t * tanh(C_t) $$
σ是sigmoid函数,将值压缩到0-1之间
$tanh$是双曲正切函数,将值压缩到-1到1之间
表示元素乘法(Hadamard积)
$W$和$b$是可学习的权重和偏置参数
$h_{t-1}$是前一时间步的隐藏状态
$x_t$是当前时间步的输入
$C_{t-1}$是前一时间步的细胞状态