1

我正在尝试在这里处理一些数据并比较 glm 和 lda 的测试性能。

数据附在此处。

这是我尝试做这两个的总体计划:

training = read.csv("train.csv")
testing = read.csv("test.csv")

model_glm <- glm(V1 ~.,family=binomial(link='logit'),data=training)
pred_glm <- predict(model_glm, testing)

library(MASS)
model_lda <- lda(V1 ~ ., data=training)
predict_lda <- predict(model_lda, testing)

#Calculating classification error
err_lda <- (pred_lda) - test$V1
err2_lda <- err_lda[err_lda != 0]
classification_error_lda = length(err2_lda)/length(test$V1)

然而这些都不起作用。我以为有一个多项式家庭课程,但似乎不存在。另外,由于我的第一列是数字,而下一列都是灰度值,我认为我这样做V1 ~ .了,但我认为这对于这些情况也不正确。有人知道我的语法/设置是否错误吗?

编辑:我添加了我如何尝试计算 LDA 的分类错误。但是我不认为我原来的东西有效,因为它给出了:

(pred_lda) 中的错误 - test$V1:二元运算符的非数字参数

4

1 回答 1

1

这不是一个二元分类,而是一个多类(数字)分类问题,我们有 10 个类标签。因此,您需要使用多项式 logit 而不是逻辑回归。试试下面,我们可以看到,多项式logit模型的整体预测准确率高于lda。

library(nnet)
model_mlogit <- multinom(V1 ~ ., data = training, MaxNWts=2581)
predict_mlogit <- predict(model_mlogit, testing)
library(MASS)
model_lda <- lda(V1 ~ ., data=training)
predict_lda <- predict(model_lda, testing)
library(caret)
confusionMatrix(predict_mlogit,testing$V1)
# output 
Confusion Matrix and Statistics

          Reference
Prediction   0   1   2   3   4   5   6   7   8   9
         0 343   0   5   2   5   4   1   0   7   0
         1   0 254   1   0   2   1   0   0   0   0
         2   3   2 163   4   5   0   4   2   7   0
         3   2   1   6 145   1   7   0   3   3   1
         4   3   1   8   1 168   3   4   5   1   3
         5   2   0   1   8   2 137   4   0   9   1
         6   2   1   1   1   4   3 156   0   0   0
         7   3   1   5   2   1   0   0 132   4   2
         8   1   1   7   3   4   2   1   0 130   5
         9   0   3   1   0   8   3   0   5   5 165

Overall Statistics

               Accuracy : 0.8934         
                 95% CI : (0.879, 0.9065)
    No Information Rate : 0.1789         
    P-Value [Acc > NIR] : < 2.2e-16      

                  Kappa : 0.8803         
 Mcnemar's Test P-Value : NA             

Statistics by Class:

                     Class: 0 Class: 1 Class: 2 Class: 3 Class: 4 Class: 5 Class: 6 Class: 7 Class: 8 Class: 9
Sensitivity            0.9554   0.9621  0.82323  0.87349  0.84000  0.85625  0.91765  0.89796  0.78313  0.93220
Specificity            0.9854   0.9977  0.98507  0.98696  0.98395  0.98538  0.99347  0.99032  0.98696  0.98634
Pos Pred Value         0.9346   0.9845  0.85789  0.85799  0.85279  0.83537  0.92857  0.88000  0.84416  0.86842
Neg Pred Value         0.9902   0.9943  0.98074  0.98857  0.98232  0.98752  0.99239  0.99192  0.98057  0.99340
Prevalence             0.1789   0.1315  0.09865  0.08271  0.09965  0.07972  0.08470  0.07324  0.08271  0.08819
Detection Rate         0.1709   0.1266  0.08122  0.07225  0.08371  0.06826  0.07773  0.06577  0.06477  0.08221
Detection Prevalence   0.1829   0.1286  0.09467  0.08421  0.09816  0.08171  0.08371  0.07474  0.07673  0.09467
Balanced Accuracy      0.9704   0.9799  0.90415  0.93023  0.91198  0.92082  0.95556  0.94414  0.88505  0.95927

