在从 R 中的一类 svm 进行预测时,我试图找出概率输出。我知道这不被支持libsvm
,我也知道这个问题已经在几年前被问过,但是包是当时不可用。我希望现在情况有所改变!此外,这个问题仍然有效,因为没有给出在 R 中实现的方法作为解决方案。
我找不到执行此操作的软件包,因此我自己尝试了两种方法来解决此问题:
- 获取决策值并通过使用 sigmoid 激活函数对其进行转换。这在本文中有所描述。请注意以下段落:
此外,SVM 还可以生成类概率作为输出而不是类标签。这可以通过 Platt 的后验概率 (Platt 2000) 的改进实现 (Lin, Lin, and Weng 2001) 来完成,其中将 sigmoid 函数拟合到二元 SVM 分类器的决策值 f,通过最小化来估计 A 和 B负对数似然函数
我的问题是,为了检查我的两个解决方案中的任何一个是否合理,我在一个二类 svm 问题上测试了这两种方法,因为e1071
使用libsvm
给出了二类问题的概率,因此这被视为“真相”。我发现我的两种方法都没有与libsvm
.
这是三个图表,显示了结果概率与已知决策值的关系。 点击查看图片。抱歉,我的声誉似乎太低,无法嵌入令人沮丧的图像!我不确定社区中声誉较高的人是否可以编辑嵌入?
我认为我的 Platt 方法在理论上更合理,但从图中可以看出,逻辑回归似乎在某种程度上太好了,与任一分类相关的概率都非常接近 1(正数)和 0(负数)。
我的 Platt 实现代码是
platt_scale <- function(oc_svm, X){
# Get SVM predictions
y_pred <- predict(oc_svm$best.model,X)
#y_pred <- as.factor(ifelse(y_pred==T,"pos","neg"))
# Train using logistic regression with cross-validation
require(caret)
model <- train(x = X,
y = y_pred,
method = "glm",
family=binomial(),
trControl = trainControl(method = "cv",
number = 5),
control = list(maxit = 50) #BROUGHT IN TO STOP WARNING MESSAGES
)
return(predict(model,
newdata = X,
type = "prob")[,1])
}
运行时我收到以下警告
glm.fit: fitted probabilities numerically 0 or 1 occurred
所以我显然做错了什么!我觉得修复这个功能可能是最好的方法,但我看不出我哪里出错了?我正在遵循我之前提到的方法,here
我得到如下决策值的 sigmoid
sig_mult <-e1071::sigmoid(decision_values)
示例是使用 Iris 数据集完成的,完整代码在这里
data(iris)
two_class<-iris[iris$Species %in% c("setosa","versicolor"),]
#Make Two-class SVM
svm_mult<-e1071::tune(svm,
train.x = two_class[,1:4],
train.y = factor(two_class[,5],levels=c("setosa", "versicolor")),
type="C-classification",
kernel="radial",
gamma=0.05,
cost=1,
probability = T,
tunecontrol = tune.control(cross = 5))
#Get related decision values
dec_vals_mult <-attr(predict(svm_mult$best.model,
two_class[,1:4],
decision.values = T #use decision values to get score
), "decision.values")
#Get related probabilities
prob_mult <-attr(predict(svm_mult$best.model,
two_class[,1:4],
probability = T #use decision values to get score
), "probabilities")[,1]
#transform decision values using sigmoid
sig_mult <-e1071::sigmoid(dec_vals_mult)
#Use Platt Implementation function to derive probabilities
platt_imp<-platt_scale(svm_mult,two_class[,1:4])
require(ggplot2)
data2<-as.data.frame(cbind(dec_vals_mult,sig_mult))
names(data2)<-c("Decision.Values","Sigmoid.Decision.Values(Prob)")
sig<-ggplot(data=data2,aes(x=Decision.Values,
y=`Sigmoid.Decision.Values(Prob)`,
colour=ifelse(Decision.Values<0,"neg","pos")))+
geom_point()+
ylim(0,1)+
theme(legend.position = "none")
data3<-as.data.frame(cbind(dec_vals_mult,prob_mult))
names(data3)<-c("Decision.Values","Probabilities")
actual<-ggplot(data=data3,aes(x=Decision.Values,
y=Probabilities,
colour=ifelse(Decision.Values<0,"neg","pos")))+
geom_point()+
ylim(0,1)+
theme(legend.position = "none")
data4<-as.data.frame(cbind(dec_vals_mult,platt_imp))
names(data4)<-c("Decision.Values","Platt")
plat_imp<-ggplot(data=data4,aes(x=Decision.Values,
y=Platt,
colour=ifelse(Decision.Values<0,"neg","pos")))+
geom_point()+
ylim(0,1)
require(ggpubr)
ggarrange(actual, plat_imp, sig,
labels = c("Actual", "Platt Implementation", "Sigmoid Transformation"),
ncol = 3,
label.x = -.05,
label.y = 1.001,
font.label = list(size = 8.5, color = "black", face = "bold", family = NULL),
common.legend = TRUE, legend = "bottom")