这是在语境中的Google Magenta package
,特别是旋律RNN model
。
我尝试使用自己的数据集训练 basic_rnn,它运行良好,生成了一个可用的检查点。但是,当我尝试改用 attention_rnn 时,通过将“attn_length=40”添加到 hparams,我得到错误“训练期间的 NaN 损失。”。我尝试将 attn_length 更改为其他值,例如 10 或 20,但仍然出现此错误。另外,我确保使用“attention_rnn”参数创建数据集,所以这应该不是问题。
有人有类似的问题吗?
以下是我使用的命令:
convert_dir_to_note_sequences
--input_dir=$INPUT_DIRECTORY
--output_file=$SEQUENCES_TFRECORD
--recursive
melody_rnn_create_dataset --config="attention_rnn" --input=".../mono_notesequences.tfrecord" --output_dir="..." --eval_ratio="0.10"
python ${MODEL}/melody_rnn_train.py --config=attention_rnn --run_dir=${OUTPUT} --sequence_example_file=${INPUT}/attention_rnn/training_melodies.tfrecord --hparams="batch_size=128,rnn_layer_sizes=[512,512],attn_length=40" --num_training_steps=20000