我有一个多分类任务,我得到了n-hot类型的预测,比如
n_hot_prediction = [[0, 1, 1],
[0, 1, 0],
[1, 0, 1]]
和另一个top_k数组,如
top_k_prediction = [[1, 2],
[0, 1],
[0, 1]]
首先,我希望得到一个函数,它的工作原理如下:
tf.function1(n_hot_prediction) #output: [[1, 2], [1], [0, 2]]
其次,我找到了另一个函数,它的工作原理如下:
tf.function2(top_k_prediction) #output: [[0, 1, 1], [1, 1, 0], [1, 1, 0]]
有没有像tf.function1和tf.function2一样工作的函数或方法?
发布于 2020-07-01 12:10:49
您的第二个函数实现起来非常简单:
import tensorflow as tf
@tf.function
def multi_hot(x, depth=None):
x = tf.convert_to_tensor(x)
if depth is None:
depth = tf.math.reduce_max(x) + 1
r = tf.range(tf.dtypes.cast(depth, x.dtype))
eq = tf.equal(tf.expand_dims(x, axis=-1), r)
return tf.cast(tf.reduce_any(eq, axis=-2), x.dtype)
x = [[1, 2], [0, 1], [0, 1]]
tf.print(multi_hot(x))
# [[0 1 1]
# [1 1 0]
# [1 1 0]]
对于第一个,结果不是一个适当的张量,所以你可以改为制作一个粗糙的张量,屏蔽连续值的张量:
import tensorflow as tf
@tf.function
def as_labels(x):
mask = tf.dtypes.cast(x, tf.bool)
s = tf.shape(mask)
r = tf.reshape(tf.range(s[-1]), tf.concat([tf.ones(tf.rank(x) - 1, tf.int32), [-1]], axis=0))
r = tf.tile(r, tf.concat([s[:-1], [1]], axis=0))
return tf.ragged.boolean_mask(r, mask)
x = [[0, 1, 1], [0, 1, 0], [1, 0, 1]]
print(as_labels(x).to_list())
# [[1, 2], [1], [0, 2]]
https://stackoverflow.com/questions/62675441
复制相似问题