tf.app.flags.DEFINE_float('initial_learning_rate', 0.0003,
"""Initial learning rate.""")
+def _read_image(filename):
+ image_string = tf.read_file(filename)
+ image_decoded = tf.image.decode_image(image_string, channels=3)
+ image_decoded.set_shape([256, 256, 3])
+ return tf.cast(image_decoded, dtype=tf.float32) / 127.5 - 1.0
def train(dataset_frame1, dataset_frame2, dataset_frame3):
"""Trains a model."""
with tf.Graph().as_default():
+ # Create input.
+ data_list_frame1 = dataset_frame1.read_data_list_file()
+ dataset_frame1 = tf.data.Dataset.from_tensor_slices(tf.constant(data_list_frame1))
+ dataset_frame1 = dataset_frame1.repeat().shuffle(buffer_size=1000000, seed=1).map(_read_image)
+
+ data_list_frame2 = dataset_frame2.read_data_list_file()
+ dataset_frame2 = tf.data.Dataset.from_tensor_slices(tf.constant(data_list_frame2))
+ dataset_frame2 = dataset_frame2.repeat().shuffle(buffer_size=1000000, seed=1).map(_read_image)
+
+ data_list_frame3 = dataset_frame3.read_data_list_file()
+ dataset_frame3 = tf.data.Dataset.from_tensor_slices(tf.constant(data_list_frame3))
+ dataset_frame3 = dataset_frame3.repeat().shuffle(buffer_size=1000000, seed=1).map(_read_image)
+
+ batch_frame1 = dataset_frame1.batch(FLAGS.batch_size).make_initializable_iterator()
+ batch_frame2 = dataset_frame2.batch(FLAGS.batch_size).make_initializable_iterator()
+ batch_frame3 = dataset_frame3.batch(FLAGS.batch_size).make_initializable_iterator()
+
# Create input and target placeholder.
- input_placeholder = tf.placeholder(tf.float32, shape=(None, 256, 256, 6))
- target_placeholder = tf.placeholder(tf.float32, shape=(None, 256, 256, 3))
+ input_placeholder = tf.concat([batch_frame1.get_next(), batch_frame3.get_next()], 3)
+ target_placeholder = batch_frame2.get_next()
# input_resized = tf.image.resize_area(input_placeholder, [128, 128])
# target_resized = tf.image.resize_area(target_placeholder,[128, 128])
# Prepare model.
model = Voxel_flow_model()
- prediction = model.inference(input_placeholder) d
+ prediction, flow = model.inference(input_placeholder)
# reproduction_loss, prior_loss = model.loss(prediction, target_placeholder)
reproduction_loss = model.loss(prediction, target_placeholder)
# total_loss = reproduction_loss + prior_loss
# Create summaries
summaries = tf.get_collection(tf.GraphKeys.SUMMARIES)
- summaries.append(tf.scalar_summary('total_loss', total_loss))
- summaries.append(tf.scalar_summary('reproduction_loss', reproduction_loss))
- # summaries.append(tf.scalar_summary('prior_loss', prior_loss))
- summaries.append(tf.image_summary('Input Image', input_placeholder, 3))
- summaries.append(tf.image_summary('Output Image', prediction, 3))
- summaries.append(tf.image_summary('Target Image', target_placeholder, 3))
+ summaries.append(tf.summary.scalar('total_loss', total_loss))
+ summaries.append(tf.summary.scalar('reproduction_loss', reproduction_loss))
+ # summaries.append(tf.summary.scalar('prior_loss', prior_loss))
+ summaries.append(tf.summary.image('Input Image (before)', input_placeholder[:, :, :, 0:3], 3));
+ summaries.append(tf.summary.image('Input Image (after)', input_placeholder[:, :, :, 3:6], 3));
+ summaries.append(tf.summary.image('Output Image', prediction, 3))
+ summaries.append(tf.summary.image('Target Image', target_placeholder, 3))
+ summaries.append(tf.summary.image('Flow', flow, 3))
# Create a saver.
saver = tf.train.Saver(tf.all_variables())
# Build the summary operation from the last tower summaries.
- summary_op = tf.merge_all_summaries()
+ summary_op = tf.summary.merge_all()
- # Build an initialization operation to run below.
- init = tf.initialize_all_variables()
- sess = tf.Session()
- sess.run(init)
+ # Restore checkpoint from file.
+ if FLAGS.pretrained_model_checkpoint_path:
+ sess = tf.Session()
+ assert tf.gfile.Exists(FLAGS.pretrained_model_checkpoint_path)
+ ckpt = tf.train.get_checkpoint_state(
+ FLAGS.pretrained_model_checkpoint_path)
+ restorer = tf.train.Saver()
+ restorer.restore(sess, ckpt.model_checkpoint_path)
+ print('%s: Pre-trained model restored from %s' %
+ (datetime.now(), ckpt.model_checkpoint_path))
+ else:
+ # Build an initialization operation to run below.
+ init = tf.initialize_all_variables()
+ sess = tf.Session()
+ sess.run([init, batch_frame1.initializer, batch_frame2.initializer, batch_frame3.initializer])
# Summary Writter
- summary_writer = tf.train.SummaryWriter(
+ summary_writer = tf.summary.FileWriter(
FLAGS.train_dir,
graph=sess.graph)
- # Training loop using feed dict method.
- data_list_frame1 = dataset_frame1.read_data_list_file()
- random.seed(1)
- shuffle(data_list_frame1)
-
- data_list_frame2 = dataset_frame2.read_data_list_file()
- random.seed(1)
- shuffle(data_list_frame2)
-
- data_list_frame3 = dataset_frame3.read_data_list_file()
- random.seed(1)
- shuffle(data_list_frame3)
-
data_size = len(data_list_frame1)
epoch_num = int(data_size / FLAGS.batch_size)
- # num_workers = 1
-
- # load_fn_frame1 = partial(dataset_frame1.process_func)
- # p_queue_frame1 = PrefetchQueue(load_fn_frame1, data_list_frame1, FLAGS.batch_size, shuffle=False, num_workers=num_workers)
-
- # load_fn_frame2 = partial(dataset_frame2.process_func)
- # p_queue_frame2 = PrefetchQueue(load_fn_frame2, data_list_frame2, FLAGS.batch_size, shuffle=False, num_workers=num_workers)
-
- # load_fn_frame3 = partial(dataset_frame3.process_func)
- # p_queue_frame3 = PrefetchQueue(load_fn_frame3, data_list_frame3, FLAGS.batch_size, shuffle=False, num_workers=num_workers)
-
- for step in xrange(0, FLAGS.max_steps):
+ for step in range(0, FLAGS.max_steps):
batch_idx = step % epoch_num
- batch_data_list_frame1 = data_list_frame1[int(batch_idx * FLAGS.batch_size) : int((batch_idx + 1) * FLAGS.batch_size)]
- batch_data_list_frame2 = data_list_frame2[int(batch_idx * FLAGS.batch_size) : int((batch_idx + 1) * FLAGS.batch_size)]
- batch_data_list_frame3 = data_list_frame3[int(batch_idx * FLAGS.batch_size) : int((batch_idx + 1) * FLAGS.batch_size)]
-
- # Load batch data.
- batch_data_frame1 = np.array([dataset_frame1.process_func(line) for line in batch_data_list_frame1])
- batch_data_frame2 = np.array([dataset_frame2.process_func(line) for line in batch_data_list_frame2])
- batch_data_frame3 = np.array([dataset_frame3.process_func(line) for line in batch_data_list_frame3])
-
- # batch_data_frame1 = p_queue_frame1.get_batch()
- # batch_data_frame2 = p_queue_frame2.get_batch()
- # batch_data_frame3 = p_queue_frame3.get_batch()
-
- feed_dict = {input_placeholder: np.concatenate((batch_data_frame1, batch_data_frame3), 3), target_placeholder: batch_data_frame2}
-
# Run single step update.
- _, loss_value = sess.run([update_op, total_loss], feed_dict = feed_dict)
+ _, loss_value = sess.run([update_op, total_loss])
if batch_idx == 0:
- # Shuffle data at each epoch.
- random.seed(1)
- shuffle(data_list_frame1)
- random.seed(1)
- shuffle(data_list_frame2)
- random.seed(1)
- shuffle(data_list_frame3)
print('Epoch Number: %d' % int(step / epoch_num))
- # Output Summary
if step % 10 == 0:
- # summary_str = sess.run(summary_op, feed_dict = feed_dict)
- # summary_writer.add_summary(summary_str, step)
- print("Loss at step %d: %f" % (step, loss_value))
+ print("Loss at step %d: %f" % (step, loss_value))
+
+ if step % 100 == 0:
+ # Output Summary
+ summary_str = sess.run(summary_op)
+ summary_writer.add_summary(summary_str, step)
if step % 500 == 0:
# Run a batch of images
- prediction_np, target_np = sess.run([prediction, target_placeholder], feed_dict = feed_dict)
+ prediction_np, target_np = sess.run([prediction, target_placeholder])
for i in range(0,prediction_np.shape[0]):
file_name = FLAGS.train_image_dir+str(i)+'_out.png'
file_name_label = FLAGS.train_image_dir+str(i)+'_gt.png'
# target_resized = tf.image.resize_area(target_placeholder,[128, 128])
# Prepare model.
- model = Voxel_flow_model(is_train=True)
+ model, flow = Voxel_flow_model(is_train=True)
prediction = model.inference(input_placeholder)
# reproduction_loss, prior_loss = model.loss(prediction, target_placeholder)
reproduction_loss = model.loss(prediction, target_placeholder)