1

我正在尝试在 Keras 中训练一个包含 LSTM 用于回归目的的循环模型。我想在线使用该模型,据我了解,我需要训练一个有状态的 LSTM。由于模型必须输出一系列值,我希望它计算每个预期输出向量的损失。但是,我担心我的代码不能以这种方式工作,如果有人能帮助我了解我是否做得对或是否有更好的方法,我将不胜感激。

模型的输入是一个 128 维向量的序列。训练集中的每个序列都有不同的长度。每次,模型都应该输出一个包含 3 个元素的向量。

我正在尝试训练和比较两个模型:A)一个具有 128 个输入和 3 个输出的简单 LSTM;B) 一个具有 128 个输入和 100 个输出的简单 LSTM + 一个具有 3 个输出的密集层;

对于模型 A)我写了以下代码:

# Model
model = Sequential()
model.add(LSTM(3, batch_input_shape=(1, None, 128),  return_sequences=True, activation = "linear", stateful = True))`
model.compile(loss='mean_squared_error', optimizer=Adam())

# Training
for i in range(n_epoch):
    for j in np.random.permutation(n_sequences):
        X = data[j] # j-th sequences
        X = X[np.newaxis, ...] # X has size 1 x NTimes x 128

        Y = dataY[j] # Y has size NTimes x 3

        history = model.fit(X, Y, epochs=1, batch_size=1, verbose=0, shuffle=False)
        model.reset_states()

使用这段代码,模型 A) 似乎训练得很好,因为输出序列接近训练集上的真实序列。但是,我想知道损失是否真的是通过考虑所有 NTimes 输出向量来计算的。

对于模型 B),由于密集层,我找不到任何方法来获得整个输出序列。因此,我写道:

# Model
model = Sequential()
model.add(LSTM(100, batch_input_shape=(1, None, 128), , stateful = True))
model.add(Dense(3,   activation="linear"))
model.compile(loss='mean_squared_error', optimizer=Adam())

# Training
for i in range(n_epoch):
    for j in np.random.permutation(n_sequences):
        X = data[j]  #j-th sequence
        X = X[np.newaxis, ...] # X has size 1 x NTimes x 128

        Y = dataY[j] # Y has size NTimes x 3

        for h in range(X.shape[1]):
            x = X[0,h,:]
            x = x[np.newaxis, np.newaxis, ...] # h-th vector in j-th sequence
            y = Y[h,:]
            y = y[np.newaxis, ...]
            loss += model.train_on_batch(x,y)
        model.reset_states() #After the end of the sequence

使用此代码,模型 B) 无法正常训练。在我看来,训练不会收敛,损失值会周期性地增加和减少,我也尝试仅将最后一个向量用作 Y,并且它们在整个训练序列 X 上调用拟合函数,但没有任何改进。

任何想法?谢谢!

4

1 回答 1

2

如果您希望序列的每一步仍然有三个输出,则需要像这样 TimeDistribute 您的 Dense 层:

model.add(TimeDistributed(Dense(3, activation="linear")))

这将密集层独立地应用于每个时间步。

https://keras.io/layers/wrappers/#timedistributed

于 2019-09-25T17:00:42.377 回答