]> git.sesse.net Git - voxel-flow/blobdiff - voxel_flow_train.py
Fix and reenable summary outputs.
[voxel-flow] / voxel_flow_train.py
index 856ba857070fdd007f977eb7f7d76859f22987e0..6b2bcb440a7598a5cdda25ca76ab8a1e70b4b44b 100755 (executable)
@@ -52,7 +52,7 @@ def train(dataset_frame1, dataset_frame2, dataset_frame3):
 
     # Prepare model.
     model = Voxel_flow_model()
-    prediction = model.inference(input_placeholder) d
+    prediction = model.inference(input_placeholder)
     # reproduction_loss, prior_loss = model.loss(prediction, target_placeholder)
     reproduction_loss = model.loss(prediction, target_placeholder)
     # total_loss = reproduction_loss + prior_loss
@@ -68,18 +68,19 @@ def train(dataset_frame1, dataset_frame2, dataset_frame3):
 
     # Create summaries
     summaries = tf.get_collection(tf.GraphKeys.SUMMARIES)
-    summaries.append(tf.scalar_summary('total_loss', total_loss))
-    summaries.append(tf.scalar_summary('reproduction_loss', reproduction_loss))
-    # summaries.append(tf.scalar_summary('prior_loss', prior_loss))
-    summaries.append(tf.image_summary('Input Image', input_placeholder, 3))
-    summaries.append(tf.image_summary('Output Image', prediction, 3))
-    summaries.append(tf.image_summary('Target Image', target_placeholder, 3))
+    summaries.append(tf.summary.scalar('total_loss', total_loss))
+    summaries.append(tf.summary.scalar('reproduction_loss', reproduction_loss))
+    # summaries.append(tf.summary.scalar('prior_loss', prior_loss))
+    summaries.append(tf.summary.image('Input Image (before)', input_placeholder[:, :, :, 0:3], 3));
+    summaries.append(tf.summary.image('Input Image (after)', input_placeholder[:, :, :, 3:6], 3));
+    summaries.append(tf.summary.image('Output Image', prediction, 3))
+    summaries.append(tf.summary.image('Target Image', target_placeholder, 3))
 
     # Create a saver.
     saver = tf.train.Saver(tf.all_variables())
 
     # Build the summary operation from the last tower summaries.
-    summary_op = tf.merge_all_summaries()
+    summary_op = tf.summary.merge_all()
 
     # Build an initialization operation to run below.
     init = tf.initialize_all_variables()
@@ -87,7 +88,7 @@ def train(dataset_frame1, dataset_frame2, dataset_frame3):
     sess.run(init)
 
     # Summary Writter
-    summary_writer = tf.train.SummaryWriter(
+    summary_writer = tf.summary.FileWriter(
       FLAGS.train_dir,
       graph=sess.graph)
 
@@ -118,7 +119,7 @@ def train(dataset_frame1, dataset_frame2, dataset_frame3):
     # load_fn_frame3 = partial(dataset_frame3.process_func)
     # p_queue_frame3 = PrefetchQueue(load_fn_frame3, data_list_frame3, FLAGS.batch_size, shuffle=False, num_workers=num_workers)
 
-    for step in xrange(0, FLAGS.max_steps):
+    for step in range(0, FLAGS.max_steps):
       batch_idx = step % epoch_num
       
       batch_data_list_frame1 = data_list_frame1[int(batch_idx * FLAGS.batch_size) : int((batch_idx + 1) * FLAGS.batch_size)]
@@ -149,11 +150,13 @@ def train(dataset_frame1, dataset_frame2, dataset_frame3):
         shuffle(data_list_frame3)
         print('Epoch Number: %d' % int(step / epoch_num))
       
-      # Output Summary 
       if step % 10 == 0:
-        # summary_str = sess.run(summary_op, feed_dict = feed_dict)
-        # summary_writer.add_summary(summary_str, step)
-             print("Loss at step %d: %f" % (step, loss_value))
+        print("Loss at step %d: %f" % (step, loss_value))
+
+      if step % 100 == 0:
+        # Output Summary 
+        summary_str = sess.run(summary_op, feed_dict = feed_dict)
+        summary_writer.add_summary(summary_str, step)
 
       if step % 500 == 0:
         # Run a batch of images