]> git.sesse.net Git - voxel-flow/blobdiff - voxel_flow_train.py
Enable checkpoint restore from file.
[voxel-flow] / voxel_flow_train.py
index 4440e798410e9ec4b16f638146a5d4fcf8862be4..e1fa96f2c3102f226148b18283eb1f5855cec4eb 100755 (executable)
@@ -71,7 +71,8 @@ def train(dataset_frame1, dataset_frame2, dataset_frame3):
     summaries.append(tf.summary.scalar('total_loss', total_loss))
     summaries.append(tf.summary.scalar('reproduction_loss', reproduction_loss))
     # summaries.append(tf.summary.scalar('prior_loss', prior_loss))
-    summaries.append(tf.summary.image('Input Image', input_placeholder, 3))
+    summaries.append(tf.summary.image('Input Image (before)', input_placeholder[:, :, :, 0:3], 3));
+    summaries.append(tf.summary.image('Input Image (after)', input_placeholder[:, :, :, 3:6], 3));
     summaries.append(tf.summary.image('Output Image', prediction, 3))
     summaries.append(tf.summary.image('Target Image', target_placeholder, 3))
 
@@ -81,10 +82,21 @@ def train(dataset_frame1, dataset_frame2, dataset_frame3):
     # Build the summary operation from the last tower summaries.
     summary_op = tf.summary.merge_all()
 
-    # Build an initialization operation to run below.
-    init = tf.initialize_all_variables()
-    sess = tf.Session()
-    sess.run(init)
+    # Restore checkpoint from file.
+    if FLAGS.pretrained_model_checkpoint_path:
+      sess = tf.Session()
+      assert tf.gfile.Exists(FLAGS.pretrained_model_checkpoint_path)
+      ckpt = tf.train.get_checkpoint_state(
+               FLAGS.pretrained_model_checkpoint_path)
+      restorer = tf.train.Saver()
+      restorer.restore(sess, ckpt.model_checkpoint_path)
+      print('%s: Pre-trained model restored from %s' %
+        (datetime.now(), ckpt.model_checkpoint_path))
+    else:
+      # Build an initialization operation to run below.
+      init = tf.initialize_all_variables()
+      sess = tf.Session()
+      sess.run(init)
 
     # Summary Writter
     summary_writer = tf.summary.FileWriter(
@@ -149,11 +161,13 @@ def train(dataset_frame1, dataset_frame2, dataset_frame3):
         shuffle(data_list_frame3)
         print('Epoch Number: %d' % int(step / epoch_num))
       
-      # Output Summary 
       if step % 10 == 0:
-        # summary_str = sess.run(summary_op, feed_dict = feed_dict)
-        # summary_writer.add_summary(summary_str, step)
-             print("Loss at step %d: %f" % (step, loss_value))
+        print("Loss at step %d: %f" % (step, loss_value))
+
+      if step % 100 == 0:
+        # Output Summary 
+        summary_str = sess.run(summary_op, feed_dict = feed_dict)
+        summary_writer.add_summary(summary_str, step)
 
       if step % 500 == 0:
         # Run a batch of images