这不是一个二元分类,而是一个多类(数字)分类问题,我们有 10 个类标签。因此,您需要使用多项式 logit 而不是逻辑回归。试试下面,我们可以看到,多项式logit模型的整体预测准确率高于lda。
library(nnet)
model_mlogit <- multinom(V1 ~ ., data = training, MaxNWts=2581)
predict_mlogit <- predict(model_mlogit, testing)
library(MASS)
model_lda <- lda(V1 ~ ., data=training)
predict_lda <- predict(model_lda, testing)
library(caret)
confusionMatrix(predict_mlogit,testing$V1)
# output
Confusion Matrix and Statistics
Reference
Prediction 0 1 2 3 4 5 6 7 8 9
0 343 0 5 2 5 4 1 0 7 0
1 0 254 1 0 2 1 0 0 0 0
2 3 2 163 4 5 0 4 2 7 0
3 2 1 6 145 1 7 0 3 3 1
4 3 1 8 1 168 3 4 5 1 3
5 2 0 1 8 2 137 4 0 9 1
6 2 1 1 1 4 3 156 0 0 0
7 3 1 5 2 1 0 0 132 4 2
8 1 1 7 3 4 2 1 0 130 5
9 0 3 1 0 8 3 0 5 5 165
Overall Statistics
Accuracy : 0.8934
95% CI : (0.879, 0.9065)
No Information Rate : 0.1789
P-Value [Acc > NIR] : < 2.2e-16
Kappa : 0.8803
Mcnemar's Test P-Value : NA
Statistics by Class:
Class: 0 Class: 1 Class: 2 Class: 3 Class: 4 Class: 5 Class: 6 Class: 7 Class: 8 Class: 9
Sensitivity 0.9554 0.9621 0.82323 0.87349 0.84000 0.85625 0.91765 0.89796 0.78313 0.93220
Specificity 0.9854 0.9977 0.98507 0.98696 0.98395 0.98538 0.99347 0.99032 0.98696 0.98634
Pos Pred Value 0.9346 0.9845 0.85789 0.85799 0.85279 0.83537 0.92857 0.88000 0.84416 0.86842
Neg Pred Value 0.9902 0.9943 0.98074 0.98857 0.98232 0.98752 0.99239 0.99192 0.98057 0.99340
Prevalence 0.1789 0.1315 0.09865 0.08271 0.09965 0.07972 0.08470 0.07324 0.08271 0.08819
Detection Rate 0.1709 0.1266 0.08122 0.07225 0.08371 0.06826 0.07773 0.06577 0.06477 0.08221
Detection Prevalence 0.1829 0.1286 0.09467 0.08421 0.09816 0.08171 0.08371 0.07474 0.07673 0.09467
Balanced Accuracy 0.9704 0.9799 0.90415 0.93023 0.91198 0.92082 0.95556 0.94414 0.88505 0.95927
confusionMatrix(predict_lda$class,testing$V1)
#output
Confusion Matrix and Statistics
Reference
Prediction 0 1 2 3 4 5 6 7 8 9
0 342 0 7 3 1 6 1 0 5 0
1 0 251 2 0 4 0 0 1 0 0
2 0 0 157 3 6 0 3 0 2 0
3 4 2 4 142 0 16 0 2 11 0
4 3 5 12 3 174 3 3 7 7 4
5 1 0 2 9 0 125 3 0 4 0
6 5 3 1 0 2 0 157 0 0 0
7 0 0 1 1 2 0 0 129 0 5
8 3 1 12 4 1 5 3 1 135 3
9 1 2 0 1 10 5 0 7 2 165
Overall Statistics
Accuracy : 0.8854
95% CI : (0.8706, 0.899)
No Information Rate : 0.1789
P-Value [Acc > NIR] : < 2.2e-16
Kappa : 0.8713
Mcnemar's Test P-Value : NA
Statistics by Class:
Class: 0 Class: 1 Class: 2 Class: 3 Class: 4 Class: 5 Class: 6 Class: 7 Class: 8 Class: 9
Sensitivity 0.9526 0.9508 0.79293 0.85542 0.87000 0.78125 0.92353 0.87755 0.81325 0.93220
Specificity 0.9860 0.9960 0.99226 0.97882 0.97399 0.98971 0.99401 0.99516 0.98207 0.98470
Pos Pred Value 0.9370 0.9729 0.91813 0.78453 0.78733 0.86806 0.93452 0.93478 0.80357 0.85492
Neg Pred Value 0.9896 0.9926 0.97767 0.98686 0.98544 0.98121 0.99293 0.99037 0.98314 0.99338
Prevalence 0.1789 0.1315 0.09865 0.08271 0.09965 0.07972 0.08470 0.07324 0.08271 0.08819
Detection Rate 0.1704 0.1251 0.07823 0.07075 0.08670 0.06228 0.07823 0.06428 0.06726 0.08221
Detection Prevalence 0.1819 0.1286 0.08520 0.09018 0.11011 0.07175 0.08371 0.06876 0.08371 0.09616
Balanced Accuracy 0.9693 0.9734 0.89260 0.91712 0.92200 0.88548 0.95877 0.93636 0.89766 0.95845
[编辑]
没有caret
:
table(predict_mlogit,testing$V1)
# output
predict_mlogit 0 1 2 3 4 5 6 7 8 9
0 343 0 5 2 5 4 1 0 7 0
1 0 254 1 0 2 1 0 0 0 0
2 3 2 163 4 5 0 4 2 7 0
3 2 1 6 145 1 7 0 3 3 1
4 3 1 8 1 168 3 4 5 1 3
5 2 0 1 8 2 137 4 0 9 1
6 2 1 1 1 4 3 156 0 0 0
7 3 1 5 2 1 0 0 132 4 2
8 1 1 7 3 4 2 1 0 130 5
9 0 3 1 0 8 3 0 5 5 165
# accuracy
sum(predict_mlogit==testing$V1)/length(testing$V1)
# [1] 0.8933732
table(predict_lda$class,testing$V1)
# output
0 1 2 3 4 5 6 7 8 9
0 342 0 7 3 1 6 1 0 5 0
1 0 251 2 0 4 0 0 1 0 0
2 0 0 157 3 6 0 3 0 2 0
3 4 2 4 142 0 16 0 2 11 0
4 3 5 12 3 174 3 3 7 7 4
5 1 0 2 9 0 125 3 0 4 0
6 5 3 1 0 2 0 157 0 0 0
7 0 0 1 1 2 0 0 129 0 5
8 3 1 12 4 1 5 3 1 135 3
9 1 2 0 1 10 5 0 7 2 165
# accuracy
sum(predict_lda$class==testing$V1)/length(testing$V1)
# [1] 0.8854011