0

我有一个程序来预测新闻文章是否与某个主题有关。

有两个主要脚本:

1) bow_train.py - 生成一个词表和一个模型并将它们存储在两个文件中(arab.model 和 wordList.pkl)

2) bow_predict.py - 使用词表和模型对未知文章进行分类

使用的方法是逻辑回归而不是支持向量机,因为这种分类的逻辑回归的性能应该要好得多。

我想改进结果。是否有另一种方法可以让您强调某些关键字。例如,对于“阿拉伯之春”主题,我会输入一个关键字列表:[“抗议”、“动荡”、“革命”等],带有这些关键字的文件比没有这些关键字的文件有更高的概率。

bow_predict.py

import re
import os
import sys
import pickle
import operator

from collections import Counter

from liblinearutil import *

from bow_util import *

# path to directory with articles that should be classified
rootdirAll = 'C:\\Users\\Jiyda\\Desktop\\bow_arab\\all\\'

# load the wordList and model from the training phase
wordListIn = open('wordList.pkl', 'rb')

m        = load_model('arab.model')
wordList = pickle.load(wordListIn)

counterByFilepathAll = {}

# count and store term frequencies
for folder, subs, files in os.walk(rootdirAll):
    for filename in files:
        filepath       = os.path.join(folder, filename)
        wordsInArticle = get_words_from_file(filepath)
        counterByFilepathAll[filepath] = count_words(wordsInArticle)

denseData = []

# generate features from term frequencies (bag-of-words)
for _, counter in counterByFilepathAll.iteritems():
    denseData.append(gen_features(counter, wordList))

# assume output class is 1 (liblinear/libsvm always require a output class
# even for unknown data)
classList = [1 for _ in xrange(0, len(counterByFilepathAll))]

# predict using the model from training phase
y, x                  = classList, denseData
p_label, p_acc, p_val = predict(y, x, m)

# store probabilites by filepath
probByFilepath = {}
i = 0
for filepath, _ in counterByFilepathAll.iteritems():
    probByFilepath[filepath] = p_val[i]
    i += 1

# sort by probability
sortedByProb = sorted(probByFilepath.iteritems(),
                      key=operator.itemgetter(1),
                      reverse=True)

# write to output file         
probsOut = open('probsOut.txt', 'wb')
for t in sortedByProb:
    probsOut.write(' '.join(str(s) for s in t) + '\n')

probsOut.close()

bow_train.py

import re
import os
import sys
import copy
import pickle

from collections import defaultdict
from collections import Counter

from liblinearutil import *

from bow_util import *

# Initialize directories for articles

rootdirArab = sys.argv[1]
rootdirNoArab = sys.argv[2]

#rootdirArab   = 'C:\\Users\\Jiyda\\Desktop\\bow_arab\\arab\\'
#rootdirNoArab = 'C:\\Users\\Jiyda\\Desktop\\bow_arab\\no_arab\\'

wordSet                 = set()
counterByFilepathArab   = {}
counterByFilepathNoArab = {}

# generate set of all words in all articles
for rootdir in [rootdirArab, rootdirNoArab]:
    for folder, subs, files in os.walk(rootdir):
        for filename in files:
            filepath       = os.path.join(folder, filename)
            wordsInArticle = get_words_from_file(filepath)
            wordSet        = wordSet.union(wordSet, wordsInArticle)

# store sorted set in list
wordList = sorted(wordSet)

# save sorted list to output file for prediction phase
wordListOut = open('wordList.pkl', 'wb')
pickle.dump(wordList, wordListOut)

# count and store term frequencies for all arab spring training articles
for folder, subs, files in os.walk(rootdirArab):
    for filename in files:
        filepath       = os.path.join(folder, filename)
        wordsInArticle = get_words_from_file(filepath)
        counterByFilepathArab[filepath] = count_words(wordsInArticle)

# count and store term frequencies for all non arab spring training articles
for folder, subs, files in os.walk(rootdirNoArab):
    for filename in files:
        filepath       = os.path.join(folder, filename)
        wordsInArticle = get_words_from_file(filepath)
        counterByFilepathNoArab[filepath] = count_words(wordsInArticle)

# generate features. the features for one article are a list of the frequenices 
# of each term in wordList found in the article
denseData = []

for counter in counterByFilepathArab.values():
    denseData.append(gen_features(counter, wordList))

for counter in counterByFilepathNoArab.values():
    denseData.append(gen_features(counter, wordList))

# set output value to 1 for arab spring articles and -1 for non arab spring articles
classList = [1 for _ in xrange(0, len(counterByFilepathArab))] + \
            [-1 for _ in xrange(0, len(counterByFilepathNoArab))]

# train logistic regression model
y, x  = classList, denseData
prob  = problem(y, x)
# uncomment to obtain cross validation results
#param = parameter('-v 5')
m     = train(prob)#, param)

# store model in output file for prediction phase
save_model('arab.model', m)

# uncomment to check if training worked as expected
#p_label, p_acc, p_val = predict(y, x, m)
#ACC, MSE, SCC         = evaluations(y, p_label)

wordListOut.close()
4

0 回答 0