X-Git-Url: https://git.sesse.net/?a=blobdiff_plain;f=voxel_flow_train.py;h=6b2bcb440a7598a5cdda25ca76ab8a1e70b4b44b;hb=d55e5f71d6a84b7f39ab143bc5b56bb24d35cba1;hp=f6d3abb67a0f062e6750c0b3bd8471174be1b8f6;hpb=c74057bc77fb9bc4eb75668896adfb85ddb99251;p=voxel-flow diff --git a/voxel_flow_train.py b/voxel_flow_train.py index f6d3abb..6b2bcb4 100755 --- a/voxel_flow_train.py +++ b/voxel_flow_train.py @@ -71,7 +71,8 @@ def train(dataset_frame1, dataset_frame2, dataset_frame3): summaries.append(tf.summary.scalar('total_loss', total_loss)) summaries.append(tf.summary.scalar('reproduction_loss', reproduction_loss)) # summaries.append(tf.summary.scalar('prior_loss', prior_loss)) - summaries.append(tf.summary.image('Input Image', input_placeholder, 3)) + summaries.append(tf.summary.image('Input Image (before)', input_placeholder[:, :, :, 0:3], 3)); + summaries.append(tf.summary.image('Input Image (after)', input_placeholder[:, :, :, 3:6], 3)); summaries.append(tf.summary.image('Output Image', prediction, 3)) summaries.append(tf.summary.image('Target Image', target_placeholder, 3)) @@ -118,7 +119,7 @@ def train(dataset_frame1, dataset_frame2, dataset_frame3): # load_fn_frame3 = partial(dataset_frame3.process_func) # p_queue_frame3 = PrefetchQueue(load_fn_frame3, data_list_frame3, FLAGS.batch_size, shuffle=False, num_workers=num_workers) - for step in xrange(0, FLAGS.max_steps): + for step in range(0, FLAGS.max_steps): batch_idx = step % epoch_num batch_data_list_frame1 = data_list_frame1[int(batch_idx * FLAGS.batch_size) : int((batch_idx + 1) * FLAGS.batch_size)] @@ -149,11 +150,13 @@ def train(dataset_frame1, dataset_frame2, dataset_frame3): shuffle(data_list_frame3) print('Epoch Number: %d' % int(step / epoch_num)) - # Output Summary if step % 10 == 0: - # summary_str = sess.run(summary_op, feed_dict = feed_dict) - # summary_writer.add_summary(summary_str, step) - print("Loss at step %d: %f" % (step, loss_value)) + print("Loss at step %d: %f" % (step, loss_value)) + + if step % 100 == 0: + # Output Summary + summary_str = sess.run(summary_op, feed_dict = feed_dict) + summary_writer.add_summary(summary_str, step) if step % 500 == 0: # Run a batch of images