1

我正在尝试使用 SGDClassifier 对 MNIST 问题使用在线(外)学习算法, 但似乎准确性并不总是在增加。

在这种情况下我该怎么办?以最准确的方式保存分类器?SGDClassifier 是否收敛到某个最优解?

这是我的代码:

import numpy as np
from sklearn.linear_model.stochastic_gradient import SGDClassifier
from sklearn.datasets import fetch_mldata
from sklearn.utils import shuffle

#use all digits
mnist = fetch_mldata("MNIST original")
X_train, y_train = mnist.data[:70000] / 255., mnist.target[:70000]

X_train, y_train = shuffle(X_train, y_train)
X_test, y_test = X_train[60000:70000], y_train[60000:70000]  

step =1000
batches= np.arange(0,60000,step)
all_classes = np.array([0,1,2,3,4,5,6,7,8,9])
classifier = SGDClassifier()
for curr in batches:
 X_curr, y_curr = X_train[curr:curr+step], y_train[curr:curr+step]
 classifier.partial_fit(X_curr, y_curr, classes=all_classes)
 score= classifier.score(X_test, y_test)
 print score

print "all done"

我在 MNIST 上测试了 linearSVM 与 SGD,使用 10k 个样本进行训练,10k 个样本进行测试,得到 0.883 13,95 和 0.85 1,32,因此 SGD 更快但准确度较低。

#test linearSVM vs SGD
t0 = time.time()
clf = LinearSVC()
clf.fit(X_train, y_train)
score= clf.score(X_test, y_test)
print score
print (time.time()-t0)

t1 = time.time()
clf = SGDClassifier()
clf.fit(X_train, y_train)
score= clf.score(X_test, y_test)
print score
print (time.time()-t1)

我也在这里找到了一些信息 https://stats.stackexchange.com/a/14936/16843

更新: 超过 1 次(10 次)通过数据达到 90.8 % 的最佳准确度。所以它可以是解决方案。SGD 的另一个特性是在传递给分类器之前必须对数据进行洗牌。

4

1 回答 1

3

首先说明:您使用SGDClassifier的是默认参数:它们可能不是该数据集的最佳值:也可以尝试其他值(特别是对于 alpha,正则化参数)。

现在回答你的问题,线性模型不太可能在像 MNIST 这样的数字图像分类任务的数据集上做得很好。您可能想尝试线性模型,例如:

  • SVC(kernel='rbf')(但不可扩展,尝试训练集的一小部分)而不是增量/核外
  • ExtraTreesClassifier(n_estimator=100)或更多,但也不是核心。子估计器的数量越多,训练所需的时间就越长。

您还可以尝试Nystroem 近似SVC(kernel='rbf')方法是使用Nystroem(n_components=1000, gamma=0.05)拟合数据的一小部分子集(例如 10000 个样本)转换数据集,然后将整个转换后的训练集传递给线性模型,例如SGDClassifier:它需要对数据集进行 2 次传递。

在 github 上还有一个对 1​​ 个隐藏层感知器的拉取请求,它的计算速度应该比ExtraTreesClassifier在 MNIST 上的测试集准确度要快并且接近 98%(并且还提供了一个用于核外学习的 partial_fit API)。

编辑:预期分数估计的波动SGDClassifier:SGD代表随机梯度下降,这意味着一次只考虑一个示例:分类错误的样本可能会导致模型权重的更新,这种方式是有害的对于其他样本,您需要对数据进行不止一次传递,以使学习率降低到足以获得对验证准确度的更平滑估计。您可以在 for 循环中使用itertools.repeat在数据集上执行多次(例如 10 次)。

于 2013-09-19T14:15:45.307 回答