从 SparkNLP 网站下载 T5-small 模型,并使用此代码(几乎完全来自示例):
import com.johnsnowlabs.nlp.SparkNLP
import com.johnsnowlabs.nlp.annotators.seq2seq.T5Transformer
import org.apache.spark.sql.SparkSession
val spark = SparkSession.builder()
.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
.config("spark.kryoserializer.buffer.max", "500M")
.master("local").getOrCreate()
SparkNLP.start()
val testData = spark.createDataFrame(Seq(
(1, "Google has announced the release of a beta version of the popular TensorFlow machine learning library"),
(2, "The Paris metro will soon enter the 21st century, ditching single-use paper tickets for rechargeable electronic cards.")
)).toDF("id", "text")
val documentAssembler = new DocumentAssembler()
.setInputCol("text")
.setOutputCol("documents")
val t5 = T5Transformer.load("/tmp/t5-small")
.setTask("summarize:")
.setInputCols(Array("documents"))
.setOutputCol("summaries")
new Pipeline().setStages(Array(documentAssembler, t5))
.fit(testData)
.transform(testData)
.select("summaries.result").show(truncate = false)
我从执行者那里得到这个错误:
Caused by: java.lang.IllegalArgumentException: No Operation named [encoder_input_ids] in the Graph
at org.tensorflow.Session$Runner.operationByName(Session.java:384)
at org.tensorflow.Session$Runner.parseOutput(Session.java:398)
at org.tensorflow.Session$Runner.feed(Session.java:132)
at com.johnsnowlabs.ml.tensorflow.TensorflowT5.process(TensorflowT5.scala:76)
最初使用 Spark-2.3.0 运行,但使用 spark-2.4.4 也重现了该问题。其他 SparkNLP 功能运行良好,只有这个 T5 模型失败。磁盘上的模型:
$ ll /tmp/t5-small
drwxr-xr-x@ 6 XXX XXX 192 Dec 25 12:36 metadata
-rw-r--r--@ 1 XXX XXX 791656 Dec 22 18:32 t5_spp
-rw-r--r--@ 1 XXX XXX 175686374 Dec 22 18:32 t5_tensorflow
$ cat /tmp/t5-small/metadata/part-00000
{"class":"com.johnsnowlabs.nlp.annotators.seq2seq.T5Transformer","timestamp":1608475002145,
"sparkVersion":"2.4.4","uid":"T5Transformer_1e0a16435680","paramMap":{},
"defaultParamMap":{"task":"","lazyAnnotator":false,"maxOutputLength":200}}
我是 SparkNLP 的新手,所以我不确定这是一个实际问题还是我做错了什么。将不胜感激任何帮助。