8

I would like to apply layer normalization to a recurrent neural network using tf.keras. In TensorFlow 2.0, there is a LayerNormalization class in tf.layers.experimental, but it's unclear how to use it within a recurrent layer like LSTM, at each time step (as it was designed to be used). Should I create a custom cell, or is there a simpler way?

For example, applying dropout at each time step is as easy as setting the recurrent_dropout argument when creating an LSTM layer, but there is no recurrent_layer_normalization argument.

4

2 回答 2

4

您可以通过从类继承来创建自定义单元格SimpleRNNCell,如下所示:

import numpy as np
from tensorflow.keras.models import Sequential
from tensorflow.keras.activations import get as get_activation
from tensorflow.keras.layers import SimpleRNNCell, RNN, Layer
from tensorflow.keras.layers.experimental import LayerNormalization

class SimpleRNNCellWithLayerNorm(SimpleRNNCell):
    def __init__(self, units, **kwargs):
        self.activation = get_activation(kwargs.get("activation", "tanh"))
        kwargs["activation"] = None
        super().__init__(units, **kwargs)
        self.layer_norm = LayerNormalization()
    def call(self, inputs, states):
        outputs, new_states = super().call(inputs, states)
        norm_out = self.activation(self.layer_norm(outputs))
        return norm_out, [norm_out]

SimpleRNN此实现在没有任何的情况下运行一个常规单元activation,然后将层范数应用于结果输出,然后应用activation. 然后你可以像这样使用它:

model = Sequential([
    RNN(SimpleRNNCellWithLayerNorm(20), return_sequences=True,
        input_shape=[None, 20]),
    RNN(SimpleRNNCellWithLayerNorm(5)),
])

model.compile(loss="mse", optimizer="sgd")
X_train = np.random.randn(100, 50, 20)
Y_train = np.random.randn(100, 5)
history = model.fit(X_train, Y_train, epochs=2)

对于 GRU 和 LSTM 单元,人们通常在门上应用层范数(在输入和状态的线性组合之后,在 sigmoid 激活之前),所以实现起来有点棘手。或者,您可以通过在应用activationand之前应用层规范来获得良好的结果recurrent_activation,这将更容易实现。

于 2019-04-09T04:05:47.477 回答
4

LayerNormLSTMCell在 tensorflow 插件中,有一个开箱即用的预构建。

有关更多详细信息,请参阅此文档。您可能必须先安装tensorflow-addons才能导入此单元格。

pip install tensorflow-addons
于 2020-06-24T03:16:04.803 回答