1

我正在尝试在 Keras 中创建简单的 RNN,它将学习这个数据集:

x_train = [
    [0,0,0,1,-1,-1,1,0,1,0,...,0,1,-1],
    [-1,0,0,-1,-1,0,1,1,1,...,-1,-1,0],
    ...
    [1,0,0,1,1,0,-1,-1,-1,...,-1,-1,0]
]

其中 1 表示增加一个指标,-1 表示减少,0 表示指标没有变化。每个数组有 83 个项目,用于 83 个指标,每个数组的输出(标签)是一个分类数组,显示这些指标对单个指标的影响:

[[ 0.  0.  1.]
 [ 1.  0.  0.],
 [ 0.  0.  1.],
 ...
 [ 0.  0.  1.],
 [ 1.  0.  0.]]

我在下面的代码中使用了Keras和:LSTM

def train(x, y, x_test, y_test):
    x_train = np.array(x)
    y_train = np.array(y)
    print x_train.shape
    y_train = to_categorical(y_train, 3)
    model = Sequential()
    model.add(LSTM(128,input_dim=83, input_length=3))
    model.add(Dropout(0.5))
    model.add(Dense(3, activation='softmax'))
    opt = optimizers.SGD(lr=0.1, decay=1e-2)
    model.compile(loss='categorical_crossentropy',
            optimizer=opt,
            metrics=['accuracy'])
    model.fit(x_train, y_train, batch_size=128, nb_epoch=200)

行的输出print x_train.shape(1618, 83),当我运行我的代码时,我得到这个错误:

Traceback (most recent call last):
  File "temp.py", line 171, in <module>
    load()
  File "temp.py", line 166, in load
    train(x, y, x_test, y_test)
  File "temp.py", line 63, in train
    model.fit(x_train, y_train, batch_size=128, nb_epoch=200)
  File "/usr/local/lib/python2.7/dist-packages/keras/models.py", line 652, in fit
    sample_weight=sample_weight)
  File "/usr/local/lib/python2.7/dist-packages/keras/engine/training.py", line 1038, in fit
    batch_size=batch_size)
  File "/usr/local/lib/python2.7/dist-packages/keras/engine/training.py", line 963, in _standardize_user_data
    exception_prefix='model input')
  File "/usr/local/lib/python2.7/dist-packages/keras/engine/training.py", line 100, in standardize_input_data
    str(array.shape))
Exception: Error when checking model input: expected lstm_input_1 to have 3 dimensions, but got array with shape (1618, 83)

我不想使用Embedding并想添加input_shapeLSTM图层中。

4

1 回答 1

1

LSTM 是循环层,这意味着输入数据必须是三维的,对应于二维输入形状。在实践中,这意味着数据必须具有形状(num_samples, timesteps, features)并且输入形状必须是(timesteps, features).

在您的情况下,您在数据和输入形状中都缺少时间步长维度。

于 2017-08-30T07:32:46.193 回答