从一个Tensor
中获取多个相同大小的切片可以使用tf.split
函数。tf.split
函数可以将一个Tensor
沿着指定的维度切分成多个子张量。
函数原型如下:
tf.split(value, num_or_size_splits, axis=0, num=None, name='split')
参数解释:
value
:要切分的Tensor
。num_or_size_splits
:切分后的子张量数量或者每个子张量的大小。如果是一个整数,则表示切分后的子张量数量;如果是一个列表或元组,则表示每个子张量的大小。axis
:指定切分的维度。num
:切分后的子张量数量,与num_or_size_splits
参数作用相同,二者只需指定一个即可。name
:操作的名称。下面是一个示例代码:
import tensorflow as tf
# 创建一个形状为[6, 4]的Tensor
x = tf.constant([[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12],
[13, 14, 15, 16],
[17, 18, 19, 20],
[21, 22, 23, 24]])
# 沿着第一个维度将Tensor切分成两个子张量
slices = tf.split(x, num_or_size_splits=2, axis=0)
# 打印切分后的子张量
for i, slice in enumerate(slices):
print("Slice", i+1, ":", slice)
输出结果:
Slice 1 : tf.Tensor(
[[ 1 2 3 4]
[ 5 6 7 8]
[ 9 10 11 12]], shape=(3, 4), dtype=int32)
Slice 2 : tf.Tensor(
[[13 14 15 16]
[17 18 19 20]
[21 22 23 24]], shape=(3, 4), dtype=int32)
在这个示例中,我们创建了一个形状为[6, 4]的Tensor
,然后使用tf.split
函数将其沿着第一个维度切分成两个子张量。最后,我们打印出切分后的子张量。
推荐的腾讯云相关产品和产品介绍链接地址:
领取专属 10元无门槛券
手把手带您无忧上云