我们在当前项目中使用的决策树使用条件推理(C树)算法。我可以使用下面的代码提取二进制c-tree的拆分变量:
#develop ctree decision tree
prod_discount_data_ctree <- ctree(Discount~Prod, data=prod_discount_data, controls = ctree_control(minsplit=30))
plot(prod_discount_data_ctree)
#extract the left and right terminal node split rule
lvls <- levels(prod_discount_data_ctree@tree$psplit$splitpoint)
#left leaf node split variable
left.df = lvls[prod_discount_data_ctree@tree$psplit$splitpoint == 1]
#right leaf node split variable
right.df = lvls[prod_discount_data_ctree@tree$psplit$splitpoint == 0]
如果树只有一个节点(深度= 1),并将其拆分为2个叶节点,则可以很好地工作。但是,如果树有一个节点(节点1),它拆分成多个节点(节点2,5),这些节点又进一步拆分成叶节点(节点2{3,4}节点5{6,7}),我应该如何更深入地遍历并获得叶节点拆分变量?根据示例,我希望节点3、4、6、7以4个列表的形式拆分变量。
发布于 2015-04-09 03:14:42
我尝试了所有可能的选项,最终找到了一种在C树中遍历的方法,并获得每个叶节点的拆分变量。如果任何人想在将来参考,请粘贴代码片段。
if (nrow(SubBrandright_total) > 200) {
sec_discount_data <- subset(SubBrandright_total, select=c(Discount,Sector))
sec_discount_data_ctree <- ctree(Discount~Sector, data=sec_discount_data, controls = ctree_control(minsplit=30))
sec_lvls_r <- levels(sec_discount_data_ctree@tree$psplit$splitpoint)
#Testing if the node is terminal [TRUE] or not [FALSE]
#print(sec_discount_data_ctree@tree$terminal)
#print(sec_discount_data_ctree@tree$left$terminal)
#print(sec_discount_data_ctree@tree$left$left$terminal)
#print(sec_discount_data_ctree@tree$left$right$terminal)
sec_left_left.df = sec_lvls_r[sec_discount_data_ctree@tree$left$psplit$splitpoint == 1]
sec_left.df = sec_lvls_r[sec_discount_data_ctree@tree$psplit$splitpoint == 1]
#Using setdiff to get right leaf node from Node minus left leaf node
sec_left_right.df = setdiff(sec_left.df,sec_left_left.df)
print("Sector Segmentation")
print(sec_left_left.df)
print(sec_left_right.df)
sec_right.df = sec_lvls_r[sec_discount_data_ctree@tree$psplit$splitpoint == 0]
sec_right_right.df = sec_lvls_r[sec_discount_data_ctree@tree$right$psplit$splitpoint == 0]
#Using setdiff to get left leaf node from Node minus right leaf node
sec_right_left.df = setdiff(sec_right.df,sec_right_right.df)
print(sec_right_left.df)
print(sec_right_right.df)
}
https://stackoverflow.com/questions/29525208
复制相似问题