X-Git-Url: https://git.sesse.net/?a=blobdiff_plain;f=voxel_flow_train.py;h=0ba31722b156a8ccb08d6e3cd0205d15992fd06a;hb=HEAD;hp=19847f16cb56a8e7433f6a7c158410589868c2b8;hpb=9ca9b0416d25aecce26313f0c9a2a45c61088661;p=voxel-flow diff --git a/voxel_flow_train.py b/voxel_flow_train.py index 19847f1..0ba3172 100755 --- a/voxel_flow_train.py +++ b/voxel_flow_train.py @@ -52,14 +52,17 @@ def train(dataset_frame1, dataset_frame2, dataset_frame3): data_list_frame1 = dataset_frame1.read_data_list_file() dataset_frame1 = tf.data.Dataset.from_tensor_slices(tf.constant(data_list_frame1)) dataset_frame1 = dataset_frame1.repeat().shuffle(buffer_size=1000000, seed=1).map(_read_image) + dataset_frame1 = dataset_frame1.prefetch(100) data_list_frame2 = dataset_frame2.read_data_list_file() dataset_frame2 = tf.data.Dataset.from_tensor_slices(tf.constant(data_list_frame2)) dataset_frame2 = dataset_frame2.repeat().shuffle(buffer_size=1000000, seed=1).map(_read_image) + dataset_frame2 = dataset_frame2.prefetch(100) data_list_frame3 = dataset_frame3.read_data_list_file() dataset_frame3 = tf.data.Dataset.from_tensor_slices(tf.constant(data_list_frame3)) dataset_frame3 = dataset_frame3.repeat().shuffle(buffer_size=1000000, seed=1).map(_read_image) + dataset_frame3 = dataset_frame3.prefetch(100) batch_frame1 = dataset_frame1.batch(FLAGS.batch_size).make_initializable_iterator() batch_frame2 = dataset_frame2.batch(FLAGS.batch_size).make_initializable_iterator() @@ -74,7 +77,7 @@ def train(dataset_frame1, dataset_frame2, dataset_frame3): # Prepare model. model = Voxel_flow_model() - prediction = model.inference(input_placeholder) + prediction, flow = model.inference(input_placeholder) # reproduction_loss, prior_loss = model.loss(prediction, target_placeholder) reproduction_loss = model.loss(prediction, target_placeholder) # total_loss = reproduction_loss + prior_loss @@ -97,6 +100,7 @@ def train(dataset_frame1, dataset_frame2, dataset_frame3): summaries.append(tf.summary.image('Input Image (after)', input_placeholder[:, :, :, 3:6], 3)); summaries.append(tf.summary.image('Output Image', prediction, 3)) summaries.append(tf.summary.image('Target Image', target_placeholder, 3)) + summaries.append(tf.summary.image('Flow', flow, 3)) # Create a saver. saver = tf.train.Saver(tf.all_variables()) @@ -114,6 +118,7 @@ def train(dataset_frame1, dataset_frame2, dataset_frame3): restorer.restore(sess, ckpt.model_checkpoint_path) print('%s: Pre-trained model restored from %s' % (datetime.now(), ckpt.model_checkpoint_path)) + sess.run([batch_frame1.initializer, batch_frame2.initializer, batch_frame3.initializer]) else: # Build an initialization operation to run below. init = tf.initialize_all_variables() @@ -155,7 +160,7 @@ def train(dataset_frame1, dataset_frame2, dataset_frame3): imwrite(file_name_label, target_np[i,:,:,:]) # Save checkpoint - if step % 5000 == 0 or (step +1) == FLAGS.max_steps: + if step % 500 == 0 or (step +1) == FLAGS.max_steps: checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt') saver.save(sess, checkpoint_path, global_step=step) @@ -176,7 +181,7 @@ def test(dataset_frame1, dataset_frame2, dataset_frame3): # target_resized = tf.image.resize_area(target_placeholder,[128, 128]) # Prepare model. - model = Voxel_flow_model(is_train=True) + model, flow = Voxel_flow_model(is_train=True) prediction = model.inference(input_placeholder) # reproduction_loss, prior_loss = model.loss(prediction, target_placeholder) reproduction_loss = model.loss(prediction, target_placeholder)