在使用ctree数据的for循环中,我有一些奇怪的问题。如果我在一个循环中编写这段代码,那么R就会冻结。
data = read.csv("train.csv") #data description https://www.kaggle.com/c/titanic-gettingStarted/data
treet = ctree(Survived ~ ., data = data)
print(plot(treet))
有时我会犯一个错误:“一个预测因子超过52个级别,被截断为打印输出”,而我的树以非常奇怪的方式显示出来。有时候效果很好。真的真的很奇怪!
我的循环代码:
functionPlot <- function(traine, i) {
print(i) # print only once, then RStudio freezes
tempd <- ctree(Survived ~ ., data = traine)
print(plot(tempd))
}
for(i in 1:2) {
smp_size <- floor(0.70 * nrow(data))
train_ind <- sample(seq_len(nrow(data)), size = smp_size)
set.seed(100 + i)
train <- data[train_ind, ]
test <- data[-train_ind, ]
#
functionPlot(train,i)
}
发布于 2015-04-21 02:38:51
ctree()
函数期望(a)适当的类(数值、因子等)对每个变量都使用,(b)模型公式中只使用有用的预测器。
至于(b),您提供的变量实际上只是字符(如Name
),而不是因素。这要么需要进行适当的预处理,要么需要从分析中省略。
即使不这样做,也不会得到最好的结果,因为有些变量(如Survived
和Pclass
)是数字编码的,但实际上是应该是因素的绝对变量。如果您查看来自https://www.kaggle.com/c/titanic/forums/t/13390/introducing-kaggle-scripts的脚本,那么您还将看到如何执行数据准备。在这里,我用
titanic <- read.csv("train.csv")
titanic$Survived <- factor(titanic$Survived,
levels = 0:1, labels = c("no", "yes"))
titanic$Pclass <- factor(titanic$Pclass)
titanic$Name <- as.character(titanic$Name)
至于(b),我接着用经过充分预处理的变量调用ctree()
,以便进行有意义的分析。(我使用了包partykit
中较新的推荐实现。)
library("partykit")
ct <- ctree(Survived ~ Pclass + Sex + Age + SibSp + Parch + Fare + Embarked,
data = titanic)
plot(ct)
print(ct)
这将产生以下图形输出:
以及下列打印输出:
Model formula:
Survived ~ Pclass + Sex + Age + SibSp + Parch + Fare + Embarked
Fitted party:
[1] root
| [2] Sex in female
| | [3] Pclass in 1, 2: yes (n = 170, err = 5.3%)
| | [4] Pclass in 3
| | | [5] Fare <= 23.25: yes (n = 117, err = 41.0%)
| | | [6] Fare > 23.25: no (n = 27, err = 11.1%)
| [7] Sex in male
| | [8] Pclass in 1
| | | [9] Age <= 52: no (n = 88, err = 43.2%)
| | | [10] Age > 52: no (n = 34, err = 20.6%)
| | [11] Pclass in 2, 3
| | | [12] Age <= 9
| | | | [13] Pclass in 3: no (n = 71, err = 18.3%)
| | | | [14] Pclass in 2: yes (n = 13, err = 30.8%)
| | | [15] Age > 9: no (n = 371, err = 11.3%)
Number of inner nodes: 7
Number of terminal nodes: 8
https://stackoverflow.com/questions/29755661
复制