我正在尝试为 NER 微调 BERT。我从这里下载了一个检查点(https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip)
我已经使用以下代码加载了会话和图表:
model = "./cased_L-12_H-768_A-12/bert_model"
new_saver = tf.train.import_meta_graph(model + ".ckpt.meta")
new_saver.restore(sess, model + '.ckpt')
graph = tf.get_default_graph()
现在,我试图从该图中获取输入占位符,以创建我自己的 feed_dict 并定义我自己的损失函数。我使用以下代码检查图表:
op = sess.graph.get_operations()
[m.values() for m in op]
下面列出了我找到的唯一占位符:
[(<tf.Tensor 'Placeholder:0' shape=(1, 128) dtype=int32>,),
(<tf.Tensor 'Placeholder_1:0' shape=(1, 128) dtype=int32>,),
(<tf.Tensor 'Placeholder_2:0' shape=(1, 128) dtype=int32>,),
这些占位符对我来说看起来不正确,原因如下:
我希望它们的大小为 (None,512),因为此 BERT 模型接受的序列长度最大为 512,并且不应该预先确定 batch_size。根据我在这里看到的这个大小,这个 BERT 模型一次接受 1 个序列,最大大小为 128。这是为什么呢?
我相信我们必须提供一组序列、它们的长度和相应的标签。在这三个占位符中,哪个是哪个?