据我了解Estimator
API,该方法train
在给定数据上进行训练,直到input_fn
函数引发异常或达到步骤数。因此,如果我想循环直到估计器收敛(对于它的某些定义),我需要自己编写循环并测试我的标准。对于估算器来说,这样的事情e
:
prevloss = 999999999999
while True:
e.train(input_fn)
loss = e.evaluate(input_fn)['loss']
if abs(prevloss - loss) < 1e-4:
break
prevloss = loss
但是有一些事情对我来说仍然是模糊的。
- 什么时候调用输入函数?它应该总是返回相同的数据吗?它的正确用途是什么?
- 数据会在每次迭代时上传到 GPU 吗?
- 如果是这样,我该如何避免呢?
- 输入函数返回的数据是否只是作为常量嵌入到图中?
- 如果是这样,我如何使它可喂养?
- 总体而言,启动一个 tensorflow 作业需要 0.8 秒(即每次迭代 1.6 秒没有取得进展),这个循环在 GPU 上不是更好吗?
- 如果是这样,我该怎么做?