summaries.append(tf.summary.scalar('total_loss', total_loss))
summaries.append(tf.summary.scalar('reproduction_loss', reproduction_loss))
# summaries.append(tf.summary.scalar('prior_loss', prior_loss))
- summaries.append(tf.summary.image('Input Image', input_placeholder, 3))
+ summaries.append(tf.summary.image('Input Image (before)', input_placeholder[:, :, :, 0:3], 3));
+ summaries.append(tf.summary.image('Input Image (after)', input_placeholder[:, :, :, 3:6], 3));
summaries.append(tf.summary.image('Output Image', prediction, 3))
summaries.append(tf.summary.image('Target Image', target_placeholder, 3))
# Build the summary operation from the last tower summaries.
summary_op = tf.summary.merge_all()
- # Build an initialization operation to run below.
- init = tf.initialize_all_variables()
- sess = tf.Session()
- sess.run(init)
+ # Restore checkpoint from file.
+ if FLAGS.pretrained_model_checkpoint_path:
+ sess = tf.Session()
+ assert tf.gfile.Exists(FLAGS.pretrained_model_checkpoint_path)
+ ckpt = tf.train.get_checkpoint_state(
+ FLAGS.pretrained_model_checkpoint_path)
+ restorer = tf.train.Saver()
+ restorer.restore(sess, ckpt.model_checkpoint_path)
+ print('%s: Pre-trained model restored from %s' %
+ (datetime.now(), ckpt.model_checkpoint_path))
+ else:
+ # Build an initialization operation to run below.
+ init = tf.initialize_all_variables()
+ sess = tf.Session()
+ sess.run(init)
# Summary Writter
summary_writer = tf.summary.FileWriter(
shuffle(data_list_frame3)
print('Epoch Number: %d' % int(step / epoch_num))
- # Output Summary
if step % 10 == 0:
- # summary_str = sess.run(summary_op, feed_dict = feed_dict)
- # summary_writer.add_summary(summary_str, step)
- print("Loss at step %d: %f" % (step, loss_value))
+ print("Loss at step %d: %f" % (step, loss_value))
+
+ if step % 100 == 0:
+ # Output Summary
+ summary_str = sess.run(summary_op, feed_dict = feed_dict)
+ summary_writer.add_summary(summary_str, step)
if step % 500 == 0:
# Run a batch of images