我正在尝试为机器学习培养一些直觉。我查看了来自https://github.com/deeplearning4j/dl4j-0.4-examples的示例,我想开发自己的示例。基本上我只是采用了一个简单的函数:a * a + b * b + c * c - a * b * c + a + b + c 并为随机 a、b、c 生成 10000 个输出,并尝试在 90 上训练我的网络% 的输入。问题是无论我做什么,我的网络都无法预测其余的示例。
这是我的代码:
public class BasicFunctionNN {
private static Logger log = LoggerFactory.getLogger(MlPredict.class);
public static DataSetIterator generateFunctionDataSet() {
Collection<DataSet> list = new ArrayList<>();
for (int i = 0; i < 100000; i++) {
double a = Math.random();
double b = Math.random();
double c = Math.random();
double output = a * a + b * b + c * c - a * b * c + a + b + c;
INDArray in = Nd4j.create(new double[]{a, b, c});
INDArray out = Nd4j.create(new double[]{output});
list.add(new DataSet(in, out));
}
return new ListDataSetIterator(list, list.size());
}
public static void main(String[] args) throws Exception {
DataSetIterator iterator = generateFunctionDataSet();
Nd4j.MAX_SLICES_TO_PRINT = 10;
Nd4j.MAX_ELEMENTS_PER_SLICE = 10;
final int numInputs = 3;
int outputNum = 1;
int iterations = 100;
log.info("Build model....");
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.iterations(iterations).weightInit(WeightInit.XAVIER).updater(Updater.SGD).dropOut(0.5)
.learningRate(.8).regularization(true)
.l1(1e-1).l2(2e-4)
.optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT)
.list(3)
.layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(8)
.activation("identity")
.build())
.layer(1, new DenseLayer.Builder().nIn(8).nOut(8)
.activation("identity")
.build())
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.RMSE_XENT)//LossFunctions.LossFunction.RMSE_XENT)
.activation("identity")
.weightInit(WeightInit.XAVIER)
.nIn(8).nOut(outputNum).build())
.backprop(true).pretrain(false)
.build();
//run the model
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
model.setListeners(Collections.singletonList((IterationListener) new ScoreIterationListener(iterations)));
//get the dataset using the record reader. The datasetiterator handles vectorization
DataSet next = iterator.next();
SplitTestAndTrain testAndTrain = next.splitTestAndTrain(0.9);
System.out.println(testAndTrain.getTrain());
model.fit(testAndTrain.getTrain());
//evaluate the model
Evaluation eval = new Evaluation(10);
DataSet test = testAndTrain.getTest();
INDArray output = model.output(test.getFeatureMatrix());
eval.eval(test.getLabels(), output);
log.info(">>>>>>>>>>>>>>");
log.info(eval.stats());
}
}
我也玩过学习率,而且很多时候分数都没有提高:
10:48:51.404 [main] DEBUG o.d.o.solvers.BackTrackLineSearch - Exited line search after maxIterations termination condition; score did not improve (bestScore=0.8522868127536543, scoreAtStart=0.8522868127536543). Resetting parameters
作为激活函数,我也尝试了 relu