我真的很喜欢至少在建模的早期阶段使用插入符号,特别是因为它非常容易使用重采样方法。但是,我正在研究一个模型,其中训练集通过半监督自我训练添加了相当多的案例,因此我的交叉验证结果确实存在偏差。我对此的解决方案是使用验证集来衡量模型性能,但我看不到直接在插入符号中使用验证集的方法 - 我是否遗漏了某些东西或者这只是不受支持?我知道我可以编写自己的包装器来完成插入符号通常会为 m 做的事情,但是如果有一种解决方法而不必这样做,那就太好了。
这是我正在经历的一个简单的例子:
> library(caret)
> set.seed(1)
>
> #training/validation sets
> i <- sample(150,50)
> train <- iris[-i,]
> valid <- iris[i,]
>
> #make my model
> tc <- trainControl(method="cv")
> model.rf <- train(Species ~ ., data=train,method="rf",trControl=tc)
>
> #model parameters are selected using CV results...
> model.rf
100 samples
4 predictors
3 classes: 'setosa', 'versicolor', 'virginica'
No pre-processing
Resampling: Cross-Validation (10 fold)
Summary of sample sizes: 90, 90, 90, 89, 90, 92, ...
Resampling results across tuning parameters:
mtry Accuracy Kappa Accuracy SD Kappa SD
2 0.971 0.956 0.0469 0.0717
3 0.971 0.956 0.0469 0.0717
4 0.971 0.956 0.0469 0.0717
Accuracy was used to select the optimal model using the largest value.
The final value used for the model was mtry = 2.
>
> #have to manually check validation set
> valid.pred <- predict(model.rf,valid)
> table(valid.pred,valid$Species)
valid.pred setosa versicolor virginica
setosa 17 0 0
versicolor 0 20 1
virginica 0 2 10
> mean(valid.pred==valid$Species)
[1] 0.94
summaryFunction()
我最初认为我可以通过为对象创建自定义来做到这一点,trainControl()
但我看不到如何引用我的模型对象以从验证集中获取预测(文档 - http://caret.r-forge.r-project.org /training.html - 仅列出“data”、“lev”和“model”作为可能的参数)。例如,这显然行不通:
tc$summaryFunction <- function(data, lev = NULL, model = NULL){
data.frame(Accuracy=mean(predict(<model object>,valid)==valid$Species))
}
编辑:为了想出一个真正丑陋的修复,我一直在寻找是否可以从另一个函数的范围访问模型对象,但我什至没有看到它们模型存储在任何地方。希望有一些优雅的解决方案,我什至没有接近看到......
> tc$summaryFunction <- function(data, lev = NULL, model = NULL){
+ browser()
+ data.frame(Accuracy=mean(predict(model,valid)==valid$Species))
+ }
> train(Species ~ ., data=train,method="rf",trControl=tc)
note: only 1 unique complexity parameters in default grid. Truncating the grid to 1 .
Called from: trControl$summaryFunction(testOutput, classLevels, method)
Browse[1]> lapply(sys.frames(),function(x) ls(envi=x))
[[1]]
[1] "x"
[[2]]
[1] "cons" "contrasts" "data" "form" "m" "na.action" "subset"
[8] "Terms" "w" "weights" "x" "xint" "y"
[[3]]
[1] "x"
[[4]]
[1] "classLevels" "funcCall" "maximize" "method" "metric" "modelInfo"
[7] "modelType" "paramCols" "ppMethods" "preProcess" "startTime" "testOutput"
[13] "trainData" "trainInfo" "trControl" "tuneGrid" "tuneLength" "weights"
[19] "x" "y"
[[5]]
[1] "data" "lev" "model"