data_list_frame1 = dataset_frame1.read_data_list_file()
dataset_frame1 = tf.data.Dataset.from_tensor_slices(tf.constant(data_list_frame1))
dataset_frame1 = dataset_frame1.repeat().shuffle(buffer_size=1000000, seed=1).map(_read_image)
+ dataset_frame1 = dataset_frame1.prefetch(100)
data_list_frame2 = dataset_frame2.read_data_list_file()
dataset_frame2 = tf.data.Dataset.from_tensor_slices(tf.constant(data_list_frame2))
dataset_frame2 = dataset_frame2.repeat().shuffle(buffer_size=1000000, seed=1).map(_read_image)
+ dataset_frame2 = dataset_frame2.prefetch(100)
data_list_frame3 = dataset_frame3.read_data_list_file()
dataset_frame3 = tf.data.Dataset.from_tensor_slices(tf.constant(data_list_frame3))
dataset_frame3 = dataset_frame3.repeat().shuffle(buffer_size=1000000, seed=1).map(_read_image)
+ dataset_frame3 = dataset_frame3.prefetch(100)
batch_frame1 = dataset_frame1.batch(FLAGS.batch_size).make_initializable_iterator()
batch_frame2 = dataset_frame2.batch(FLAGS.batch_size).make_initializable_iterator()
restorer.restore(sess, ckpt.model_checkpoint_path)
print('%s: Pre-trained model restored from %s' %
(datetime.now(), ckpt.model_checkpoint_path))
+ sess.run([batch_frame1.initializer, batch_frame2.initializer, batch_frame3.initializer])
else:
# Build an initialization operation to run below.
init = tf.initialize_all_variables()
imwrite(file_name_label, target_np[i,:,:,:])
# Save checkpoint
- if step % 5000 == 0 or (step +1) == FLAGS.max_steps:
+ if step % 500 == 0 or (step +1) == FLAGS.max_steps:
checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
saver.save(sess, checkpoint_path, global_step=step)