0

我正在尝试实现一个点积层,但它似乎无法正常工作。这是我到目前为止的几个实现

实施1

// dot product of betas and factors
        graphBuilder.addVertex("layer_product_",
                new ElementWiseVertex(ElementWiseVertex.Op.Product), "layer_output_Betas_", "layer_output_F_");
        graphBuilder.addLayer("layer_dot_",
                new DenseLayer.Builder().nOut(1)
                .activation(new SumActivation()).build(), "layer_product_");

SumActivation.java

public class SumActivation extends BaseActivationFunction {
    @Override
    public INDArray getActivation(INDArray in, boolean training) {
        Nd4j.getExecutioner().execAndReturn(new Sum(in, 1));
        return in;
    }

    @Override
    public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
        assertShape(in, epsilon);
        Nd4j.getExecutioner().execAndReturn(new Sum(in, 1));
        return new Pair<>(in, null);
    }

    @Override
    public String toString() {
       return  "reduce-sum";
    }
}

实施 2

// dot product of betas and factors
graphBuilder.addVertex("layer_product_",
                new ElementWiseVertex(ElementWiseVertex.Op.Product), "layer_output_Betas_", "layer_output_F_");
    graphBuilder.addLayer("layer_dot_",
                        new GlobalPoolingLayer.Builder().poolingDimensions(1).poolingType(PoolingType.SUM).build(), "layer_product_");
                graphBuilder.inputPreProcessor("layer_dot_", new FeedForwardToRnnPreProcessor());

输入将是 nx 6 矩阵,点积的输出将是 anx 1 向量。我在这里做错了什么吗?

4

0 回答 0