我想自动区分我希望并行化的相当复杂的功能。
我正在使用TensorFlow 2.x 并使用 tf.GradientTape 进行区分。
我做了一个玩具例子来说明这一点。自动微分在没有线程的情况下工作得很好,但是当完全相同的计算在两个单独的线程中运行时会中断。
import pdb
import tensorflow as tf
import threading
# This ThreadWithResult is from https://stackoverflow.com/a/65447493/1935801 and works fine on its own
class ThreadWithResult(threading.Thread):
def __init__(self, group=None, target=None, name=None, args=(), kwargs={}, *, daemon=None):
def function():
self.result = target(*args, **kwargs)
super().__init__(group=group, target=function, name=name, daemon=daemon)
def my_function(x):
return x*x + x*x*x
def my_function_threaded(x):
def square(x):
result = x*x
return result
def cube(x):
result = x*x*x
return result
t1 = ThreadWithResult(target=square, args=(x,))
t2 = ThreadWithResult(target=cube, args=(x,))
t1.start()
t2.start()
t1.join()
t2.join()
y = t1.result + t2.result
return y
x = tf.constant(3.0)
print("my_function(x) =", my_function(x))
print("my_function_threaded(x) =", my_function_threaded(x))
with tf.GradientTape() as tape:
tape.watch(x)
y = my_function(x)
dy_dx = tape.gradient(y, x, unconnected_gradients=tf.UnconnectedGradients.ZERO)
print("Simple dy_dx", dy_dx)
with tf.GradientTape() as tape:
tape.watch(x)
y = my_function_threaded(x)
dy_dx = tape.gradient(y, x, unconnected_gradients=tf.UnconnectedGradients.ZERO)
print("Threaded dy_dx", dy_dx)
从下面的输出中可以看出,当线程用于相同的简单计算时,梯度被破坏了。
my_function(x) = tf.Tensor(36.0, shape=(), dtype=float32)
my_function_threaded(x) = tf.Tensor(36.0, shape=(), dtype=float32)
Simple dy_dx tf.Tensor(33.0, shape=(), dtype=float32)
Threaded dy_dx tf.Tensor(0.0, shape=(), dtype=float32)
关于如何在 GradientTape 中并行化我的函数的任何建议/想法将不胜感激?