我在 tensorflow 2.x 中实现了一个自定义层。我的要求是,程序应该在返回输出之前检查一个条件。
class SimpleRNN_cell(tf.keras.layers.Layer):
def __init__(self, M1, M2, fi=tf.nn.tanh, disp_name=True):
super(SimpleRNN_cell, self).__init__()
pass
def call(self, X, hidden_state, return_state=True):
y = tf.constant(5)
if return_state == True:
return y, self.h
else:
return y
我的问题是:我应该继续使用当前代码(假设tape.gradient(Loss, self.trainable_weights)
可以正常工作)还是应该使用tf.cond()
. 另外,如果可能,请说明在哪里使用tf.cond()
和在哪里不使用。我没有找到关于这个主题的太多内容。