1

语境

caret用来拟合和调整模型。通常,使用交叉验证等重采样方法可以找到最佳参数。一旦选择了最佳参数,最终模型将使用最佳参数集拟合到整个训练数据。

除了要调整的参数(通过 传递tuneGrid)之外,还可以通过将参数传递给被调用的底层算法来传递参数train

我的问题

有没有办法指定仅用于最终模型的特定于模型的选项?

为了更加清楚:我确实想拟合所有中间模型(以获得可靠的性能估计),但我想用不同的参数拟合最终模型(除了最佳参数)。

具体用例

假设我想拟合bartMachine一些数据,然后在生产中使用最终模型。我通常会将调整后的模型保存到磁盘并根据需要加载它。但我只能保存/加载一个已经序列化的 bartMachine 模型,即我需要传递serialize=TbartMachinevia caret::train

但这将使所有模型序列化,这是非常不切实际的。我真的只需要序列化最终模型。有没有办法做到这一点?

library("caret")
library("bartMachine")
tgrid <- expand.grid(num_trees = 100,
                       k = c(2, 3),
                       alpha = 0.95, 
                       beta = 2,
                       nu =  3)
# The printed log shows that all intermediate models are being serialized
fit <- train(hp ~ ., 
             data=mtcars, 
             method="bartMachine",
             serialize=T,
             tuneGrid=tgrid,
             trControl = trainControl(method="cv", 5, verboseIter=T))
4

1 回答 1

1

要在不调整参数或重新采样的情况下将模型拟合到整个数据集,请将列车控制方法修改为无:

tgrid <- expand.grid(num_trees = 100,
                     k = 2,
                     alpha = 0.95, 
                     beta = 2,
                     nu =  3)
fit <- train(hp ~ ., 
             data=mtcars, 
             method="bartMachine",
             serialize=TRUE,
             tuneGrid=tgrid,
             trControl = trainControl(method="none"))

请注意,我已删除问题代码中的两个 k 值之一。否则会报错:Only one model should be specified in tuneGrid with no resampling。我建议用另一个 k 值构建一个单独的模型。

上面的代码给出了以下输出:

bartMachine initializing with 100 trees...
bartMachine vars checked...
bartMachine java init...
bartMachine factors created...
bartMachine before preprocess...
bartMachine after preprocess... 11 total features...
bartMachine sigsq estimated...
bartMachine training data finalized...
Now building bartMachine for regression ...
building BART with mem-cache speedup...
Iteration 100/1250  mem: 17.6/477.1MB
Iteration 200/1250  mem: 25.1/477.1MB
Iteration 300/1250  mem: 30.8/477.1MB
Iteration 400/1250  mem: 39.9/477.1MB
Iteration 500/1250  mem: 19/477.1MB
Iteration 600/1250  mem: 59.6/477.1MB
Iteration 700/1250  mem: 39.6/477.1MB
Iteration 800/1250  mem: 79.8/477.1MB
Iteration 900/1250  mem: 119.9/477.1MB
Iteration 1000/1250  mem: 40.7/477.1MB
Iteration 1100/1250  mem: 80.8/477.1MB
Iteration 1200/1250  mem: 121/477.1MB
done building BART in 1.289 sec 

burning and aggregating chains from all threads... done
evaluating in sample data...done
serializing in order to be saved for future R sessions...done

serialize 参数设置为 TRUE 在fit$finalModel

fit$finalModel$serialize
[1] TRUE

对于它的价值,bartMachine 内部 check_serialization 函数不会给出任何警告或错误(或任何其他输出):

bartMachine:::check_serialization(fit$finalModel)

我不清楚如何从fit$finalModel. 我认为它存储在fit$finalModel$java_bart_machine其中包含一个 rJava 指针。使用 bartMachine 所依赖的 rJava 包可能会获得更深入的了解。

更新:@antoine-sac 在下面的评论中声明“serialize=T 不会导致模型被保存,而是将样本序列化到模型中,这意味着它们在模型写入磁盘时被保存”。

于 2018-11-20T12:54:06.720 回答