首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
社区首页 >问答首页 >使用tf.scatter_nd使Keras 'None‘批处理大小保持不变

使用tf.scatter_nd使Keras 'None‘批处理大小保持不变
EN

Stack Overflow用户
提问于 2020-10-04 01:28:54
回答 1查看 157关注 0票数 2

我需要向LSTM解码器输入一个池模块,我正在使用一个自定义层来构造它,其中编码器LSTM状态和Keras输入层作为输入。在这个自定义层中,我需要将更新分散到索引中:

代码语言:javascript
运行
AI代码解释
复制
updates: <tf.Tensor --- shape=(None, 225, 5, 32) dtype=float32>
indices: <tf.Tensor --- shape=(None, 225) dtype=int32>

使用tf.scatter_nd创建shape=张量(None,960,5,32),如下所示:

代码语言:javascript
运行
AI代码解释
复制
tf.scatter_nd(tf.expand_dims(indices, 2), updates, shape=[None, 960, 5, 32])

但问题是,这样做会产生错误,这是由于NoneType的形状,我不想在其中声明batch_size,因为它是一个Keras层,只有在学习过程中才是确定的。在这种情况下,代码的工作版本如下:

代码语言:javascript
运行
AI代码解释
复制
tf.scatter_nd(tf.expand_dims(indices, 2), updates, shape=[960, 5, 32])
        >>> <tf.Tensor 'ScatterNd_4:0' shape=(960, 5, 32) dtype=float32>

它忽略了输出中的batch_size。是否有任何替代方法来构造所需的输出张量而不是tf.scatter_nd,或者使其正常工作?

EN

回答 1

Stack Overflow用户

发布于 2022-11-22 00:20:43

我在tf.scatter_nd手术方面也有类似的问题。我通过使用tf.shape(input)[0]在运行时推断批大小来解决这个问题。因此,在您的示例中,以下代码应该可以工作:

代码语言:javascript
运行
AI代码解释
复制
bs = tf.shape(indices)[0]
tf.scatter_nd(tf.expand_dims(indices, 2), updates, shape=[bs, 960, 5, 32])
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/64193001

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档
查看详情【社区公告】 技术创作特训营有奖征文