我正在使用以下代码将文档分为体育、政治和金钱三类。我可以看到这段代码计算了精确召回和 F1。但是我无法找到一种方法来使用此代码来测试自定义文档以预测其标签。
from nltk.corpus import stopwords, reuters
from nltk import word_tokenize
from nltk.stem.porter import PorterStemmer
import re
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.svm import LinearSVC
from sklearn.multiclass import OneVsRestClassifier
from sklearn.metrics import f1_score, precision_score, recall_score
cachedStopWords = stopwords.words("english")
def tokenize(text):
min_length = 3
words = map(lambda word: word.lower(), word_tokenize(text))
words = [word for word in words if word not in cachedStopWords]
tokens = (list(map(lambda token: PorterStemmer().stem(token),words)))
p = re.compile('[a-zA-Z]+');
filtered_tokens = list(filter (lambda token: p.match(token) and len(token) >= min_length,tokens))
return filtered_tokens
def represent(documents, representer):
train_docs_id = list(filter(lambda doc: doc.startswith("train"), documents))
test_docs_id = list(filter(lambda doc: doc.startswith("test"), documents))
train_docs = [reuters.raw(doc_id) for doc_id in train_docs_id]
test_docs = [reuters.raw(doc_id) for doc_id in test_docs_id]
# Learn and transform train documents
vectorised_train_documents = representer.fit_transform(train_docs)
vectorised_test_documents = representer.transform(test_docs)
# Transform multilabel labels
mlb = MultiLabelBinarizer()
train_labels = mlb.fit_transform([reuters.categories(doc_id) for doc_id in train_docs_id])
test_labels = mlb.transform([reuters.categories(doc_id) for doc_id in test_docs_id])
return (vectorised_train_documents, train_labels, vectorised_test_documents, test_labels)
def evaluate(test_labels, predictions):
precision = precision_score(test_labels, predictions, average='micro')
recall = recall_score(test_labels, predictions, average='micro')
f1 = f1_score(test_labels, predictions, average='micro')
print("Micro-average quality numbers")
print("Precision: {:.4f}, Recall: {:.4f}, F1-measure: {:.4f}".format(precision, recall, f1))
precision = precision_score(test_labels, predictions, average='macro')
recall = recall_score(test_labels, predictions, average='macro')
f1 = f1_score(test_labels, predictions, average='macro')
print("Macro-average quality numbers")
print("Precision: {:.4f}, Recall: {:.4f}, F1-measure: {:.4f}".format(precision, recall, f1))
documents = reuters.fileids()
candidate = {'representer': TfidfVectorizer(tokenizer=tokenize),
'estimator': OneVsRestClassifier(LinearSVC(random_state=42))}
train_docs, train_labels, test_docs, test_labels = represent(documents, candidate['representer'])
candidate['estimator'].fit(train_docs, train_labels)
predictions = candidate['estimator'].predict(test_docs)
evaluate(test_labels, predictions)
致谢: https ://github.com/miguelmalvarez/reuters-tc/blob/master/notebook/Classification_Reuters.ipynb