前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
社区首页 >专栏 >LSTM:克服长期依赖难题的循环神经网络升级方案

LSTM:克服长期依赖难题的循环神经网络升级方案

原创
作者头像
JavaEdge
修改2025-03-20 10:01:38
修改2025-03-20 10:01:38
14400
代码可运行
举报
文章被收录于专栏:AI理论与前沿AI理论与前沿
运行总次数:0
代码可运行

1 简介

长短期记忆网络(Long Short-Term Memory)作为RNN的进阶架构,在序列建模领域具有里程碑意义。其核心突破在于通过智能门控系统,有效捕获跨越数百个时间步的语义关联,成功缓解了传统RNN存在的梯度消失/爆炸难题,在语音识别、金融预测等需要长程记忆的场景中表现卓越。

结构更复杂,核心结构可分四部分:

2 LSTM内部结构图

结构解释图:

2.1 遗忘门:智能记忆过滤器

结构图和计算公式
结构分析

类似传统RNN内部结构计算:

  • 先将当前时间步输入x(t)与上一个时间步隐含状态h(t-1)拼接,得到x(t), h(t-1)
  • 再通过一个全连接层做变换,最后通过sigmoid函数进行激活得到f(t),可将f(t)看作门值,好比一扇门开合的大小程度,门值都将作用在通过该扇门的张量。遗忘门门值将作用的上一层的细胞状态上,代表遗忘过去的多少信息,又因为遗忘门门值是由x(t), h(t-1)计算得来,因此整个公式意味着根据当前时间步输入和上一个时间步隐含状态h(t-1)来决定遗忘多少上一层的细胞状态所携带的过往信息

动态决定历史信息的保留比例,通过sigmoid函数输出0-1之间的遗忘系数。实际应用场景如在语言模型中自动遗忘不相关的主语信息。

过程演示
激活函数sigmiod

帮助调节流经网络的值,sigmoid函数将值压缩在0和1之间。

2.2 输入门:新知融合系统

结构图与计算公式
结构分析

输入门的计算公式有两个:

  • 产生输入门门值的公式, 它和遗忘门公式几乎相同, 区别只是在于它们之后要作用的目标上. 这个公式意味着输入信息有多少需要进行过滤
  • 与传统RNN的内部结构计算相同. 对于LSTM来讲, 它得到的是当前的细胞状态, 而不是像经典RNN一样得到的是隐含状态

可实现新信息的选择性记忆,如在股票预测中精准捕捉突发市场信号。

过程演示

2.3 细胞状态更新

结构图和计算公式
结构分析

没有全连接层,只是将刚得的遗忘门门值与上一个时间步得到的C(t-1)相乘,再加上输入门门值与当前时间步得到的未更新C(t)相乘。最终得到更新后的C(t)作为下一个时间步输入的一部分。

整个细胞状态更新过程就是对遗忘门和输入门的应用。可构建动态记忆高速公路,如医疗诊断场景中持续更新患者病史特征。

过程演示

2.4 输出门:信息蒸馏器

结构图和计算公式
结构分析

公式也两个,双阶段处理:

  • 计算输出门的门值,同遗忘门、输入门计算方式
  • 用这个门值产生隐含状态h(t),作用在更新后的细胞状态C(t)上,并做tanh激活,最终得到h(t)作为下一时间步输入的一部分。整个输出门的过程,就是为产生隐含状态h(t)

可智能生成当前时刻的特征表达,如在机器翻译中精准输出目标语言词汇。

过程演示

3 Bi-LSTM

双向LSTM,未改变LSTM本身任何的内部结构,只是将LSTM应用两次且方向不同,再将两次得到的LSTM结果进行拼接作为最终输出。

3.1 结构分析

图中对"我爱中国"这句话或叫这个输入序列,进行从左到右、从右到左两次LSTM处理,将得到的结果张量拼接作为最终输出。

这种结构能捕捉语言语法中一些特定的前置或后置特征,增强语义关联,但模型参数和计算复杂度也随之增加一倍,一般需对语料和计算资源进行评估后,决定是否使用该结构。

可使用于:

  • 正向LSTM捕捉历史依赖
  • 反向LSTM捕获未来特征

如医疗文本分析中同时考虑症状描述和诊断结果。

3.2 单向LSTM V.S 双向LSTM

特性

单向LSTM

双向LSTM

参数数量

1x

2x

上下文感知

前向

全向

计算效率

中等

4 工程实践

Pytorch中LSTM工具在torch.nn包,通过torch.nn.LSTM可调用。

