10

我想了解更多关于 NLP 的知识。我遇到了这段代码。TfidfVectorizer.fit_transform但是我对打印结果的结果感到困惑。我熟悉 tfidf 是什么,但我不明白这些数字的含义。

import tensorflow as tf
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
import os
import io
import string
import requests
import csv
import nltk
from zipfile import ZipFile

sess = tf.Session()

batch_size = 100
max_features = 1000

save_file_name = os.path.join('smsspamcollection', 'SMSSpamCollection.csv')
if os.path.isfile(save_file_name):
    text_data = []
    with open(save_file_name, 'r') as temp_output_file:
        reader = csv.reader(temp_output_file)
        for row in reader:
            text_data.append(row)

else:
    zip_url = 'http://archive.ics.uci.edu/ml/machine-learning-databases/00228/smsspamcollection.zip'
    r = requests.get(zip_url)
    z = ZipFile(io.BytesIO(r.content))
    file = z.read('SMSSpamCollection')

    # Format data 
    text_data = file.decode()
    text_data = text_data.encode('ascii', errors='ignore')
    text_data = text_data.decode().split('\n')
    text_data = [x.split('\t') for x in text_data if len(x) >= 1]

    # And write to csv 
    with open(save_file_name, 'w') as temp_output_file:
        writer = csv.writer(temp_output_file)
        writer.writerows(text_data)

texts = [x[1] for x in text_data]
target = [x[0] for x in text_data]
target = [1 if x == 'spam' else 0 for x in target]

# Normalize the text
texts = [x.lower() for x in texts]  # lower
texts = [''.join(c for c in x if c not in string.punctuation) for x in texts]  # remove punctuation
texts = [''.join(c for c in x if c not in '0123456789') for x in texts]  # remove numbers
texts = [' '.join(x.split()) for x in texts]  # trim extra whitespace


def tokenizer(text):
    words = nltk.word_tokenize(text)
    return words


tfidf = TfidfVectorizer(tokenizer=tokenizer, stop_words='english', max_features=max_features)
sparse_tfidf_texts = tfidf.fit_transform(texts)
print(sparse_tfidf_texts)

输出是:

(0, 630) 0.37172623140154337 (0, 160) 0.36805562944957004 (0, 38) 0.3613966215413548 (0, 545) 0.2561101665717327 (0, 326) 0.2645280991765623 (0, 967) 0.3277447602873963 (0, 421) 0.3896274380321477 (0, 227) 0.28102915589024796 (0 , 323) 0.22032541100275282 (0, 922) 0.2709848154866997 (1, 577) 0.4007895093299793 (1, 425) 0.5970064521899725 (1, 943) 0.6310763941180291 (1, 878) 0.29102173465492637 (2, 282) 0.1771481430848552 (2, 243) 0.5517018054305785 (2, 955 ) 0.2920174942032025 (2, 138) 0.30143666813167863 (2, 946) 0.2269933441326121 (2, 165) 0.3051095293405041 (2, 268) 0.2820392223588522 (2, 780) 0.24119626642264894 (2, 823) 0.1890454397278538 (2, 674) 0.256251970757827 (2, 874) 0.19343834015314287 : : (5569, 648) 0.24171652492226922
(5569, 123) 0.23011909339432202 (5569, 957) 0.24817919217662862
(5569, 549) 0.28583789844730134 (5569, 863) 0.3026729783085827
(5569, 844) 0.20228305447951195 (5569, 146) 0.2514415602877767
(5569, 595) 0.2463259875380789 (5569, 511) 0.3091904754885042
(5569 , 230) 0.2872728684768659 (5569, 638) 0.34151390143548765
(5569, 83) 0.3464271621701711 (5570, 370) 0.4199910200421362
(5570, 46) 0.48234172093857797 (5570, 317) 0.4171646676697801
(5570, 281) 0.6456993475093024 (5572, 282) 0.25540827228532487
(5572, 385 ) 0.36945842040023935 (5572, 448) 0.25540827228532487
(5572, 931) 0.3031800542518209 (5572, 192) 0.29866989620926737
(5572, 303) 0.43990016711221736 (5572, 87) 0.45211284173737176
(5572, 332) 0.3924202767503492 (5573, 866) 1.0

如果有人可以解释输出,我会非常高兴。

4

1 回答 1

20

请注意,您正在打印稀疏矩阵,因此与打印标准密集矩阵相比,输出看起来不同。请参阅下面的主要组件:

  • 元组代表:(document_id, token_id)
  • 元组后面的值表示给定文档中给定标记的 tf-idf 分数
  • 不存在的元组的 tf-idf 分数为 0

如果要查找token_id对应的令牌,请检查get_feature_names方法。

于 2018-06-18T09:41:00.400 回答