2

这个问题与低级 Tensorflow 1.x API 有关。给定 a TensorSession.run()我不清楚 Tensorflow 如何遍历计算图。

假设我有一些这样的代码:

a = tf.constant(1.0)
b = tf.subtract(a, 1.0)
c = tf.add(b, 2.0)
d = tf.multiply(c,3)

sess = tf.Session()
sess.run(d)

减法、加法和乘法运算并不都存储在 Tensord中,对吧?我知道Tensor对象有graphop字段;这些字段是否有一些如何递归访问以获得计算所需的所有操作d

编辑:添加输出

print(tf.get_default_graph().as_graph_def())
node {
  name: "Const"
  op: "Const"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_FLOAT
        tensor_shape {
        }
        float_val: 1.0
      }
    }
  }
}
node {
  name: "Sub/y"
  op: "Const"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_FLOAT
        tensor_shape {
        }
        float_val: 1.0
      }
    }
  }
}
node {
  name: "Sub"
  op: "Sub"
  input: "Const"
  input: "Sub/y"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
}
node {
  name: "Add/y"
  op: "Const"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_FLOAT
        tensor_shape {
        }
        float_val: 2.0
      }
    }
  }
}
node {
  name: "Add"
  op: "Add"
  input: "Sub"
  input: "Add/y"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
}
node {
  name: "Mul/y"
  op: "Const"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_FLOAT
        tensor_shape {
        }
        float_val: 3.0
      }
    }
  }
}
node {
  name: "Mul"
  op: "Mul"
  input: "Add"
  input: "Mul/y"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
}
versions {
  producer: 38
}
4

1 回答 1

2

这就是 Tensorflow 的静态计算图的重点。当您构建图表时,Tensorflow 会在后台隐式构建一个静态图表。然后,当您在图中执行一个节点时,Tensorflow 知道导致该节点的确切操作集。这有几个好处:

  1. 节省计算,因为只有指向您想要的节点的子图才会被执行。
  2. 整个计算被分成小的可微部分。
  3. 模型的每个部分都可以在不同的设备上执行,从而实现巨大的加速。

使用此命令,查看每个节点的输入:

print(tf.get_default_graph().as_graph_def())

例如,如果您在小图上执行此操作,您将看到以下内容,从节点开始d = tf.multiply(c,3)

name: "Mul"
op: "Mul"
input: "Add"

然后c = tf.add(b, 2.0)

name: "Add"
op: "Add"
input: "Sub"

然后b = tf.subtract(a, 1.0)

name: "Sub"
op: "Sub"
input: "Const"

最后a = tf.constant(1.0)

name: "Const"
op: "Const"
于 2019-08-13T00:01:35.057 回答