From: Steinar H. Gunderson Date: Thu, 8 Feb 2018 18:54:49 +0000 (+0000) Subject: Enable checkpoint restore from file. X-Git-Url: https://git.sesse.net/?p=voxel-flow;a=commitdiff_plain;h=e46b852bfceb21baf504f2dbe39cf1fd31f47cc3 Enable checkpoint restore from file. --- diff --git a/voxel_flow_train.py b/voxel_flow_train.py index 6b2bcb4..e1fa96f 100755 --- a/voxel_flow_train.py +++ b/voxel_flow_train.py @@ -82,10 +82,21 @@ def train(dataset_frame1, dataset_frame2, dataset_frame3): # Build the summary operation from the last tower summaries. summary_op = tf.summary.merge_all() - # Build an initialization operation to run below. - init = tf.initialize_all_variables() - sess = tf.Session() - sess.run(init) + # Restore checkpoint from file. + if FLAGS.pretrained_model_checkpoint_path: + sess = tf.Session() + assert tf.gfile.Exists(FLAGS.pretrained_model_checkpoint_path) + ckpt = tf.train.get_checkpoint_state( + FLAGS.pretrained_model_checkpoint_path) + restorer = tf.train.Saver() + restorer.restore(sess, ckpt.model_checkpoint_path) + print('%s: Pre-trained model restored from %s' % + (datetime.now(), ckpt.model_checkpoint_path)) + else: + # Build an initialization operation to run below. + init = tf.initialize_all_variables() + sess = tf.Session() + sess.run(init) # Summary Writter summary_writer = tf.summary.FileWriter(