我正在尝试为使用相同句子嵌入阶段的不同任务创建多个分类器。
通过重用同一个句子嵌入层的内存消耗将显着减少,因为嵌入层使用大约 300mb 并且分类器大约是 50mb。
我当前的管道如下所示:
document.setInputCol(EXAMPLE_COL);
document.setOutputCol(DOCUMENT_COL);
UniversalSentenceEncoder sentenceEmbeddings = UniversalSentenceEncoder.pretrained("tfhub_use_multi_lg", "xx");
sentenceEmbeddings.setInputCols(new String[]{DOCUMENT_COL});
sentenceEmbeddings.setOutputCol(SENTENCE_EMBEDDINGS_COL);
ClassifierDLApproach dlApproach = new ClassifierDLApproach();
dlApproach.setInputCols(new String[]{SENTENCE_EMBEDDINGS_COL});
dlApproach.setOutputCol(CLASS_COL);
dlApproach.setLabelColumn(LABEL_COL);
dlApproach.setMaxEpochs(epochs);
dlApproach.setRandomSeed(42);
dlApproach.setEnableOutputLogs(false);
Pipeline classifierPipeline = new Pipeline();
classifierPipeline.setStages(new PipelineStage[]{document, sentenceEmbeddings, dlApproach});
我正在尝试做这样的事情:
DocumentAssembler document = new DocumentAssembler();
document.setInputCol("example");
document.setOutputCol("document");
UniversalSentenceEncoder sentenceEmbeddings = UniversalSentenceEncoder.pretrained("tfhub_use_multi_lg", "xx");
sentenceEmbeddings.setInputCols(new String[]{"document"});
sentenceEmbeddings.setOutputCol("sentence_embeddings");
EmbeddingsFinisher embeddingsFinisher = new EmbeddingsFinisher();
embeddingsFinisher.setInputCols(new String[]{"sentence_embeddings"});
embeddingsFinisher.setOutputCols(new String[]{"sentence_embeddings"});
embeddingsFinisher.setOutputAsVector(true);
embeddingsFinisher.setCleanAnnotations(true);
Pipeline prePipeline = new Pipeline();
prePipeline.setStages(new PipelineStage[]{document, sentenceEmbeddings, embeddingsFinisher});
PipelineModel preModel = prePipeline.fit(sparkSession.createDataset(Collections.emptyList(), Encoders.STRING()).toDF("example"));
Dataset<Row> output = preModel.transform(trainDataset);
然后使用该管道的嵌入输出作为训练集来拟合分类器。
有两种可能的方法来做到这一点:
- 方法一:
ClassifierDLApproach
直接从fit
方法中使用(这里我需要模拟上句嵌入阶段的输出和元数据 - 方法 2:使用只有自定义“嵌入文档汇编器”阶段和分类器阶段的新管道
在训练 prePipeline(句子嵌入管道)之后,所有其他分类器都可以重复使用以进行推理
我的主要问题是,如果没有某种 hack,就无法实现任何方法。
- 有没有办法在不求助于黑客的情况下做到这一点?
- 有没有更好的方法来解决这个问题?