在我的 sklearn 分类模型中,当我手动设置 TfidfVectorizer 的“ngram_range=(4,4)”参数时,我得到 0.58 作为 f1_macro 的结果 .. 例如对于 unigram (1,1) 结果是 0.49 ..
问题是,当我使用 GridSearchCv 方法选择最佳参数时,它没有给我最佳参数,而是返回参数集中第一个元素的结果,请看一下我的代码以使其更清楚:
import numpy as np
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer, TfidfTransformer
from sklearn.pipeline import Pipeline, FeatureUnion
from sklearn.linear_model import SGDClassifier
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, classification_report, average_precision_score
import re
from os import walk
import csv
import operator
# variables --
co = dict()
lex = []
def tokenizeManu(txt):
txt = clean_str(txt)
return txt.split()
def tokenizeTfidf(txt):
return txt.split() return txt.split() # It just split the text without any operation
def repAllTxt(txt):
out = re.sub("[a-z]|[A-Z]", '*', txt)
out = re.sub("[0-9]", '#', out)
return out
def corpDict(x):
count = CountVectorizer(ngram_range=(1, 1), tokenizer=tokenizeManu, lowercase=False)
countFit = count.fit_transform(x)
vocab = count.get_feature_names()
dist = np.sum(countFit.toarray(), axis=0)
for tag, count in zip(vocab, dist):
co[tag] = count
# print(len(co))
def clean_str(string):
string = re.sub(r"[^A-Za-z0-9()\-\":.$,!?\'\`]", r" ", string)
string = re.sub(r"([()\-\":.,!?\'\`])", r" \1 ", string)
string = re.sub(r"\'s", r" \'s", string)
string = re.sub(r"\'m", r" \'m", string)
string = re.sub(r"\'ve", r" \'ve", string)
string = re.sub(r"n\'t", r" n\'t", string)
string = re.sub(r"\'re", r" \'re", string)
string = re.sub(r"\'d", r" \'d", string)
string = re.sub(r"\'ll", r" \'ll", string)
string = re.sub(r"\s{2,}", r" ", string)
return string.strip()
def readLexicons():
path = 'lexicons'
# Load data from files
f = []
for (dirpath, dirnames, filenames) in walk(path):
for i in filenames:
f.append(str(dirpath+'\\'+i))
lexList = []
for pa in f:
if pa.endswith('txt') == True:
with open(pa, encoding="utf8") as inf:
reader = csv.reader(inf, delimiter='\n',quoting=csv.QUOTE_NONE)
col = list(zip(*reader))
lexList.extend(col[0])
else:
file_object = open(pa, "r")
file_object = file_object.read()
file_object = re.findall(r'((?<=word1=)\w+)', file_object)
lexList.extend(file_object)
lex.extend(lexList)
def prepTxtStar(X, kValue, maintainLex):
sorted_co = sorted(co.items(), key=operator.itemgetter(1), reverse=True)[:kValue]
sorted_co = list([i[0] for i in sorted_co])
for row in range(len(X)):
c = str(X[row]).split()
for i in range(len(c)):
if c[i] in co.keys():
if not sorted_co.__contains__(c[i]):
if maintainLex == 0:
c[i] = repAllTxt(c[i])
else:
if not lex.__contains__(c[i]):
c[i] = repAllTxt(c[i])
X[row] = ' '.join(c)
for x in X[:3]:
print(x)
return X
def readFiles():
path = 'datasetpaaaaaaaaaaath/ds.txt'
f = []
for (dirpath, dirnames, filenames) in walk(path):
for i in filenames:
f.append(str(dirpath+'\\'+i))
x = []
y = []
lexList = []
for pa in f:
if pa.endswith('txt') == True:
with open(pa, encoding="utf8") as inf:
reader = csv.reader(inf, delimiter='\t',quoting=csv.QUOTE_NONE)
col = list(zip(*reader))
x.extend(col[2])
y.extend(col[3])
return x,y
if __name__ == "__main__":
xOri, yOri = readFiles()
xOri = [clean_str(i) for i in xOri]
readLexicons()
corpDict(xOri)
xOri = prepTxtStar(xOri, kValue=10000000, maintainLex=0)
x, xTest, y, yTest = train_test_split(xOri, yOri, test_size=0.32, random_state=42)
model = Pipeline([
('tfidf', TfidfVectorizer( analyzer='char_wb', min_df=0.0007,lowercase=False,tokenizer=tokenizeTfidf)),
('clf', SGDClassifier(tol=None, loss='hinge', random_state=38, max_iter=5))
])
# Grid search
parameters = {
'tfidf__ngram_range': [(1,1),(2,2),(3,3),(4,4),(5,5),(6,6)]
}
gs_clf = GridSearchCV(model, parameters, n_jobs=-1, scoring='f1_macro')
gs_clf = gs_clf.fit(x, y)
predicted = gs_clf.predict(xTest)
for param_name in sorted(parameters.keys()):
print("%s: %r" % (param_name, gs_clf.best_params_[param_name]))
print('F1 Macro: ', f1_score(yTest, predicted, average='macro'))
在这个例子中,我得到了以下结果:
tfidf__ngram_range: (1, 1)
F1 Macro: 0.4927875243664717
因此,它选择参数集的第一个元素 (1,1) 而根据 f1_score 的最佳元素是 (4,4) !
有什么问题,我错过了什么吗?
已编辑:完整的源代码与数据集一起添加:数据集