前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >如何计算 LSTM 的参数量

如何计算 LSTM 的参数量

作者头像
Alan Lee
发布2020-10-29 10:34:47
2.5K0
发布2020-10-29 10:34:47
举报
文章被收录于专栏:Small Code

理论上的参数量

之前翻译了 Christopher Olah 的那篇著名的 Understanding LSTM Networks,这篇文章对于整体理解 LSTM 很有帮助,但是在理解 LSTM 的参数数量这种细节方面,略有不足。本文就来补充一下,讲讲如何计算 LSTM 的参数数量。

建议阅读本文前先阅读 Understanding LSTM Networks 的原文或我的译文

首先来回顾下 LSTM。一层 LSTM 如下:

这里的xt​ 实际上是一个句子的 embedding(不考虑 batch 维度),shape 一般为 [seq_length, embedding_size]。图中的A 就是 cell,xt​ 中的词依次进入这个 cell 中进行处理。可以看到其实只有这么一个 cell,所以每次词进去处理的时候,权重是共享的,将这个过程平铺展开,就是下面这张图了:

实际上我觉得这里 x t x_t xt​ 并不准确,第一个 x t x_t xt​ 应该指的是整句话,而第二个 x t x_t xt​ 应该指的是这句话中最后一个词,所以为了避免歧义,我认为可以将第一个 x t x_t xt​ 重命名为 x x x,第二个仍然保留,即现在 x x x 表示一句话,该句话有 t + 1 t+1 t+1 个词, x t x_t xt​ 表示该句话的第 t + 1 t+1 t+1 个词, t ∈ [ 0 , t ] t \in [0, t] t∈[0,t]。

一个不那么小的数被多次相乘之后会变得很小,一个不那么大的数被多次相乘之后会变得很大。所以,这也是普通 RNN 容易出现梯度消失/爆炸的问题的原因

扯远了点。

代码语言:javascript
复制
(embedding_size + hidden_size) * hidden_size + hidden_size

一个 cell 有 4 个这样结构相同的网络,那么一个 cell 的总参数量就是直接 × 4:

代码语言:javascript
复制
((embedding_size + hidden_size) * hidden_size + hidden_size) * 4

注意这 4 个权重可不是共享的,都是独立的网络。

代码语言:javascript
复制
import tensorflow as tf

model = tf.keras.model.Sequential(
    tf.keras.layers.Embedding(1000, 128),
    tf.keras.layers.LSTM(units=64),
    tf.keras.layers.Dense(10)
)
model.summary()

输入如下:

代码语言:javascript
复制
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
embedding_1 (Embedding)      (None, None, 128)         128000    
_________________________________________________________________
lstm_1 (LSTM)                (None, 64)                49408     
_________________________________________________________________
dense_1 (Dense)              (None, 10)                650       
=================================================================
Total params: 178,058
Trainable params: 178,058
Non-trainable params: 0
_________________________________________________________________

代码语言:javascript
复制
inputs = tf.random.normal([64, 100, 128])  # [batch_size, seq_length, embedding_size]
whole_seq_output, final_memory_state, final_carry_state = tf.keras.layers.LSTM(64, return_sequences=True, return_state=True)(inputs)
print(f"{whole_seq_output.shape=}")
print(f"{final_memory_state.shape=}")
print(f"{final_carry_state.shape=}")

输出:

代码语言:javascript
复制
whole_seq_output.shape=TensorShape([32, 100, 64])  # 100 表示有 100 个词,即 100 个 time step
final_memory_state.shape=TensorShape([32, 64])
final_carry_state.shape=TensorShape([32, 64])

OK,LSTM 的参数量应该挺清晰了,欢迎在评论区留下你的想法。

Reference

END

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2020/10/24 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 理论上的参数量
  • Reference
  • END
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档