1

我已经设法让 BERT 模型在 johnsnowlabs-spark-nlp 库上工作。我可以将“训练模型”保存在磁盘上,如下所示。

拟合模型

df_bert_trained = bert_pipeline.fit(textRDD)

df_bert=df_bert_trained.transform(textRDD)

保存模型

df_bert_trained.write().overwrite().save("/home/XX/XX/trained_model")

然而,

首先,根据此处的文档https://nlp.johnsnowlabs.com/docs/en/concepts,据说可以将模型加载为

EmbeddingsHelper.load(path, spark, format, reference, dims, caseSensitive) 

但目前我不清楚变量“reference”代表什么。

其次,有没有人设法将 BERT 嵌入保存为 python 中的 pickle 文件?

4

1 回答 1

1

在 Spark NLP 中,BERT 作为预训练模型出现。这意味着它已经是一个经过训练、拟合等并以正确格式保存的模型。

话虽如此,没有理由再次安装或保存它。但是,一旦将 DataFrame 转换为具有每个令牌的 BERT 嵌入的新 DataFrame,您就可以保存它的结果。

例子:

使用 Spark NLP 包在 spark-shell 中启动 Spark 会话

spark-shell --packages JohnSnowLabs:spark-nlp:2.4.0
import com.johnsnowlabs.nlp.annotators._
import com.johnsnowlabs.nlp.base._

val documentAssembler = new DocumentAssembler()
      .setInputCol("text")
      .setOutputCol("document")

    val sentence = new SentenceDetector()
      .setInputCols("document")
      .setOutputCol("sentence")

    val tokenizer = new Tokenizer()
      .setInputCols(Array("sentence"))
      .setOutputCol("token")

    // Download and load the pretrained BERT model
    val embeddings = BertEmbeddings.pretrained(name = "bert_base_cased", lang = "en")
      .setInputCols("sentence", "token")
      .setOutputCol("embeddings")
      .setCaseSensitive(true)
      .setPoolingLayer(0)

    val pipeline = new Pipeline()
      .setStages(Array(
        documentAssembler,
        sentence,
        tokenizer,
        embeddings
      ))

// Test and transform

   val testData = Seq(
      "I like pancakes in the summer. I hate ice cream in winter.",
      "If I had asked people what they wanted, they would have said faster horses"
    ).toDF("text")

    val predictionDF = pipeline.fit(testData).transform(testData)

predictionDF是一个 DataFrame,其中包含数据集中每个标记的 BERT 嵌入。预BertEmbeddings训练模型来自 TF Hub,这意味着它们与 Google 发布的预训练权重完全相同。所有 5 种型号均可用:

  • bert_base_cased (en)
  • bert_base_uncased (en)
  • bert_large_cased (en)
  • bert_large_uncased (en)
  • bert_multi_cased (xx)

如果您有任何疑问或问题,请告诉我,我会更新我的答案。

参考资料

于 2020-02-14T11:06:22.673 回答