]> git.sesse.net Git - voxel-flow/blobdiff - voxel_flow_train.py
Fix yet more no-movement rejection (we would reject A-A-B pulldown, but not A-B-B).
[voxel-flow] / voxel_flow_train.py
index dfe87b198e63a3286bb674361b7608b4509a74d6..0ba31722b156a8ccb08d6e3cd0205d15992fd06a 100755 (executable)
@@ -39,20 +39,45 @@ tf.app.flags.DEFINE_integer(
 tf.app.flags.DEFINE_float('initial_learning_rate', 0.0003,
                           """Initial learning rate.""")
 
+def _read_image(filename):
+  image_string = tf.read_file(filename)
+  image_decoded = tf.image.decode_image(image_string, channels=3)
+  image_decoded.set_shape([256, 256, 3])
+  return tf.cast(image_decoded, dtype=tf.float32) / 127.5 - 1.0
 
 def train(dataset_frame1, dataset_frame2, dataset_frame3):
   """Trains a model."""
   with tf.Graph().as_default():
+    # Create input.
+    data_list_frame1 = dataset_frame1.read_data_list_file()
+    dataset_frame1 = tf.data.Dataset.from_tensor_slices(tf.constant(data_list_frame1))
+    dataset_frame1 = dataset_frame1.repeat().shuffle(buffer_size=1000000, seed=1).map(_read_image)
+    dataset_frame1 = dataset_frame1.prefetch(100)
+
+    data_list_frame2 = dataset_frame2.read_data_list_file()
+    dataset_frame2 = tf.data.Dataset.from_tensor_slices(tf.constant(data_list_frame2))
+    dataset_frame2 = dataset_frame2.repeat().shuffle(buffer_size=1000000, seed=1).map(_read_image)
+    dataset_frame2 = dataset_frame2.prefetch(100)
+
+    data_list_frame3 = dataset_frame3.read_data_list_file()
+    dataset_frame3 = tf.data.Dataset.from_tensor_slices(tf.constant(data_list_frame3))
+    dataset_frame3 = dataset_frame3.repeat().shuffle(buffer_size=1000000, seed=1).map(_read_image)
+    dataset_frame3 = dataset_frame3.prefetch(100)
+
+    batch_frame1 = dataset_frame1.batch(FLAGS.batch_size).make_initializable_iterator()
+    batch_frame2 = dataset_frame2.batch(FLAGS.batch_size).make_initializable_iterator()
+    batch_frame3 = dataset_frame3.batch(FLAGS.batch_size).make_initializable_iterator()
+
     # Create input and target placeholder.
-    input_placeholder = tf.placeholder(tf.float32, shape=(None, 256, 256, 6))
-    target_placeholder = tf.placeholder(tf.float32, shape=(None, 256, 256, 3))
+    input_placeholder = tf.concat([batch_frame1.get_next(), batch_frame3.get_next()], 3)
+    target_placeholder = batch_frame2.get_next()
 
     # input_resized = tf.image.resize_area(input_placeholder, [128, 128])
     # target_resized = tf.image.resize_area(target_placeholder,[128, 128])
 
     # Prepare model.
     model = Voxel_flow_model()
-    prediction = model.inference(input_placeholder)
+    prediction, flow = model.inference(input_placeholder)
     # reproduction_loss, prior_loss = model.loss(prediction, target_placeholder)
     reproduction_loss = model.loss(prediction, target_placeholder)
     # total_loss = reproduction_loss + prior_loss
@@ -68,96 +93,66 @@ def train(dataset_frame1, dataset_frame2, dataset_frame3):
 
     # Create summaries
     summaries = tf.get_collection(tf.GraphKeys.SUMMARIES)
-    summaries.append(tf.scalar_summary('total_loss', total_loss))
-    summaries.append(tf.scalar_summary('reproduction_loss', reproduction_loss))
-    # summaries.append(tf.scalar_summary('prior_loss', prior_loss))
-    summaries.append(tf.image_summary('Input Image', input_placeholder, 3))
-    summaries.append(tf.image_summary('Output Image', prediction, 3))
-    summaries.append(tf.image_summary('Target Image', target_placeholder, 3))
+    summaries.append(tf.summary.scalar('total_loss', total_loss))
+    summaries.append(tf.summary.scalar('reproduction_loss', reproduction_loss))
+    # summaries.append(tf.summary.scalar('prior_loss', prior_loss))
+    summaries.append(tf.summary.image('Input Image (before)', input_placeholder[:, :, :, 0:3], 3));
+    summaries.append(tf.summary.image('Input Image (after)', input_placeholder[:, :, :, 3:6], 3));
+    summaries.append(tf.summary.image('Output Image', prediction, 3))
+    summaries.append(tf.summary.image('Target Image', target_placeholder, 3))
+    summaries.append(tf.summary.image('Flow', flow, 3))
 
     # Create a saver.
     saver = tf.train.Saver(tf.all_variables())
 
     # Build the summary operation from the last tower summaries.
-    summary_op = tf.merge_all_summaries()
+    summary_op = tf.summary.merge_all()
 
-    # Build an initialization operation to run below.
-    init = tf.initialize_all_variables()
-    sess = tf.Session()
-    sess.run(init)
+    # Restore checkpoint from file.
+    if FLAGS.pretrained_model_checkpoint_path:
+      sess = tf.Session()
+      assert tf.gfile.Exists(FLAGS.pretrained_model_checkpoint_path)
+      ckpt = tf.train.get_checkpoint_state(
+               FLAGS.pretrained_model_checkpoint_path)
+      restorer = tf.train.Saver()
+      restorer.restore(sess, ckpt.model_checkpoint_path)
+      print('%s: Pre-trained model restored from %s' %
+        (datetime.now(), ckpt.model_checkpoint_path))
+      sess.run([batch_frame1.initializer, batch_frame2.initializer, batch_frame3.initializer])
+    else:
+      # Build an initialization operation to run below.
+      init = tf.initialize_all_variables()
+      sess = tf.Session()
+      sess.run([init, batch_frame1.initializer, batch_frame2.initializer, batch_frame3.initializer])
 
     # Summary Writter
-    summary_writer = tf.train.SummaryWriter(
+    summary_writer = tf.summary.FileWriter(
       FLAGS.train_dir,
       graph=sess.graph)
 
-    # Training loop using feed dict method.
-    data_list_frame1 = dataset_frame1.read_data_list_file()
-    random.seed(1)
-    shuffle(data_list_frame1)
-
-    data_list_frame2 = dataset_frame2.read_data_list_file()
-    random.seed(1)
-    shuffle(data_list_frame2)
-
-    data_list_frame3 = dataset_frame3.read_data_list_file()
-    random.seed(1)
-    shuffle(data_list_frame3)
-
     data_size = len(data_list_frame1)
     epoch_num = int(data_size / FLAGS.batch_size)
 
-    # num_workers = 1
-      
-    # load_fn_frame1 = partial(dataset_frame1.process_func)
-    # p_queue_frame1 = PrefetchQueue(load_fn_frame1, data_list_frame1, FLAGS.batch_size, shuffle=False, num_workers=num_workers)
-
-    # load_fn_frame2 = partial(dataset_frame2.process_func)
-    # p_queue_frame2 = PrefetchQueue(load_fn_frame2, data_list_frame2, FLAGS.batch_size, shuffle=False, num_workers=num_workers)
-
-    # load_fn_frame3 = partial(dataset_frame3.process_func)
-    # p_queue_frame3 = PrefetchQueue(load_fn_frame3, data_list_frame3, FLAGS.batch_size, shuffle=False, num_workers=num_workers)
-
-    for step in xrange(0, FLAGS.max_steps):
+    for step in range(0, FLAGS.max_steps):
       batch_idx = step % epoch_num
       
-      batch_data_list_frame1 = data_list_frame1[int(batch_idx * FLAGS.batch_size) : int((batch_idx + 1) * FLAGS.batch_size)]
-      batch_data_list_frame2 = data_list_frame2[int(batch_idx * FLAGS.batch_size) : int((batch_idx + 1) * FLAGS.batch_size)]
-      batch_data_list_frame3 = data_list_frame3[int(batch_idx * FLAGS.batch_size) : int((batch_idx + 1) * FLAGS.batch_size)]
-      
-      # Load batch data.
-      batch_data_frame1 = np.array([dataset_frame1.process_func(line) for line in batch_data_list_frame1])
-      batch_data_frame2 = np.array([dataset_frame2.process_func(line) for line in batch_data_list_frame2])
-      batch_data_frame3 = np.array([dataset_frame3.process_func(line) for line in batch_data_list_frame3])
-
-      # batch_data_frame1 = p_queue_frame1.get_batch()
-      # batch_data_frame2 = p_queue_frame2.get_batch()
-      # batch_data_frame3 = p_queue_frame3.get_batch()
-
-      feed_dict = {input_placeholder: np.concatenate((batch_data_frame1, batch_data_frame3), 3), target_placeholder: batch_data_frame2}
-     
       # Run single step update.
-      _, loss_value = sess.run([update_op, total_loss], feed_dict = feed_dict)
+      _, loss_value = sess.run([update_op, total_loss])
       
       if batch_idx == 0:
-        # Shuffle data at each epoch.
-        random.seed(1)
-        shuffle(data_list_frame1)
-        random.seed(1)
-        shuffle(data_list_frame2)
-        random.seed(1)
-        shuffle(data_list_frame3)
         print('Epoch Number: %d' % int(step / epoch_num))
       
-      # Output Summary 
       if step % 10 == 0:
-        # summary_str = sess.run(summary_op, feed_dict = feed_dict)
-        # summary_writer.add_summary(summary_str, step)
-             print("Loss at step %d: %f" % (step, loss_value))
+        print("Loss at step %d: %f" % (step, loss_value))
+
+      if step % 100 == 0:
+        # Output Summary 
+        summary_str = sess.run(summary_op)
+        summary_writer.add_summary(summary_str, step)
 
       if step % 500 == 0:
         # Run a batch of images        
-        prediction_np, target_np = sess.run([prediction, target_placeholder], feed_dict = feed_dict) 
+        prediction_np, target_np = sess.run([prediction, target_placeholder])
         for i in range(0,prediction_np.shape[0]):
           file_name = FLAGS.train_image_dir+str(i)+'_out.png'
           file_name_label = FLAGS.train_image_dir+str(i)+'_gt.png'
@@ -165,7 +160,7 @@ def train(dataset_frame1, dataset_frame2, dataset_frame3):
           imwrite(file_name_label, target_np[i,:,:,:])
 
       # Save checkpoint 
-      if step % 5000 == 0 or (step +1) == FLAGS.max_steps:
+      if step % 500 == 0 or (step +1) == FLAGS.max_steps:
         checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
         saver.save(sess, checkpoint_path, global_step=step)
 
@@ -186,7 +181,7 @@ def test(dataset_frame1, dataset_frame2, dataset_frame3):
     # target_resized = tf.image.resize_area(target_placeholder,[128, 128])
 
     # Prepare model.
-    model = Voxel_flow_model(is_train=True)
+    model, flow = Voxel_flow_model(is_train=True)
     prediction = model.inference(input_placeholder)
     # reproduction_loss, prior_loss = model.loss(prediction, target_placeholder)
     reproduction_loss = model.loss(prediction, target_placeholder)