0

我在 tensorflow 2.x 中实现了一个自定义层。我的要求是,程序应该在返回输出之前检查一个条件。

class SimpleRNN_cell(tf.keras.layers.Layer):
    def __init__(self, M1, M2, fi=tf.nn.tanh, disp_name=True):
        super(SimpleRNN_cell, self).__init__()
        pass        
    def call(self, X, hidden_state, return_state=True):
        y = tf.constant(5)
        if return_state == True:
            return y, self.h
        else:
            return y

我的问题是:我应该继续使用当前代码(假设tape.gradient(Loss, self.trainable_weights)可以正常工作)还是应该使用tf.cond(). 另外,如果可能,请说明在哪里使用tf.cond()和在哪里不使用。我没有找到关于这个主题的太多内容。

4

1 回答 1

2

tf.cond仅在基于可微计算图中的数据执行条件评估时才相关。( https://www.tensorflow.org/api_docs/python/tf/cond ) 这在 TF 1.0 中尤其必要,图形模式是默认设置。对于急切模式,GradientTape系统还允许使用 python 结构进行条件数据流,例如if ...:( https://www.tensorflow.org/guide/autodiff#control_flow )

然而,对于仅仅基于配置参数提供不同的行为,不依赖于计算图的数据并且在模型运行时固定,使用简单的 pythonif语句是正确的。

于 2021-02-13T12:41:25.447 回答