首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >如何在tensorflow中将n热向量转换为多标签?

如何在tensorflow中将n热向量转换为多标签?
EN

Stack Overflow用户
提问于 2020-07-01 19:06:12
回答 1查看 61关注 0票数 1

我有一个多分类任务,我得到了n-hot类型的预测,比如

代码语言:javascript
运行
复制
n_hot_prediction = [[0, 1, 1],
                    [0, 1, 0],
                    [1, 0, 1]]

和另一个top_k数组,如

代码语言:javascript
运行
复制
top_k_prediction = [[1, 2],
                    [0, 1],
                    [0, 1]]

首先,我希望得到一个函数,它的工作原理如下:

代码语言:javascript
运行
复制
tf.function1(n_hot_prediction)  #output: [[1, 2], [1], [0, 2]]

其次,我找到了另一个函数,它的工作原理如下:

代码语言:javascript
运行
复制
tf.function2(top_k_prediction) #output: [[0, 1, 1], [1, 1, 0], [1, 1, 0]]

有没有像tf.function1和tf.function2一样工作的函数或方法?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-07-01 20:10:49

您的第二个函数实现起来非常简单:

代码语言:javascript
运行
复制
import tensorflow as tf

@tf.function
def multi_hot(x, depth=None):
    x = tf.convert_to_tensor(x)
    if depth is None:
        depth = tf.math.reduce_max(x) + 1
    r = tf.range(tf.dtypes.cast(depth, x.dtype))
    eq = tf.equal(tf.expand_dims(x, axis=-1), r)
    return tf.cast(tf.reduce_any(eq, axis=-2), x.dtype)

x = [[1, 2], [0, 1], [0, 1]]
tf.print(multi_hot(x))
# [[0 1 1]
#  [1 1 0]
#  [1 1 0]]

对于第一个,结果不是一个适当的张量,所以你可以改为制作一个粗糙的张量,屏蔽连续值的张量:

代码语言:javascript
运行
复制
import tensorflow as tf

@tf.function
def as_labels(x):
    mask = tf.dtypes.cast(x, tf.bool)
    s = tf.shape(mask)
    r = tf.reshape(tf.range(s[-1]), tf.concat([tf.ones(tf.rank(x) - 1, tf.int32), [-1]], axis=0))
    r = tf.tile(r, tf.concat([s[:-1], [1]], axis=0))
    return tf.ragged.boolean_mask(r, mask)

x = [[0, 1, 1], [0, 1, 0], [1, 0, 1]]
print(as_labels(x).to_list())
# [[1, 2], [1], [0, 2]]
票数 3
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/62675441

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档