3

我想自动区分我希望并行化的相当复杂的功能。

我正在使用TensorFlow 2.x 并使用 tf.GradientTape 进行区分。

我做了一个玩具例子来说明这一点。自动微分在没有线程的情况下工作得很好,但是当完全相同的计算在两个单独的线程中运行时会中断。

import pdb
import tensorflow as tf
import threading

# This ThreadWithResult is from https://stackoverflow.com/a/65447493/1935801 and works fine on its own
class ThreadWithResult(threading.Thread):
    def __init__(self, group=None, target=None, name=None, args=(), kwargs={}, *, daemon=None):
        def function():
            self.result = target(*args, **kwargs)
        super().__init__(group=group, target=function, name=name, daemon=daemon)

def my_function(x):
    return x*x + x*x*x

def my_function_threaded(x):
    def square(x):
        result = x*x
        return result

    def cube(x):
        result = x*x*x
        return result

    t1 = ThreadWithResult(target=square, args=(x,))
    t2 = ThreadWithResult(target=cube, args=(x,))

    t1.start()
    t2.start()

    t1.join()
    t2.join()

    y = t1.result + t2.result

    return y

x = tf.constant(3.0)
print("my_function(x) =", my_function(x))
print("my_function_threaded(x) =", my_function_threaded(x))

with tf.GradientTape() as tape:
    tape.watch(x)
    y = my_function(x)

dy_dx = tape.gradient(y, x, unconnected_gradients=tf.UnconnectedGradients.ZERO)
print("Simple dy_dx", dy_dx)

with tf.GradientTape() as tape:
    tape.watch(x)
    y = my_function_threaded(x)

dy_dx = tape.gradient(y, x, unconnected_gradients=tf.UnconnectedGradients.ZERO)
print("Threaded dy_dx", dy_dx)

从下面的输出中可以看出,当线程用于相同的简单计算时,梯度被破坏了。

my_function(x) = tf.Tensor(36.0, shape=(), dtype=float32)
my_function_threaded(x) = tf.Tensor(36.0, shape=(), dtype=float32)
Simple dy_dx tf.Tensor(33.0, shape=(), dtype=float32)
Threaded dy_dx tf.Tensor(0.0, shape=(), dtype=float32)

关于如何在 GradientTape 中并行化我的函数的任何建议/想法将不胜感激?

4

0 回答 0