我正在尝试实现一个点积层,但它似乎无法正常工作。这是我到目前为止的几个实现
实施1
// dot product of betas and factors
graphBuilder.addVertex("layer_product_",
new ElementWiseVertex(ElementWiseVertex.Op.Product), "layer_output_Betas_", "layer_output_F_");
graphBuilder.addLayer("layer_dot_",
new DenseLayer.Builder().nOut(1)
.activation(new SumActivation()).build(), "layer_product_");
SumActivation.java
public class SumActivation extends BaseActivationFunction {
@Override
public INDArray getActivation(INDArray in, boolean training) {
Nd4j.getExecutioner().execAndReturn(new Sum(in, 1));
return in;
}
@Override
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
assertShape(in, epsilon);
Nd4j.getExecutioner().execAndReturn(new Sum(in, 1));
return new Pair<>(in, null);
}
@Override
public String toString() {
return "reduce-sum";
}
}
实施 2
// dot product of betas and factors
graphBuilder.addVertex("layer_product_",
new ElementWiseVertex(ElementWiseVertex.Op.Product), "layer_output_Betas_", "layer_output_F_");
graphBuilder.addLayer("layer_dot_",
new GlobalPoolingLayer.Builder().poolingDimensions(1).poolingType(PoolingType.SUM).build(), "layer_product_");
graphBuilder.inputPreProcessor("layer_dot_", new FeedForwardToRnnPreProcessor());
输入将是 nx 6 矩阵,点积的输出将是 anx 1 向量。我在这里做错了什么吗?