data_list_frame1 = dataset_frame1.read_data_list_file()
dataset_frame1 = tf.data.Dataset.from_tensor_slices(tf.constant(data_list_frame1))
dataset_frame1 = dataset_frame1.repeat().shuffle(buffer_size=1000000, seed=1).map(_read_image)
+ dataset_frame1 = dataset_frame1.prefetch(100)
data_list_frame2 = dataset_frame2.read_data_list_file()
dataset_frame2 = tf.data.Dataset.from_tensor_slices(tf.constant(data_list_frame2))
dataset_frame2 = dataset_frame2.repeat().shuffle(buffer_size=1000000, seed=1).map(_read_image)
+ dataset_frame2 = dataset_frame2.prefetch(100)
data_list_frame3 = dataset_frame3.read_data_list_file()
dataset_frame3 = tf.data.Dataset.from_tensor_slices(tf.constant(data_list_frame3))
dataset_frame3 = dataset_frame3.repeat().shuffle(buffer_size=1000000, seed=1).map(_read_image)
+ dataset_frame3 = dataset_frame3.prefetch(100)
batch_frame1 = dataset_frame1.batch(FLAGS.batch_size).make_initializable_iterator()
batch_frame2 = dataset_frame2.batch(FLAGS.batch_size).make_initializable_iterator()
# Prepare model.
model = Voxel_flow_model()
- prediction = model.inference(input_placeholder)
+ prediction, flow = model.inference(input_placeholder)
# reproduction_loss, prior_loss = model.loss(prediction, target_placeholder)
reproduction_loss = model.loss(prediction, target_placeholder)
# total_loss = reproduction_loss + prior_loss
summaries.append(tf.summary.image('Input Image (after)', input_placeholder[:, :, :, 3:6], 3));
summaries.append(tf.summary.image('Output Image', prediction, 3))
summaries.append(tf.summary.image('Target Image', target_placeholder, 3))
+ summaries.append(tf.summary.image('Flow', flow, 3))
# Create a saver.
saver = tf.train.Saver(tf.all_variables())
restorer.restore(sess, ckpt.model_checkpoint_path)
print('%s: Pre-trained model restored from %s' %
(datetime.now(), ckpt.model_checkpoint_path))
+ sess.run([batch_frame1.initializer, batch_frame2.initializer, batch_frame3.initializer])
else:
# Build an initialization operation to run below.
init = tf.initialize_all_variables()
imwrite(file_name_label, target_np[i,:,:,:])
# Save checkpoint
- if step % 5000 == 0 or (step +1) == FLAGS.max_steps:
+ if step % 500 == 0 or (step +1) == FLAGS.max_steps:
checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
saver.save(sess, checkpoint_path, global_step=step)
# target_resized = tf.image.resize_area(target_placeholder,[128, 128])
# Prepare model.
- model = Voxel_flow_model(is_train=True)
+ model, flow = Voxel_flow_model(is_train=True)
prediction = model.inference(input_placeholder)
# reproduction_loss, prior_loss = model.loss(prediction, target_placeholder)
reproduction_loss = model.loss(prediction, target_placeholder)