0

我试图在此链接的页面中间复制练习: https ://d2l.ai/chapter_recurrent-neural-networks/sequence.html

该练习使用正弦函数在 -1 到 1 之间创建 1000 个数据点,并使用循环网络来逼近该函数。

下面是我使用的代码。我将回去研究更多为什么这不起作用,因为当我很容易能够使用前馈网络来近似这个函数时,它对我来说没有多大意义。

      //get data
    ArrayList<DataSet> list = new ArrayList();
   
    DataSet dss = DataSetFetch.getDataSet(Constants.DataTypes.math, "sine", 20, 500, 0, 0);

    DataSet dsMain = dss.copy();

    if (!dss.isEmpty()){
        list.add(dss);
    }

   
    if (list.isEmpty()){

        return;
    }

    //format dataset
   list = DataSetFormatter.formatReccurnent(list, 0);

    //get network
    int history = 10;
    ArrayList<LayerDescription> ldlist = new ArrayList<>();
    LayerDescription l = new LayerDescription(1,history, Activation.RELU);
    ldlist.add(l);     
    LayerDescription ll = new LayerDescription(history, 1, Activation.IDENTITY, LossFunctions.LossFunction.MSE);
    ldlist.add(ll);

    ListenerDescription ld = new ListenerDescription(20, true, false);

    MultiLayerNetwork network = Reccurent.getLstm(ldlist, 123, WeightInit.XAVIER, new RmsProp(), ld);


    //train network
    final List<DataSet> lister = list.get(0).asList();
    DataSetIterator iter = new ListDataSetIterator<>(lister, 50);
    network.fit(iter, 50);
    network.rnnClearPreviousState();


    //test network
    ArrayList<DataSet> resList = new ArrayList<>();
    DataSet result = new DataSet();
    INDArray arr = Nd4j.zeros(lister.size()+1);     
    INDArray holder;

    if (list.size() > 1){
        //test on training data
        System.err.println("oops");

    }else{
        //test on original or scaled data
        for (int i = 0; i < lister.size(); i++) {

            holder = network.rnnTimeStep(lister.get(i).getFeatures());
            arr.putScalar(i,holder.getFloat(0));

        }
    }


    //add originaldata
    resList.add(dsMain);
    //result       
    result.setFeatures(dsMain.getFeatures());
  
    result.setLabels(arr);
    resList.add(result);

    //display
    DisplayData.plot2DScatterGraph(resList);

你能解释一下我需要 1 分 10 隐藏和 1 出 lstm 网络来逼近正弦函数的代码吗?

我没有使用任何归一化作为函数已经是 -1:1 并且我使用 Y 输入作为特征,然后使用以下 Y 输入作为标签来训练网络。

您注意到我正在构建一个可以更轻松地构建网络的类,并且我尝试对问题进行许多更改,但我厌倦了猜测。

以下是我的结果的一些示例。蓝色是数据 红色是结果

在此处输入图像描述

在此处输入图像描述

4

2 回答 2

1

This is one of those times were you go from wondering why was this not working to how in the hell were my original results were as good as they were.

My failing was not understanding the documentation clearly and also not understanding BPTT.

With feed forward networks each iteration is stored as a row and each input as a column. An example is [dataset.size, network inputs.size]

However with recurrent input its reversed with each row being a an input and each column an iteration in time necessary to activate the state of the lstm chain of events. At minimum my input needed to be [0, networkinputs.size, dataset.size] But could also be [dataset.size, networkinputs.size, statelength.size]

在我之前的示例中,我使用这种格式的数据训练网络 [dataset.size, networkinputs.size, 1]。因此,根据我对低分辨率的理解,lstm 网络根本不应该工作,但至少以某种方式产生了一些东西。

将数据集转换为列表也可能存在一些问题,因为我也更改了为网络提供数据的方式,但我认为问题的大部分是数据结构问题。

以下是我的新结果 不完美,但考虑到这是 5 个训练阶段,非常好

于 2020-06-23T20:52:48.183 回答
0

如果没有看到完整的代码,很难知道发生了什么。首先,我没有看到指定的 RnnOutputLayer。你可以看看这个,它向你展示了如何在 DL4J 中构建一个 RNN。如果您的 RNN 设置正确,这可能是一个调优问题。您可以在此处找到更多关于调音的信息。对于更新程序,Adam 可能是比 RMSProp 更好的选择。tanh 可能是激活输出层的不错选择,因为它的范围是 (-1,1)。其他要检查/调整的事情——学习率、时期数、数据设置(比如你想预测很远吗?)。

于 2020-06-22T23:47:30.440 回答