0

由于递归性质,我已经能够通过一次输入一个项目来激活一个只有 1 个输入神经元的 lstm 序列。

但是,当我尝试用相同的技术训练网络时,它永远不会收敛。训练永远持续下去。

这就是我正在做的事情,我将自然语言字符串转换为二进制,然后一次输入一个数字。我转换成二进制的原因是因为网络只取 0 到 1 之间的值。

我知道训练有效,因为当我使用与输入神经元一样多的值的数组进行训练时,在本例中为 1 所以:[0],它会收敛并训练得很好。

我想我可以单独传递每个数字,但是每个数字都会有一个单独的理想输出。当数字再次出现在另一个训练集中的另一个理想输出时,它不会收敛,因为例如 0 怎么可能属于 0 和 1 类?如果我在这个假设上错了,请告诉我。

如何使用序列训练这个 lstm,以便在激活时对相似的序列进行类似的分类?

这是我的整个培训师文件:https://github.com/theirf/synaptic/blob/master/src/trainer.js

这是在工作人员身上训练网络的代码:

workerTrain: function(set, callback, options) {

    var that = this;
    var error = 1;
    var iterations = bucketSize = 0;
    var input, output, target, currentRate;
    var length = set.length;

    var start = Date.now();

    if (options) {
        if (options.shuffle) {
            function shuffle(o) { //v1.0
                for (var j, x, i = o.length; i; j = Math.floor(Math.random() *
          i), x = o[--i], o[i] = o[j], o[j] = x);
                return o;
            };
          }
          if(options.iterations) this.iterations = options.iterations;
          if(options.error) this.error = options.error;
          if(options.rate) this.rate = options.rate;
          if(options.cost) this.cost = options.cost;
          if(options.schedule) this.schedule = options.schedule;
          if (options.customLog){
            // for backward compatibility with code that used customLog
            console.log('Deprecated: use schedule instead of customLog')
            this.schedule = options.customLog;
          }
    }

    // dynamic learning rate
    currentRate = this.rate;
    if(Array.isArray(this.rate)) {
        bucketSize = Math.floor(this.iterations / this.rate.length);
    }

    // create a worker
    var worker = this.network.worker();

    // activate the network
    function activateWorker(input)
        {
            worker.postMessage({ 
                action: "activate",
                input: input,
                memoryBuffer: that.network.optimized.memory
            }, [that.network.optimized.memory.buffer]);
        }

        // backpropagate the network
        function propagateWorker(target){
            if(bucketSize > 0) {
                  var currentBucket = Math.floor(iterations / bucketSize);
                  currentRate = this.rate[currentBucket];
            }
            worker.postMessage({ 
                action: "propagate",
                target: target,
                rate: currentRate,
                memoryBuffer: that.network.optimized.memory
            }, [that.network.optimized.memory.buffer]);
        }

        // train the worker
        worker.onmessage = function(e){
            // give control of the memory back to the network
            that.network.optimized.ownership(e.data.memoryBuffer);

            if(e.data.action == "propagate"){
                if(index >= length){
                    index = 0;
                    iterations++;
                    error /= set.length;

                    // log
                    if(options){
                        if(this.schedule && this.schedule.every && iterations % this.schedule.every == 0)
                        abort_training = this.schedule.do({
                            error: error,
                            iterations: iterations
                        });
                        else if(options.log && iterations % options.log == 0){
                            console.log('iterations', iterations, 'error', error);
                        };
                        if(options.shuffle) shuffle(set);
                    }

                    if(!abort_training && iterations < that.iterations && error > that.error){
                        activateWorker(set[index].input);
                    }
                    else{
                        // callback
                        callback({
                           error: error,
                           iterations: iterations,
                           time: Date.now() - start
                        })
                    }
                    error = 0;
                }
                else{
                    activateWorker(set[index].input);
               }
        }

        if(e.data.action == "activate"){
            error += that.cost(set[index].output, e.data.output);
            propagateWorker(set[index].output); 
            index++;
        }
    }
4

2 回答 2

1

自然语言字符串不应转换为二进制以进行规范化。改用 one-hot 编码:

在此处输入图像描述

此外,我建议您看看Neataptic而不是 Synaptic。它修复了 Synaptic 中的许多错误,并提供更多功能供您使用。它在训练期间有一个特殊的选项,称为clear. 这告诉网络在每次训练迭代时重置上下文,因此它知道它是从头开始的。

于 2017-05-15T19:45:55.987 回答
0

为什么您的网络只有 1 个二进制输入?网络输入应该是有意义的。神经网络很强大,但你给他们的任务非常艰巨。

相反,您应该有多个输入,每个字母一个。或者更理想的是,每个单词一个。

于 2015-05-22T02:00:38.517 回答