首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >基于LCPN方法的分层分类

基于LCPN方法的分层分类
EN

Stack Overflow用户
提问于 2022-10-04 01:22:18
回答 1查看 73关注 0票数 0

目标:

我正在研究一个层次分类问题,并希望使用本地父节点分类器(使用Tensorflow的LCPN方法)来解决它。为了做到这一点,我必须创建基于层次数据集的本地分类器。

例如,,我手动为CIFAR-10数据集创建了一个分层树结构,该数据集遵循这个。等级结构如下:

基于这个结构,它需要一个,总共需要6个局部分类器

  1. 1级分类器:
    • 用于分类运输和动物。

  2. 2级分类器:
    • 1用于分类的类别-天空、水、道路(类别运输的子类)
    • 1用于分类鸟类、爬行动物、每种、中等类(动物亚纲)

  3. 3级分类器:
    • 1分类汽车和卡车(道路的子类)
    • 1用于分类猫和狗(宠物类的子类)
    • 1用于分类鹿和马(中等类的子类)

注意:,我想得到第3级(10个类)的所有预测。如果级别1的分类器输出级别2的类,该类在第3级中没有多个子类,则应该为该示例自动分配级别3中的相应类。例如:如果第一个分类器识别一个样本为传输,那么它将选择分类器来分类运输的子类(天空、水、道路)。如果二级分类器将该样本分类为天空,那么将不再需要另一个分类器来对子类进行分类,因为它只有一个子类,即类飞机。但是对于我的实现,我希望最终的预测是三级预测,输出是飞机。

执行情况:

为了实现这一点,到目前为止,我做了以下工作:

  1. 我已经使用treelib从数据集中确定了本地分类器的数量和类的数量。它决定了本地分类器所需的输出数。
  2. 我正在使用tf.data.Dataset.filter生成一个dataset管道,它将提供一个过滤的数据集来训练模型。因为我要用相关的样本来训练本地分类器。例如,用于确定级别1类传输的子类的分类器将使用级别1类传输下的所有类的样本进行培训。所以,我想过滤掉属于动物类或动物的任何亚类的样本。
  3. 在此之后,我必须实现一个决策树来从模型中进行预测。

现在,我正在为使用这种方法的实现而奋斗。对这种问题有什么更好的解决办法吗?或者有别的办法吗?

EN

回答 1

Stack Overflow用户

发布于 2022-10-04 05:06:37

当您在一个输入中创建一个包含10个样本的初始信息时,您可以这样做,并在每个层中捕获结果。

你知道分类问题,你可以通过找到最大值或最小值、范围或临界值来进行分类。

为什么笔记本在WiFi上选择局域网接口?速度和优先级,甚至其中一些比LAN更快。

样本:每次你吃赫拉奇的时候都要打印出来。

代码语言:javascript
复制
import tensorflow as tf

class MyDenseLayer(tf.keras.layers.Layer):
    def __init__(self, num_outputs):
        super(MyDenseLayer, self).__init__()
        self.num_outputs = num_outputs
        
    def build(self, input_shape):
        self.kernel = self.add_weight("kernel",
        shape=[int(input_shape[-1]),
        self.num_outputs])

    def call(self, inputs):
        return tf.matmul(inputs, self.kernel)


start = 3
limit = 33
delta = 3

# Create DATA
sample = tf.range(start, limit, delta)
sample = tf.cast( sample, dtype=tf.float32 )

# Initail, ( 10, 1 )
sample = tf.constant( sample, shape=( 10, 1 ) )
layer = MyDenseLayer(10)
data = layer(sample)

# Layer 1, ( 10, 2 )
layer = MyDenseLayer(2)
data = layer(data)

# Layer 2, ( 10, 7 )
layer = MyDenseLayer(7)
data = layer(data)

# Layer 3, ( 10, 10 )
layer = MyDenseLayer(10)
data = layer(data)

print( data )

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/73942276

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档