2

我正在训练colaboratory有时会失去与服务器的连接。在 90 分钟不活动后,VM 也将重置。

我想tf.train.Saver.save()用回调覆盖,以便我可以按时间或步骤间隔将检查点复制到我的 Google Cloud Storage 帐户。

见:https ://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/training/saver.py

#
# override tf_saver, add callback after save()
#
import os
import tensorflow as tf
from tensorflow.python.training import saver as tf_saver

## override saver
class Saver_with_callback(tf_saver.Saver):
    _callback_op = None
    def __init__(self, callback_op, **kwargs ):
        self._callback_op = callback_op
        super(tf_saver.Saver, self).__init__(**kwargs)

    def save(self, sess, save_path, **kwargs):
        """
        see: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/training/saver.py

        """
        model_checkpoint_path = super.save(sess, save_path, **kwargs)
        if self._callback_op is not None:
            ## call on a new thread?
            self._callback_op(sess, save_path, 
                              model_checkpoint_path=model_checkpoint_path, 
                              **kwargs)
        return model_checkpoint_path

但是当我运行时出现错误slim.learning.train(saver=callback_saver)

    final_loss = slim.learning.train(train_op, log_dir, 
                        init_fn=init_fn,
                        global_step=global_step,
                        number_of_steps=steps,
                        save_summaries_secs=300,
                        save_interval_secs=600,
                        saver=callback_saver,
                        #    saver=tf_saver.Saver(),                                     
                       )

错误:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-41-dfb09327cccd> in <module>()
    149                         save_summaries_secs=300,
    150                         save_interval_secs=600,
--> 151                         saver=callback_saver,
    152 #                         saver=tf_saver.Saver,
    153                        )

/anaconda/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/slim/python/slim/learning.py in train(train_op, logdir, train_step_fn, train_step_kwargs, log_every_n_steps, graph, master, is_chief, global_step, number_of_steps, init_op, init_feed_dict, local_init_op, init_fn, ready_op, summary_op, save_summaries_secs, summary_writer, startup_delay_steps, saver, save_interval_secs, sync_optimizer, session_config, session_wrapper, trace_every_n_steps)
    730       save_summaries_secs=save_summaries_secs,
    731       save_model_secs=save_interval_secs,
--> 732       init_fn=init_fn)
    733 
    734   if summary_writer is not None:

/anaconda/anaconda3/lib/python3.6/site-packages/tensorflow/python/training/supervisor.py in __init__(self, graph, ready_op, ready_for_local_init_op, is_chief, init_op, init_feed_dict, local_init_op, logdir, summary_op, saver, global_step, save_summaries_secs, save_model_secs, recovery_wait_secs, stop_grace_secs, checkpoint_basename, session_manager, summary_writer, init_fn)
    304     self._meta_graph_def = meta_graph.create_meta_graph_def(
    305         graph_def=graph.as_graph_def(add_shapes=True),
--> 306         saver_def=self._saver.saver_def if self._saver else None)
    307     self._is_chief = is_chief
    308     self._coord = coordinator.Coordinator()

AttributeError: 'Saver_with_callback' object has no attribute 'saver_def'

``

isinstance(callback_saver, tf_saver.Saver)==True如果我使用saver=tf_saver.Saver()它可以正常工作。

4

1 回答 1

1

您没有调用in的__init__函数。tf_saver.SaverSaver_with_callback.__init__()

调用时调用__init__了父类的函数。tf_saver.Saversuper(tf_saver.Saver, self).__init__(**kwargs)

这是因为super(tf_saver.Saver, self)返回 的父类tf_saver.Saver,而不是tf_saver.Saver您期望的本身。

你应该打电话

super(Saver_with_callback, self).__init__(**kwargs)

或者对于 Python3,简单地说

super().__init__(**kwargs)

于 2018-02-14T18:36:02.597 回答