5

在我的 scikits-learn 管道中,我想将自定义词汇表传递给 CountVectorizer():

text_classifier = Pipeline([
    ('count', CountVectorizer(vocabulary=myvocab)),
    ('tfidf', TfidfTransformer()),
    ('clf', LinearSVC(C=1000))
])

但是,据我所知,当我打电话时

text_classifier.fit(X_train, y_train)

Pipeline 使用 CountVectorizer() 的 fit_transform() 方法,忽略 myvocab。如何修改我的管道以使用 myvocab?谢谢!

4

1 回答 1

9

这是我五分钟前修复的 scikit-learn 中的一个错误。感谢您发现它。我建议你要么从 Github 升级到最新版本,要么将矢量化器从管道中分离出来作为解决方法:

count = CountVectorizer(vocabulary=myvocab)
X_vectorized = count.transform(X_train)

text_classifier = Pipeline([
    ('tfidf', TfidfTransformer()),
    ('clf', LinearSVC(C=1000))
])

text_classifier.fit(X_vectorized, y_train)

更新:自从发布此答案以来,此修复程序已包含在几个 scikit-learn 版本中。

于 2011-07-08T23:19:05.560 回答