我刚开始尝试 R 包 mlr,我想知道是否可以自定义训练集和测试集。例如,一个时间序列的数据除了最后一个都是训练集,最后一个是测试集。
这是我的例子:
library(mlr)
library(survival)
data(lung)
myData2 <- lung %>%
select(time,status,age)
myData2$status = (myData2$status == 2)
myTrain <- c(1:(nrow(myData2)-1))
myTest <- nrow(myData2)
肺数据来自生存包。我只使用三个维度:时间、地位和年龄。现在,让我们假设它们并不意味着患者的年龄和他们可以存活多长时间。假设这是一位客户的墨水购买历史。
age=74 表示该客户当天购买了 74 瓶墨水,time=306 表示客户在 306 天后用完墨水。所以,我想使用除最后一行之外的所有数据建立一个生存模型。然后,当我有最后一行的数据,即age=58,表示客户当天购买了58瓶墨水时,我可以按时做出预测。接近 177 的数字将是一个很好的估计。所以,我的训练集和测试集是固定的,不需要重新采样。
此外,我需要更改超参数以进行比较。这是我的代码:
surv.task <- makeSurvTask(data=myData2,target=c('time','status'))
surv.lrn <- makeLearner("surv.cforest")
ps <- makeParamSet(
makeDiscreteParam('mincriterion',values=c(1.281552,2,3)),
makeDiscreteParam('ntree',values=c(100,200,300))
)
ctrl <- makeTuneControlGrid()
rdesc <- makeResampleDesc('Holdout',split=1,predict='train')
lrn = makeTuneWrapper(surv.lrn,control=ctrl,resampling=rdesc,par.set=ps,
measures = list(setAggregation(cindex,train.mean)))
mod <- train(learner=lrn,task=surv.task,subset=myTrain)
surv.pred <- predict(mod,task=surv.task,subset=myTest)
surv.pred
你可以看到我使用split=1
inmakeResampleDesc
因为我有固定的训练集,不需要重新采样。措施makeTuneWrapper
目前对我没有意义,因为我需要自定义自己的措施。由于固定的数据拆分,当使用不同的超参数时,我无法使用resample
或等函数tuneParams
对测试数据进行评估。
所以,我的问题是:当训练集和测试集固定时,mlr 能否为每个超参数提供全面的比较?如果是这样,该怎么做?
顺便说一句,看起来有makeFixedHoldoutInstance
可能做到这一点的功能,只是不知道如何使用它。比如我这样使用makeFixedHoldoutInstance
,就得到了这样的错误信息:
> f <- makeFixedHoldoutInstance(train.inds=myTrain,test.inds=myTest,size=length(myTrain)+1)
> lrn = makeTuneWrapper(surv.lrn,control=ctrl,resampling=f,par.set=ps)
> resample(learner=lrn,task=surv.task,resampling=f)
[Resample] holdout iter 1: [Tune] Started tuning learner surv.cforest for parameter set:
Type len Def Constr Req Tunable Trafo
mincriterion discrete - - 1.281552,2,3 - TRUE -
ntree discrete - - 100,200,300 - TRUE -
With control class: TuneControlGrid
Imputation value: -0
[Tune-x] 1: mincriterion=1.281552; ntree=100
Error in resample.fun(learner2, task, resampling, measures = measures, :
Size of data set: 227 and resampling instance: 228 differ!