]> git.sesse.net Git - voxel-flow/blobdiff - voxel_flow_train.py
Fix yet more no-movement rejection (we would reject A-A-B pulldown, but not A-B-B).
[voxel-flow] / voxel_flow_train.py
index 42f91af8b22e6ba0b7fb96888c35c03d0fa62dd6..0ba31722b156a8ccb08d6e3cd0205d15992fd06a 100755 (executable)
@@ -52,14 +52,17 @@ def train(dataset_frame1, dataset_frame2, dataset_frame3):
     data_list_frame1 = dataset_frame1.read_data_list_file()
     dataset_frame1 = tf.data.Dataset.from_tensor_slices(tf.constant(data_list_frame1))
     dataset_frame1 = dataset_frame1.repeat().shuffle(buffer_size=1000000, seed=1).map(_read_image)
+    dataset_frame1 = dataset_frame1.prefetch(100)
 
     data_list_frame2 = dataset_frame2.read_data_list_file()
     dataset_frame2 = tf.data.Dataset.from_tensor_slices(tf.constant(data_list_frame2))
     dataset_frame2 = dataset_frame2.repeat().shuffle(buffer_size=1000000, seed=1).map(_read_image)
+    dataset_frame2 = dataset_frame2.prefetch(100)
 
     data_list_frame3 = dataset_frame3.read_data_list_file()
     dataset_frame3 = tf.data.Dataset.from_tensor_slices(tf.constant(data_list_frame3))
     dataset_frame3 = dataset_frame3.repeat().shuffle(buffer_size=1000000, seed=1).map(_read_image)
+    dataset_frame3 = dataset_frame3.prefetch(100)
 
     batch_frame1 = dataset_frame1.batch(FLAGS.batch_size).make_initializable_iterator()
     batch_frame2 = dataset_frame2.batch(FLAGS.batch_size).make_initializable_iterator()
@@ -115,6 +118,7 @@ def train(dataset_frame1, dataset_frame2, dataset_frame3):
       restorer.restore(sess, ckpt.model_checkpoint_path)
       print('%s: Pre-trained model restored from %s' %
         (datetime.now(), ckpt.model_checkpoint_path))
+      sess.run([batch_frame1.initializer, batch_frame2.initializer, batch_frame3.initializer])
     else:
       # Build an initialization operation to run below.
       init = tf.initialize_all_variables()
@@ -156,7 +160,7 @@ def train(dataset_frame1, dataset_frame2, dataset_frame3):
           imwrite(file_name_label, target_np[i,:,:,:])
 
       # Save checkpoint 
-      if step % 5000 == 0 or (step +1) == FLAGS.max_steps:
+      if step % 500 == 0 or (step +1) == FLAGS.max_steps:
         checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
         saver.save(sess, checkpoint_path, global_step=step)