tanh
除了TensorFlow 中的默认值之外,我还想尝试其他一些传递函数BasicRNNCell
。
原来的实现是这样的:
class BasicRNNCell(RNNCell):
(...)
def __call__(self, inputs, state, scope=None):
"""Most basic RNN: output = new_state = tanh(W * input + U * state + B)."""
with vs.variable_scope(scope or type(self).__name__): # "BasicRNNCell"
output = tanh(linear([inputs, state], self._num_units, True))
return output, output
...我将其更改为:
class MyRNNCell(BasicRNNCell):
(...)
def __call__(self, inputs, state, scope=None):
"""Most basic RNN: output = new_state = tanh(W * input + U * state + B)."""
with tf.variable_scope(scope or type(self).__name__): # "BasicRNNCell"
output = my_transfer_function(linear([inputs, state], self._num_units, True))
return output, output
更改vs.variable_scope
为tf.variable_scope
, 是成功的,但linear
它是 > rnn_cell.py < 中的一个实现,并且本身不可用tf
。
我怎样才能让它工作?
我必须完全重新实现linear
吗?(我已经检查了代码,我想我也会在那里遇到依赖问题......)