为什么RNNs/LSTM/GRUs的隐藏状态通常只在一个时代结束后才重新初始化,而不是在批处理完成后重新初始化?
发布于 2023-01-25 07:06:02
只有在使用截断反向传播(TBPTT)的情况下,RNN的隐藏状态才会被初始化。如果您没有使用TBPTT,那么您将在每个批处理中重置RNN隐藏状态。
在TBPTT训练方法中,梯度通过LSTM的隐藏状态在批处理的时间维上传播,然后在下一批中使用最后的隐藏状态作为LSTM的输入状态。这允许LSTM在训练时使用更长的上下文,同时限制梯度计算的后退步骤数。
当您将隐藏状态从一个批处理重用到下面的时候,您不希望每批重新设置它们。
通常,深度学习框架提供某种标志来具有“有状态”的隐藏状态,如果数据准备逻辑支持TBPTT,则以这种方式启用TBPTT (您需要确保连续批实际包含“后续数据”,这样TBPTT才有意义)。
我知道使用TBPTT是常见的两种情况:
训练集是序列列表,可能来自几个文档(LM)或完整的时间序列。在数据准备过程中,将创建批处理,以便批处理中的每个序列都是序列在前一批处理中相同位置上的延续。这允许在计算预测时具有文档级/长时间序列上下文。
在这些情况下,您的数据比批处理中的序列长度维度长。这可能是由于可用GPU内存中的限制(因此限制了最大批处理大小),或者是由于任何其他原因而设计的。
注意:我重用了我的另一个答案的一些部分,这些部分解释了Keras中stateful = True
标志的含义。
https://datascience.stackexchange.com/questions/118030
复制相似问题