长短期记忆网络(Long Short-Term Memory)作为RNN的进阶架构,在序列建模领域具有里程碑意义。其核心突破在于通过智能门控系统,有效捕获跨越数百个时间步的语义关联,成功缓解了传统RNN存在的梯度消失/爆炸难题,在语音识别、金融预测等需要长程记忆的场景中表现卓越。
结构更复杂,核心结构可分四部分:
结构解释图:
类似传统RNN内部结构计算:
动态决定历史信息的保留比例,通过sigmoid函数输出0-1之间的遗忘系数。实际应用场景如在语言模型中自动遗忘不相关的主语信息。
帮助调节流经网络的值,sigmoid函数将值压缩在0和1之间。
输入门的计算公式有两个:
可实现新信息的选择性记忆,如在股票预测中精准捕捉突发市场信号。
没有全连接层,只是将刚得的遗忘门门值与上一个时间步得到的C(t-1)相乘,再加上输入门门值与当前时间步得到的未更新C(t)相乘。最终得到更新后的C(t)作为下一个时间步输入的一部分。
整个细胞状态更新过程就是对遗忘门和输入门的应用。可构建动态记忆高速公路,如医疗诊断场景中持续更新患者病史特征。
公式也两个,双阶段处理:
可智能生成当前时刻的特征表达,如在机器翻译中精准输出目标语言词汇。
双向LSTM,未改变LSTM本身任何的内部结构,只是将LSTM应用两次且方向不同,再将两次得到的LSTM结果进行拼接作为最终输出。
图中对"我爱中国"这句话或叫这个输入序列,进行从左到右、从右到左两次LSTM处理,将得到的结果张量拼接作为最终输出。
这种结构能捕捉语言语法中一些特定的前置或后置特征,增强语义关联,但模型参数和计算复杂度也随之增加一倍,一般需对语料和计算资源进行评估后,决定是否使用该结构。
可使用于:
如医疗文本分析中同时考虑症状描述和诊断结果。
特性 | 单向LSTM | 双向LSTM |
---|---|---|
参数数量 | 1x | 2x |
上下文感知 | 前向 | 全向 |
计算效率 | 高 | 中等 |
Pytorch中LSTM工具在torch.nn包,通过torch.nn.LSTM可调用。
# 定义LSTM的参数含义:
# (input_size,
# hidden_size,隐层维度
# num_layers) 堆叠3个LSTM层
# 定义输入张量的参数含义: (sequence_length, batch_size, input_size)
# 定义隐藏层初始张量和细胞初始状态张量的参数含义:
# (num_layers * num_directions, batch_size, hidden_size)
>>> import torch.nn as nn
>>> import torch
# 构建深度双向LSTM
>>> rnn = nn.LSTM(5, 6, 2)
# 三维输入:(序列长度,批大小,特征维度)
>>> input = torch.randn(1, 3, 5)
# 初始化记忆系统
>>> h0 = torch.randn(2, 3, 6) # (层数*方向数,批大小,隐层维度)
>>> c0 = torch.randn(2, 3, 6)
# 前向计算
>>> output, (hn, cn) = rnn(input, (h0, c0))
>>> output
tensor([[[ 0.0447, -0.0335, 0.1454, 0.0438, 0.0865, 0.0416],
[ 0.0105, 0.1923, 0.5507, -0.1742, 0.1569, -0.0548],
[-0.1186, 0.1835, -0.0022, -0.1388, -0.0877, -0.4007]]],
grad_fn=<StackBackward>)
>>> hn
tensor([[[ 0.4647, -0.2364, 0.0645, -0.3996, -0.0500, -0.0152],
[ 0.3852, 0.0704, 0.2103, -0.2524, 0.0243, 0.0477],
[ 0.2571, 0.0608, 0.2322, 0.1815, -0.0513, -0.0291]],
[[ 0.0447, -0.0335, 0.1454, 0.0438, 0.0865, 0.0416],
[ 0.0105, 0.1923, 0.5507, -0.1742, 0.1569, -0.0548],
[-0.1186, 0.1835, -0.0022, -0.1388, -0.0877, -0.4007]]],
grad_fn=<StackBackward>)
>>> cn
tensor([[[ 0.8083, -0.5500, 0.1009, -0.5806, -0.0668, -0.1161],
[ 0.7438, 0.0957, 0.5509, -0.7725, 0.0824, 0.0626],
[ 0.3131, 0.0920, 0.8359, 0.9187, -0.4826, -0.0717]],
[[ 0.1240, -0.0526, 0.3035, 0.1099, 0.5915, 0.0828],
[ 0.0203, 0.8367, 0.9832, -0.4454, 0.3917, -0.1983],
[-0.2976, 0.7764, -0.0074, -0.1965, -0.1343, -0.6683]]],
grad_fn=<StackBackward>)
门结构有效减缓长序列问题的梯度消失或爆炸,虽不能杜绝,但在更长的序列问题上表现优于传统RNN。
内部结构复杂,训练效率在同等算力下比传统RNN低很多。
通过深入理解LSTM的门控哲学,可构建更智能的时序模型,抢占AI应用先机。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。