我正在尝试将 RNNCell 子类化为自定义循环模型,但出现以下错误
File "python3.6/site-packages/tensorflow/python/ops/rnn.py", line 508, in dynamic_rnn
raise TypeError("cell must be an instance of RNNCell")
子类具有以下结构
import tensorflow as tf
class CSCell(tf.nn.rnn_cell.RNNCell):
def __init__(self, mem_size=64, word_size=4, batch_size=16, hidden_size=16, output_size=4):
super(CSCell, self).__init__()
@property
def output_size(self):
return self._output_size
@property
def state_size(self):
return self._total_state_vector_size
def zero_state(self, batch_size, dtype):
#zero vectors here
return zero_tensor
#(output, new_state) = self.__call__(inputs,state)
def __call__(self, inputs, state):
#implementation
return (output, new_state)
在另一个文件中,我使用它如下:
cscell_tf_core = cscell_tf.CSCell(FLAGS.mem_size,
FLAGS.word_size,
FLAGS.batch_size,
FLAGS.hidden_size,
output_size=64)
output_sequence, _ = tf.nn.dynamic_rnn(
cell=cscell_tf_core,
inputs=dataset_tensors.observations,
time_major=True)
我的猜测是 dynamic_rnn 无法识别 CSCell 子类化,但我无法理解原因。我使用 TensorFlow 1.2 版。我被卡住了,任何方向都非常感谢。