2

我试图弄清楚如何编写正确的 input_fn。我已经使用 Tensorflow 估计器和数据集重建了一个 keras UNet 实现。它正在工作,但它比 Keras 实现慢得多。我的输入管道直接从 SSD 读取文件,而 Keras 实现将整个 DataSet 加载到内存。

我想在 DataSets 中做同样的事情,不幸的是,缓存构建ds.cache()在每次中断时都会被清除estimator.train。因为我试图只使用一个 GPU,所以我必须中断它来验证模型。

我看了一下实现,Dataset.cache我猜缓存绑定到保存数据集的图上,每次训练中断时都会释放这个图。

我有什么选择可以让我的输入管道留在张量流图中?我想避免在 python 和 tensorflow 之间来回传递整个训练数据。

我可以在另一个图表中使用张量吗?所以我可以用我的 input_fn 创建一个外部(到 Estimator 的 API)图,然后以某种方式在由 estimator.train 创建的图中使用它。

沿着这条线的东西:

train_graph = tf.Graph()
with training_graph.as_default():
    train_get_next = input_fn(training_set, ...) () # tensor returned by it.get_next()

def train_input_fn_with_ext_cache():
    batch = # some how obtain result of train_get_next 
    return batch

train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn_with_ext_cache,  max_steps=3000)
eval_spec = ...

tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
4

0 回答 0