我假设您想用之前的 125 个时间步来预测 50 个时间步(例如)。我为您提供了时间序列最基本的编码器-解码器结构,但它可以改进(例如使用Luong Attention)。
from tensorflow.keras import layers,models
input_timesteps=125
input_features=2
output_timesteps=50
output_features=2
units=100
#Input
encoder_inputs = layers.Input(shape=(input_timesteps,input_features))
#Encoder
encoder = layers.LSTM(units, return_state=True, return_sequences=False)
encoder_outputs, state_h, state_c = encoder(encoder_inputs) # because return_sequences=False => encoder_outputs=state_h
#Decoder
decoder = layers.RepeatVector(output_timesteps)(state_h)
decoder_lstm = layers.LSTM(units, return_sequences=True, return_state=False)
decoder = decoder_lstm(decoder, initial_state=[state_h, state_c])
#Output
out = layers.TimeDistributed(Dense(output_features))(decoder)
model = models.Model(encoder_inputs, out)
所以这里的核心思想是:
- 将时间序列编码为两种状态:
state_h
和state_c
。检查此项以了解 LSTM 单元的工作。
- 重复
state_h
要预测的时间步数
- 使用具有由编码器计算的初始状态的 LSTM 进行解码
- 使用密集层来塑造每个时间步所需的特征数量
model.summary()
我建议您测试我们的架构并使用和可视化它们tf.keras.utils.plot_model(mode,show_shapes=True)
。它为您提供了很好的表示,例如摘要:
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_5 (InputLayer) [(None, 125, 2)] 0
__________________________________________________________________________________________________
lstm_8 (LSTM) [(None, 100), (None, 41200 input_5[0][0]
__________________________________________________________________________________________________
repeat_vector_4 (RepeatVector) (None, 50, 100) 0 lstm_8[0][1]
__________________________________________________________________________________________________
lstm_9 (LSTM) (None, 50, 100) 80400 repeat_vector_4[0][0]
lstm_8[0][1]
lstm_8[0][2]
__________________________________________________________________________________________________
time_distributed_4 (TimeDistrib (None, 50, 2) 202 lstm_9[0][0]
==================================================================================================
Total params: 121,802
Trainable params: 121,802
Non-trainable params: 0
__________________________________________________________________________________________________
和模型绘制: