我正在关注 TFT 的 PyTorch 预测教程,并尝试用 DeepAR 模型替换 TFT 模型。但是,当我实例化模型时,我在 google colab 上的会话崩溃了,我不明白为什么。这是我的代码:
deepAR = DeepAR.from_dataset(
training,
learning_rate=4e-3,
log_val_interval=1,
cell_type='LSTM',
hidden_size=20,
rnn_layers=2,
dropout= 0.1,
loss=SMAPE(),
logging_metrics=SMAPE(),
log_interval=10, # uncomment for learning rate finder and otherwise, e.g. to 10 for logging every 10 batches
reduce_on_plateau_patience=4,
)
老实说,我不太确定参数设置,我放弃了与嵌入相关的参数,因为我不知道它们是什么以及如何正确设置它们。您能否帮助我了解为什么代码会导致会话崩溃并解释正确的嵌入相关参数设置是什么?
谢谢