1

给定一个类的实例列表A[A() for _ in range(5)]我想随机选择其中一个(参见下面的代码示例)

class A:
    def __init__(self, a):
        self.a = a
    def __call__(self):
        return self.a
def f():
    a_list = [A(i) for i in range(5)]
    a = a_list[random.randint(0, 5)]()
    return a

f()

有没有一种方法可以f@tf.function不更改内容f且不调用所有项目的情况下进行装饰a_list

请注意,直接装饰f而不@tf.function对上述代码进行任何其他更改是不可行的,因为它总是会返回相同的结果。另外,我知道这可以通过a_list首先调用所有元素然后使用索引来实现tf.gather_ndA但是,如果调用类型的对象涉及深度神经网络,这将产生大量开销。

4

1 回答 1

3

我现在正在做同样的事情。这是我到目前为止所得到的。如果有人知道更好的方法,我也有兴趣听听。当我在昂贵的调用上运行它时,它比我计算并返回所有值要快得多。

@tf.function
def f2():
    a_list = [A(i) for i in range(5)]
    idx = tf.cast(tf.random.uniform(shape=[], maxval=4), tf.int32)
    return tf.switch_case(idx, a_list)

为了进行速度比较,我使用了昂贵的矩阵代数的调用方法。然后考虑一个调用每个函数的替代函数:

@tf.function
def f3():
    a_list = [A(i) for i in range(40)]
    results = [a() for a in a_list]
    return results

使用 40 个元素运行 f2:0.42643 秒

使用 40 个元素运行 f3:14.9153 秒

因此,对于仅选择一个分支的预期 40 倍加速,这看起来是正确的。

于 2020-09-11T01:53:53.700 回答