# Prepare model.
model = Voxel_flow_model()
- prediction = model.inference(input_placeholder) d
+ prediction = model.inference(input_placeholder)
# reproduction_loss, prior_loss = model.loss(prediction, target_placeholder)
reproduction_loss = model.loss(prediction, target_placeholder)
# total_loss = reproduction_loss + prior_loss
# Create summaries
summaries = tf.get_collection(tf.GraphKeys.SUMMARIES)
- summaries.append(tf.scalar_summary('total_loss', total_loss))
- summaries.append(tf.scalar_summary('reproduction_loss', reproduction_loss))
- # summaries.append(tf.scalar_summary('prior_loss', prior_loss))
- summaries.append(tf.image_summary('Input Image', input_placeholder, 3))
- summaries.append(tf.image_summary('Output Image', prediction, 3))
- summaries.append(tf.image_summary('Target Image', target_placeholder, 3))
+ summaries.append(tf.summary.scalar('total_loss', total_loss))
+ summaries.append(tf.summary.scalar('reproduction_loss', reproduction_loss))
+ # summaries.append(tf.summary.scalar('prior_loss', prior_loss))
+ summaries.append(tf.summary.image('Input Image (before)', input_placeholder[:, :, :, 0:3], 3));
+ summaries.append(tf.summary.image('Input Image (after)', input_placeholder[:, :, :, 3:6], 3));
+ summaries.append(tf.summary.image('Output Image', prediction, 3))
+ summaries.append(tf.summary.image('Target Image', target_placeholder, 3))
# Create a saver.
saver = tf.train.Saver(tf.all_variables())
# Build the summary operation from the last tower summaries.
- summary_op = tf.merge_all_summaries()
+ summary_op = tf.summary.merge_all()
# Build an initialization operation to run below.
init = tf.initialize_all_variables()
sess.run(init)
# Summary Writter
- summary_writer = tf.train.SummaryWriter(
+ summary_writer = tf.summary.FileWriter(
FLAGS.train_dir,
graph=sess.graph)
# load_fn_frame3 = partial(dataset_frame3.process_func)
# p_queue_frame3 = PrefetchQueue(load_fn_frame3, data_list_frame3, FLAGS.batch_size, shuffle=False, num_workers=num_workers)
- for step in xrange(0, FLAGS.max_steps):
+ for step in range(0, FLAGS.max_steps):
batch_idx = step % epoch_num
batch_data_list_frame1 = data_list_frame1[int(batch_idx * FLAGS.batch_size) : int((batch_idx + 1) * FLAGS.batch_size)]
shuffle(data_list_frame3)
print('Epoch Number: %d' % int(step / epoch_num))
- # Output Summary
if step % 10 == 0:
- # summary_str = sess.run(summary_op, feed_dict = feed_dict)
- # summary_writer.add_summary(summary_str, step)
- print("Loss at step %d: %f" % (step, loss_value))
+ print("Loss at step %d: %f" % (step, loss_value))
+
+ if step % 100 == 0:
+ # Output Summary
+ summary_str = sess.run(summary_op, feed_dict = feed_dict)
+ summary_writer.add_summary(summary_str, step)
if step % 500 == 0:
# Run a batch of images