1

我正在做文本分类。我Conv1D在 Keras 层之上使用了Embedding层。我的验证准确度为 0.68。这是我正在使用的数据集。这是我正在使用的代码:

   #Processing
import pandas as pd
import pickle

from keras.preprocessing.text import Tokenizer
from keras.preprocessing.sequence import pad_sequences
import numpy as np
from sklearn.preprocessing import LabelEncoder

from keras.models import Sequential
from keras.layers import Embedding, Flatten, Dense
from sklearn.preprocessing import LabelEncoder
from keras.layers import Embedding,Flatten,Dense,Conv1D,MaxPooling1D,GlobalMaxPooling1D
from keras.models import load_model

# df = pd.read_csv('text_emotion.csv')
#
# df.drop(['tweet_id', 'author'], axis=1, inplace=True)

# df = df[~df['sentiment'].isin(['empty', 'enthusiasm', 'boredom', 'anger'])]

# df = df.sample(frac=1).reset_index(drop=True)
df=pd.read_csv('emotion_merged_dataset.csv')
labels = df['sentiment']
# texts = df['content']
texts=df['text']
print (texts.shape)
#############################################
tokenizer = Tokenizer(3000)
tokenizer.fit_on_texts(texts)

sequences = tokenizer.texts_to_sequences(texts)
# print(sequences)

word_index = tokenizer.word_index
with open('tokenizer.pickle', 'wb') as handle:
    pickle.dump(tokenizer, handle, protocol=pickle.HIGHEST_PROTOCOL)

# with open('word_index.pickle', 'rb') as handle:
#     word_index_new = pickle.load(handle)

# print (word_index == word_index_new)

# print('Word index: '+str(word_index))
# print('Found %s unique tokens.' % len(word_index))

data = pad_sequences(sequences, maxlen=37)

encoder = LabelEncoder()
encoder.fit(labels)
encoded_Y = encoder.transform(labels)

from keras.utils import np_utils

labels = np_utils.to_categorical(encoded_Y)
print ('Labels: '+str(labels))

print('Shape of data tensor:', data.shape)
print('Shape of label tensor:', labels.shape)
print('data: '+str(data))

indices = np.arange(data.shape[0])
np.random.shuffle(indices)
data = data[indices]
# print ('data:'+str(data[0]))
labels = labels[indices]
print(labels.shape)

model = Sequential()
model.add(Embedding(3000, 300, input_length=37))
# model.add(Flatten())

model.add(Conv1D(32,7,activation='relu'))
model.add(MaxPooling1D(3))
model.add(Conv1D(32,7,activation='relu'))
model.add(GlobalMaxPooling1D())
model.add(Dense(labels.shape[1], activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

model.fit(data, labels, validation_split=0.2, epochs=10, batch_size=100)
model.save('model_keras_embedding_cnn.h5')
print (model.summary())

我正在腌制标记器并保存模型。然后我使用分词器预处理一个示例输入语句来检查我的模型。下面是用于测试的代码:

from keras.models import load_model
from keras.preprocessing.sequence import pad_sequences
import pickle
model = load_model('model_keras_embedding_cnn.h5')
texts='I am really sad'
with open('tokenizer.pickle', 'rb') as handle:
    tokenizer_new = pickle.load(handle)
tokenizer_new.fit_on_texts(texts)
sequences = tokenizer_new.texts_to_sequences(texts)
data = pad_sequences(sequences, maxlen=37)
print (model.predict_classes(data,verbose=10))

我得到如下输出:

[5 5 5 5 5 5 5 5 5 5 5 5 5 5 5].

我怎样才能获得真正的阶级标签(如恐惧、愤怒等)?我的方法是否正确,即保存标记器并再次使用它?我在概念上搞砸了某个地方吗?[编辑] 我在 JARS 的推荐下使用了 inverse_transform:

print (encoder_new.inverse_transform(pred))

输出是这样的:

['neutral' 'neutral' 'neutral' 'neutral' 'neutral' 'neutral' 'neutral'
 'neutral' 'neutral' 'neutral' 'neutral' 'neutral' 'neutral' 'neutral'
 'neutral']

有人可以解释输出吗?

4

1 回答 1

2

这个答案的功劳归于 JARS。所以我的输入应该是

texts=['I am very happy with the result']

结果是:

['joy']

整个代码是:

from keras.models import load_model
from keras.preprocessing.sequence import pad_sequences
import pickle
model = load_model('model_keras_embedding_cnn.h5')
texts=['I am very happy with the result']
with open('tokenizer.pickle', 'rb') as handle:
    tokenizer_new = pickle.load(handle)
with open('encoder.pickle', 'rb') as handle:
    encoder_new = pickle.load(handle)
tokenizer_new.fit_on_texts(texts)
sequences = tokenizer_new.texts_to_sequences(texts)
data = pad_sequences(sequences, maxlen=37)
pred=model.predict_classes(data)
print (encoder_new.inverse_transform(pred))

如果您计算句子中的字符数,您会发现:15 个字符,这就是我得到的输出数。

于 2018-03-08T14:33:41.707 回答