4.1 nn.LSTM类初始化参数

  • input_size: 输入张量x中特征维度的大小
  • hidden_size: 隐层张量h中特征维度的大小
  • num_layers: 隐含层的数量
  • bidirectional: 是否选择使用双向LSTM, 如果为True, 则使用; 默认不使用

4.2 nn.LSTM类实例化对象参数

  • input: 输入张量x
  • h0: 初始化的隐层张量h
  • c0: 初始化的细胞状态张量c

4.3 nn.LSTM使用示例

代码语言:python
代码运行次数:0
运行
复制
# 定义LSTM的参数含义: 
# (input_size,
# 	hidden_size,隐层维度
# 		num_layers) 堆叠3个LSTM层

# 定义输入张量的参数含义: (sequence_length, batch_size, input_size)
# 定义隐藏层初始张量和细胞初始状态张量的参数含义:
# (num_layers * num_directions, batch_size, hidden_size)

>>> import torch.nn as nn
>>> import torch
# 构建深度双向LSTM
>>> rnn = nn.LSTM(5, 6, 2)
# 三维输入:(序列长度,批大小,特征维度)
>>> input = torch.randn(1, 3, 5)

# 初始化记忆系统
>>> h0 = torch.randn(2, 3, 6) # (层数*方向数,批大小,隐层维度)
>>> c0 = torch.randn(2, 3, 6)

# 前向计算
>>> output, (hn, cn) = rnn(input, (h0, c0))
>>> output
tensor([[[ 0.0447, -0.0335,  0.1454,  0.0438,  0.0865,  0.0416],
         [ 0.0105,  0.1923,  0.5507, -0.1742,  0.1569, -0.0548],
         [-0.1186,  0.1835, -0.0022, -0.1388, -0.0877, -0.4007]]],
       grad_fn=<StackBackward>)
>>> hn
tensor([[[ 0.4647, -0.2364,  0.0645, -0.3996, -0.0500, -0.0152],
         [ 0.3852,  0.0704,  0.2103, -0.2524,  0.0243,  0.0477],
         [ 0.2571,  0.0608,  0.2322,  0.1815, -0.0513, -0.0291]],

        [[ 0.0447, -0.0335,  0.1454,  0.0438,  0.0865,  0.0416],
         [ 0.0105,  0.1923,  0.5507, -0.1742,  0.1569, -0.0548],
         [-0.1186,  0.1835, -0.0022, -0.1388, -0.0877, -0.4007]]],
       grad_fn=<StackBackward>)
>>> cn
tensor([[[ 0.8083, -0.5500,  0.1009, -0.5806, -0.0668, -0.1161],
         [ 0.7438,  0.0957,  0.5509, -0.7725,  0.0824,  0.0626],
         [ 0.3131,  0.0920,  0.8359,  0.9187, -0.4826, -0.0717]],

        [[ 0.1240, -0.0526,  0.3035,  0.1099,  0.5915,  0.0828],
         [ 0.0203,  0.8367,  0.9832, -0.4454,  0.3917, -0.1983],
         [-0.2976,  0.7764, -0.0074, -0.1965, -0.1343, -0.6683]]],
       grad_fn=<StackBackward>)

4.4 参数调优要点

  • 隐层维度:一般设置为输入维度2-4倍
  • 深度堆叠:3-5层可获得较好收益
  • 双向选择:根据任务上下文需求决定

5 LSTM评价

5.1 优势

门结构有效减缓长序列问题的梯度消失或爆炸,虽不能杜绝,但在更长的序列问题上表现优于传统RNN。

5.2 缺点

内部结构复杂,训练效率在同等算力下比传统RNN低很多。

6 新趋势展望

通过深入理解LSTM的门控哲学,可构建更智能的时序模型,抢占AI应用先机。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1 简介
  • 2 LSTM内部结构图
    • 2.1 遗忘门:智能记忆过滤器
      • 结构图和计算公式
      • 结构分析
      • 过程演示
      • 激活函数sigmiod
    • 2.2 输入门:新知融合系统
      • 结构图与计算公式
      • 结构分析
      • 过程演示
    • 2.3 细胞状态更新
      • 结构图和计算公式
      • 结构分析
      • 过程演示
    • 2.4 输出门:信息蒸馏器
      • 结构图和计算公式
      • 结构分析
      • 过程演示
  • 3 Bi-LSTM
    • 3.1 结构分析
    • 3.2 单向LSTM V.S 双向LSTM
  • 4 工程实践
    • 4.1 nn.LSTM类初始化参数
    • 4.2 nn.LSTM类实例化对象参数
    • 4.3 nn.LSTM使用示例
    • 4.4 参数调优要点
  • 5 LSTM评价
    • 5.1 优势
    • 5.2 缺点
  • 6 新趋势展望
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档