confusionMatrix(predict_lda$class,testing$V1)
#output
Confusion Matrix and Statistics

          Reference
Prediction   0   1   2   3   4   5   6   7   8   9
         0 342   0   7   3   1   6   1   0   5   0
         1   0 251   2   0   4   0   0   1   0   0
         2   0   0 157   3   6   0   3   0   2   0
         3   4   2   4 142   0  16   0   2  11   0
         4   3   5  12   3 174   3   3   7   7   4
         5   1   0   2   9   0 125   3   0   4   0
         6   5   3   1   0   2   0 157   0   0   0
         7   0   0   1   1   2   0   0 129   0   5
         8   3   1  12   4   1   5   3   1 135   3
         9   1   2   0   1  10   5   0   7   2 165

Overall Statistics

               Accuracy : 0.8854         
                 95% CI : (0.8706, 0.899)
    No Information Rate : 0.1789         
    P-Value [Acc > NIR] : < 2.2e-16      

                  Kappa : 0.8713         
 Mcnemar's Test P-Value : NA             

Statistics by Class:

                     Class: 0 Class: 1 Class: 2 Class: 3 Class: 4 Class: 5 Class: 6 Class: 7 Class: 8 Class: 9
Sensitivity            0.9526   0.9508  0.79293  0.85542  0.87000  0.78125  0.92353  0.87755  0.81325  0.93220
Specificity            0.9860   0.9960  0.99226  0.97882  0.97399  0.98971  0.99401  0.99516  0.98207  0.98470
Pos Pred Value         0.9370   0.9729  0.91813  0.78453  0.78733  0.86806  0.93452  0.93478  0.80357  0.85492
Neg Pred Value         0.9896   0.9926  0.97767  0.98686  0.98544  0.98121  0.99293  0.99037  0.98314  0.99338
Prevalence             0.1789   0.1315  0.09865  0.08271  0.09965  0.07972  0.08470  0.07324  0.08271  0.08819
Detection Rate         0.1704   0.1251  0.07823  0.07075  0.08670  0.06228  0.07823  0.06428  0.06726  0.08221
Detection Prevalence   0.1819   0.1286  0.08520  0.09018  0.11011  0.07175  0.08371  0.06876  0.08371  0.09616
Balanced Accuracy      0.9693   0.9734  0.89260  0.91712  0.92200  0.88548  0.95877  0.93636  0.89766  0.95845

[编辑] 没有caret

table(predict_mlogit,testing$V1)
# output
predict_mlogit   0   1   2   3   4   5   6   7   8   9
             0 343   0   5   2   5   4   1   0   7   0
             1   0 254   1   0   2   1   0   0   0   0
             2   3   2 163   4   5   0   4   2   7   0
             3   2   1   6 145   1   7   0   3   3   1
             4   3   1   8   1 168   3   4   5   1   3
             5   2   0   1   8   2 137   4   0   9   1
             6   2   1   1   1   4   3 156   0   0   0
             7   3   1   5   2   1   0   0 132   4   2
             8   1   1   7   3   4   2   1   0 130   5
             9   0   3   1   0   8   3   0   5   5 165
# accuracy
sum(predict_mlogit==testing$V1)/length(testing$V1)
# [1] 0.8933732

table(predict_lda$class,testing$V1)
# output
      0   1   2   3   4   5   6   7   8   9
  0 342   0   7   3   1   6   1   0   5   0
  1   0 251   2   0   4   0   0   1   0   0
  2   0   0 157   3   6   0   3   0   2   0
  3   4   2   4 142   0  16   0   2  11   0
  4   3   5  12   3 174   3   3   7   7   4
  5   1   0   2   9   0 125   3   0   4   0
  6   5   3   1   0   2   0 157   0   0   0
  7   0   0   1   1   2   0   0 129   0   5
  8   3   1  12   4   1   5   3   1 135   3
  9   1   2   0   1  10   5   0   7   2 165
# accuracy
sum(predict_lda$class==testing$V1)/length(testing$V1)
# [1] 0.8854011
于 2017-02-16T07:01:12.307 回答