0

签名部分的 tensorflow 文档中,我们有以下代码片段

@tf.function
def train(model, optimizer):
  train_ds = mnist_dataset()
  step = 0
  loss = 0.0
  accuracy = 0.0
  for x, y in train_ds:
    step += 1
    loss = train_one_step(model, optimizer, x, y)
    if tf.equal(step % 10, 0):
      tf.print('Step', step, ': loss', loss, '; accuracy', compute_accuracy.result())
  return step, loss, accuracy

step, loss, accuracy = train(model, optimizer)
print('Final step', step, ': loss', loss, '; accuracy', compute_accuracy.result())

我有一个关于变量的小问题step,它是一个整数而不是张量,签名支持内置的 python 类型,例如整数。因此,tf.equal(step%10,0)可以将其更改为简单step%10 == 0对吗?

4

2 回答 2

3

你是对的。整数变量 step 仍然是 Python 变量,即使转换为它的图形表示。调用可以看到转换结果tf.autograph.to_code(train.python_function)

不报告所有代码,只报告step变量相关部分,你会看到

  def loop_body(loop_vars, loss_1, step_1):
    with ag__.function_scope('loop_body'):
      x, y = loop_vars
      step_1 += 1

仍然是一个 python 操作(否则step_1.assign_add(1)如果第 1 步是 a tf.Tensor)。

有关签名和 tf.function 的更多信息,我建议阅读文章https://pgaleone.eu/tensorflow/tf.function/2019/03/21/dissecting-tf-function-part-1/轻松解释什么时候会发生什么一个函数被转换。

于 2019-03-22T14:21:17.660 回答
0

虽然这在生成的代码中不可见,但 step 变量实际上会被 for 循环自动装箱为张量,该循环正在转换为 TF while_loop。

您可以通过添加打印语句来验证:

    loss = train_one_step(model, optimizer, x, y)
    print(step)
    if tf.equal(step % 10, 0):
于 2019-04-01T01:43:43.507 回答