我有一个训练有素的多层网络,但我被困在如何对额外的时间步长进行预测。
我尝试通过创建此方法来遵循字符迭代示例-
public float[] sampleFromNetwork(INDArray testingData, int numTimeSteps, DataSetIterator iter){
int inputCount = this.getNumOfInputs();
int outputCount = this.getOutputCount();
float[] samples = new float[numTimeSteps];
//Sample from network (and feed samples back into input) one value at a time (for all samples)
//Sampling is done in parallel here
this.network.rnnClearPreviousState();
INDArray output = this.network.rnnTimeStep(testingData);
output = output.tensorAlongDimension(output.size(2)-1,1,0); //Gets the last time step output
for( int i=0; i<numTimeSteps; ++i ){
//Set up next input (single time step) by sampling from previous output
INDArray nextInput = Nd4j.zeros(1,inputCount);
//Output is a probability distribution. Sample from this for each example we want to generate, and add it to the new input
double[] outputProbDistribution = new double[outputCount];
for( int j=0; j<outputProbDistribution.length; j++ ) {
outputProbDistribution[j] = output.getDouble(j);
}
int nextValue = sampleFromDistribution(outputProbDistribution, new Random());
nextInput.putScalar(new int[]{0,nextValue}, 1.0f); //Prepare next time step input
samples[i] = (nextValue); //Add sampled character to StringBuilder (human readable output)
output = this.network.rnnTimeStep(nextInput); //Do one time step of forward pass
}
return samples;
}
但是 sampleFromDistribution() 没有意义,因为我没有使用离散类。
有什么想法吗?