我正在使用这个 google colab 笔记本https://colab.research.google.com/drive/1qxcQ2A1nNjFudAGN_mcMOnvV9sF_PkEb#scrollTo=aeXshJM-Cuaf来尝试生成文本。当我跑
sess = gpt2.start_tf_sess()
gpt2.finetune(sess,
dataset=file_name,
model_name='355M',
steps=2000,
restore_from='fresh',
run_name='run1',
print_every=10,
sample_every=500,
learning_rate=1e-5
)
我收到此错误:
WARNING:tensorflow:From /usr/local/lib/python3.7/dist-packages/gpt_2_simple/src/sample.py:17: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
WARNING:tensorflow:From /usr/local/lib/python3.7/dist-packages/gpt_2_simple/src/memory_saving_gradients.py:62: get_backward_walk_ops (from tensorflow.contrib.graph_editor.select) is deprecated and will be removed after 2019-06-06.
Instructions for updating:
Please use tensorflow.python.ops.op_selector.get_backward_walk_ops.
Loading checkpoint models/355M/model.ckpt
INFO:tensorflow:Restoring parameters from models/355M/model.ckpt
0%| | 0/1 [00:00<?, ?it/s]Loading dataset...
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
<ipython-input-6-b6ee97fb7ddf> in <module>()
9 print_every=10,
10 sample_every=500,
---> 11 learning_rate=1e-5
12 )
1 frames
/usr/local/lib/python3.7/dist-packages/gpt_2_simple/src/load_dataset.py in load_dataset(enc, path, combine)
37 reader = csv.reader(fp)
38 for row in reader:
---> 39 raw_text += start_token + row[0] + end_token + "\n"
40 else:
41 # Plain text
IndexError: list index out of range