0

我已经下载了 imdb 数据集并将其分离为训练子集并测试使用 Elmo Embedding:

X_train, X_test, y_train, y_test = train_test_split(trainData, trainLabels, test_size = 0.20)

然后,我会将 X_train 和 y_train 值传递给使用 Elmo 执行嵌入的函数:

getElmoW(X_train, y_train)

编程的功能是:

embed=hub.Module("https://tfhub.dev/google/elmo/2",trainable=True)

def ElmoEmbed(x):
    return embed(tf.squeeze(tf.cast(x,tf.string)),signature="default",as_dict=True)["default"]

def getElmoW(Xtrain,yTrain):
    input_text = Input(shape=(1,), dtype=tf.string)
    embedding = Lambda(ElmoEmbed, output_shape=(1024, ))(input_text)
    dense = Dense(256, activation='relu')(embedding)
    pred = Dense(1, activation='softmax')(dense)
    model = Model(inputs=[input_text], outputs=pred)
    model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
    with tf.Session() as session:
        k.set_session(session)
        session.run(tf.compat.v1.global_variables_initializer())
        session.run(tf.compat.v1.tables_initializer())
        history=model.fit(Xtrain,yTrain,epochs=20,batch_size=16)
        print(history)
        model.save_weights("responseElmo.h5")

X_train 和 y_train 具有以下数据:

24364    This movie is a classic in every sense of the ...
12933    Next to  Star Wars  and  The Wizard of Oz   th...
11578    What can I say  I ignored the reviews and went...
2513     Fans of creature feature films have to endure ...

24364    1
12933    1
11578    0
2513     0
        ..
20422    1
8983     0
Name: labels, Length: 250, dtype: int64

每个仅包含 250 个样本。我遇到的问题是当我运行模型时出现以下错误:

ValueError: You are passing a target array of shape (250, 1) while using as loss `categorical_crossentropy`. `categorical_crossentropy` expects targets to be binary matrices (1s and 0s) of shape (samples, classes). If your targets are integer classes, you can convert them to the expected format via:
```
from keras.utils import to_categorical
y_binary = to_categorical(y_int)
```

Alternatively, you can use the loss function `sparse_categorical_crossentropy` instead, which does expect integer targets.

我的目标是整数,所以我不需要使用稀疏分类交叉熵,我错过了什么?

4

0 回答 0