2

我正在尝试将 RNNCell 子类化为自定义循环模型,但出现以下错误

File "python3.6/site-packages/tensorflow/python/ops/rnn.py", line 508, in dynamic_rnn
raise TypeError("cell must be an instance of RNNCell")

子类具有以下结构

import tensorflow as tf

class CSCell(tf.nn.rnn_cell.RNNCell):
def __init__(self, mem_size=64, word_size=4, batch_size=16, hidden_size=16, output_size=4):
    super(CSCell, self).__init__()
@property
def output_size(self):
    return self._output_size
@property
def state_size(self):
    return self._total_state_vector_size

def zero_state(self, batch_size, dtype):
    #zero vectors here 
    return zero_tensor

#(output, new_state) = self.__call__(inputs,state)
def __call__(self, inputs, state):
#implementation
    return (output, new_state)

在另一个文件中,我使用它如下:

cscell_tf_core = cscell_tf.CSCell(FLAGS.mem_size,
FLAGS.word_size,
FLAGS.batch_size,
FLAGS.hidden_size,
output_size=64)

output_sequence, _ = tf.nn.dynamic_rnn(
cell=cscell_tf_core,
inputs=dataset_tensors.observations,
time_major=True)

我的猜测是 dynamic_rnn 无法识别 CSCell 子类化,但我无法理解原因。我使用 TensorFlow 1.2 版。我被卡住了,任何方向都非常感谢。

4

0 回答 0