]> git.sesse.net Git - voxel-flow/commitdiff
Enable checkpoint restore from file.
authorSteinar H. Gunderson <sgunderson@bigfoot.com>
Thu, 8 Feb 2018 18:54:49 +0000 (18:54 +0000)
committerSteinar H. Gunderson <sgunderson@bigfoot.com>
Thu, 8 Feb 2018 18:54:49 +0000 (18:54 +0000)
voxel_flow_train.py

index 6b2bcb440a7598a5cdda25ca76ab8a1e70b4b44b..e1fa96f2c3102f226148b18283eb1f5855cec4eb 100755 (executable)
@@ -82,10 +82,21 @@ def train(dataset_frame1, dataset_frame2, dataset_frame3):
     # Build the summary operation from the last tower summaries.
     summary_op = tf.summary.merge_all()
 
-    # Build an initialization operation to run below.
-    init = tf.initialize_all_variables()
-    sess = tf.Session()
-    sess.run(init)
+    # Restore checkpoint from file.
+    if FLAGS.pretrained_model_checkpoint_path:
+      sess = tf.Session()
+      assert tf.gfile.Exists(FLAGS.pretrained_model_checkpoint_path)
+      ckpt = tf.train.get_checkpoint_state(
+               FLAGS.pretrained_model_checkpoint_path)
+      restorer = tf.train.Saver()
+      restorer.restore(sess, ckpt.model_checkpoint_path)
+      print('%s: Pre-trained model restored from %s' %
+        (datetime.now(), ckpt.model_checkpoint_path))
+    else:
+      # Build an initialization operation to run below.
+      init = tf.initialize_all_variables()
+      sess = tf.Session()
+      sess.run(init)
 
     # Summary Writter
     summary_writer = tf.summary.FileWriter(