0

我正在尝试为使用相同句子嵌入阶段的不同任务创建多个分类器。

通过重用同一个句子嵌入层的内存消耗将显着减少,因为嵌入层使用大约 300mb 并且分类器大约是 50mb。

我当前的管道如下所示:

        document.setInputCol(EXAMPLE_COL);
        document.setOutputCol(DOCUMENT_COL);

        UniversalSentenceEncoder sentenceEmbeddings = UniversalSentenceEncoder.pretrained("tfhub_use_multi_lg", "xx");
        sentenceEmbeddings.setInputCols(new String[]{DOCUMENT_COL});
        sentenceEmbeddings.setOutputCol(SENTENCE_EMBEDDINGS_COL);

        ClassifierDLApproach dlApproach = new ClassifierDLApproach();
        dlApproach.setInputCols(new String[]{SENTENCE_EMBEDDINGS_COL});
        dlApproach.setOutputCol(CLASS_COL);
        dlApproach.setLabelColumn(LABEL_COL);
        dlApproach.setMaxEpochs(epochs);
        dlApproach.setRandomSeed(42);
        dlApproach.setEnableOutputLogs(false);
        Pipeline classifierPipeline = new Pipeline();
        classifierPipeline.setStages(new PipelineStage[]{document, sentenceEmbeddings, dlApproach});

我正在尝试做这样的事情:

        DocumentAssembler document = new DocumentAssembler();
        document.setInputCol("example");
        document.setOutputCol("document");

        UniversalSentenceEncoder sentenceEmbeddings = UniversalSentenceEncoder.pretrained("tfhub_use_multi_lg", "xx");
        sentenceEmbeddings.setInputCols(new String[]{"document"});
        sentenceEmbeddings.setOutputCol("sentence_embeddings");


        EmbeddingsFinisher embeddingsFinisher = new EmbeddingsFinisher();
        embeddingsFinisher.setInputCols(new String[]{"sentence_embeddings"});
        embeddingsFinisher.setOutputCols(new String[]{"sentence_embeddings"});
        embeddingsFinisher.setOutputAsVector(true);
        embeddingsFinisher.setCleanAnnotations(true);

        Pipeline prePipeline = new Pipeline();
        prePipeline.setStages(new PipelineStage[]{document, sentenceEmbeddings, embeddingsFinisher});

        PipelineModel preModel = prePipeline.fit(sparkSession.createDataset(Collections.emptyList(), Encoders.STRING()).toDF("example"));

        Dataset<Row> output = preModel.transform(trainDataset);

然后使用该管道的嵌入输出作为训练集来拟合分类器。

有两种可能的方法来做到这一点:

  • 方法一:ClassifierDLApproach直接从fit方法中使用(这里我需要模拟上句嵌入阶段的输出和元数据
  • 方法 2:使用只有自定义“嵌入文档汇编器”阶段和分类器阶段的新管道

在训练 prePipeline(句子嵌入管道)之后,所有其他分类器都可以重复使用以进行推理

我的主要问题是,如果没有某种 hack,就无法实现任何方法。

  • 有没有办法在不求助于黑客的情况下做到这一点?
  • 有没有更好的方法来解决这个问题?
4

0 回答 0