0

我正在使用这个 google colab 笔记本https://colab.research.google.com/drive/1qxcQ2A1nNjFudAGN_mcMOnvV9sF_PkEb#scrollTo=aeXshJM-Cuaf来尝试生成文本。当我跑

sess = gpt2.start_tf_sess()

gpt2.finetune(sess,
              dataset=file_name,
              model_name='355M',
              steps=2000,
              restore_from='fresh',
              run_name='run1',
              print_every=10,
              sample_every=500,
              learning_rate=1e-5
              )

我收到此错误:

WARNING:tensorflow:From /usr/local/lib/python3.7/dist-packages/gpt_2_simple/src/sample.py:17: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
WARNING:tensorflow:From /usr/local/lib/python3.7/dist-packages/gpt_2_simple/src/memory_saving_gradients.py:62: get_backward_walk_ops (from tensorflow.contrib.graph_editor.select) is deprecated and will be removed after 2019-06-06.
Instructions for updating:
Please use tensorflow.python.ops.op_selector.get_backward_walk_ops.
Loading checkpoint models/355M/model.ckpt
INFO:tensorflow:Restoring parameters from models/355M/model.ckpt
  0%|          | 0/1 [00:00<?, ?it/s]Loading dataset...

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-6-b6ee97fb7ddf> in <module>()
      9               print_every=10,
     10               sample_every=500,
---> 11               learning_rate=1e-5
     12               )

1 frames
/usr/local/lib/python3.7/dist-packages/gpt_2_simple/src/load_dataset.py in load_dataset(enc, path, combine)
     37                 reader = csv.reader(fp)
     38                 for row in reader:
---> 39                     raw_text += start_token + row[0] + end_token + "\n"
     40         else:
     41             # Plain text

IndexError: list index out of range
4

0 回答 0