在TensorFlow中,要获取索引的张量,可以使用tf.gather
,tf.gather_nd
或tf.boolean_mask
函数
tf.gather
:根据给定的索引从张量中收集值。索引必须在整数范围内。import tensorflow as tf
# 创建一个张量
tensor = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
# 指定索引
indices = tf.constant([0, 2])
# 使用tf.gather获取索引对应的值
gathered_tensor = tf.gather(tensor, indices)
# 输出gathered_tensor
print(gathered_tensor.numpy())
输出:
[[1 3]
[4 6]
[7 9]]
tf.gather_nd
:根据多维索引从张量中收集值。import tensorflow as tf
# 创建一个张量
tensor = tf.constant([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
# 指定多维索引
indices = tf.constant([[0, 0], [1, 1]])
# 使用tf.gather_nd获取索引对应的值
gathered_tensor = tf.gather_nd(tensor, indices)
# 输出gathered_tensor
print(gathered_tensor.numpy())
输出:
[1 8]
tf.boolean_mask
:根据布尔掩码过滤张量中的值。import tensorflow as tf
# 创建一个张量
tensor = tf.constant([1, 2, 3, 4, 5, 6])
# 创建一个布尔掩码
mask = tf.constant([True, False, True, False, True, False])
# 使用tf.boolean_mask过滤张量
filtered_tensor = tf.boolean_mask(tensor, mask)
# 输出filtered_tensor
print(filtered_tensor.numpy())
输出:
[1 3 5]
这些函数都可以根据给定的索引从张量中获取对应的值。您可以根据自己的需求选择合适的函数。
领取专属 10元无门槛券
手把手带您无忧上云