0

我正在编写 CTGAN 代码并希望它以分布式方式进行训练。因此我正在使用tf.distribute.Strategy.mirroredstrategy() 在我正在关注的tensorflow 文档教程中,提到你应该从一个名为 Distribute_trainstep() 的函数中调用你的 train_step 代码,并用 tf.function 装饰它. 像这样:

@tf.function
def distributed_train_step(dataset_inputs):
  per_replica_losses = strategy.run(train_step, args=(dataset_inputs,))
  return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,
                         axis=None)

这很简单,但是在 tf.function 中装饰 train_step 中的所有内容会使 train_step 中的所有 numpy 代码变得无用。我应该怎么办?是否有替代方法,仅通过有选择地在 train_step 中包装函数?还是我必须用 tensorflow 替换所有 numpy 操作?

4

0 回答 0