]> git.sesse.net Git - voxel-flow/commitdiff
Move to tf.data, for much more efficient data loading with less code.
authorSteinar H. Gunderson <sgunderson@bigfoot.com>
Thu, 8 Feb 2018 23:16:39 +0000 (00:16 +0100)
committerSteinar H. Gunderson <sgunderson@bigfoot.com>
Thu, 8 Feb 2018 23:16:39 +0000 (00:16 +0100)
dataset.py
voxel_flow_train.py

index 5957f079755579bf653f396e473413b5fda30981..bdd88b8b5744d989e31b6e28a56d22911a68dbf9 100755 (executable)
@@ -1,5 +1,4 @@
 """Implements a dataset class for handling image data"""
-from utils.image_utils import imread, imsave
 
 DATA_PATH_BASE = '/home/VoxelFlow/dataset/ucf101_triplets/'
 
@@ -9,10 +8,6 @@ class Dataset(object):
       Args:
     """
     self.data_list_file = data_list_file
-    if process_func:
-      self.process_func = process_func
-    else:
-      self.process_func = self.process_func
  
   def read_data_list_file(self):
     """Reads the data list_file into python list
@@ -21,10 +16,3 @@ class Dataset(object):
     data_list =  [DATA_PATH_BASE+line.rstrip() for line in f]
     self.data_list = data_list
     return data_list
-
-  def process_func(self, example_line):
-    """Process the single example line and return data 
-      Default behavior, assumes each line is the path to a single image.
-      This is used to train a VAE.
-    """
-    return imread(example_line)
index e1fa96f2c3102f226148b18283eb1f5855cec4eb..19847f16cb56a8e7433f6a7c158410589868c2b8 100755 (executable)
@@ -39,13 +39,35 @@ tf.app.flags.DEFINE_integer(
 tf.app.flags.DEFINE_float('initial_learning_rate', 0.0003,
                           """Initial learning rate.""")
 
+def _read_image(filename):
+  image_string = tf.read_file(filename)
+  image_decoded = tf.image.decode_image(image_string, channels=3)
+  image_decoded.set_shape([256, 256, 3])
+  return tf.cast(image_decoded, dtype=tf.float32) / 127.5 - 1.0
 
 def train(dataset_frame1, dataset_frame2, dataset_frame3):
   """Trains a model."""
   with tf.Graph().as_default():
+    # Create input.
+    data_list_frame1 = dataset_frame1.read_data_list_file()
+    dataset_frame1 = tf.data.Dataset.from_tensor_slices(tf.constant(data_list_frame1))
+    dataset_frame1 = dataset_frame1.repeat().shuffle(buffer_size=1000000, seed=1).map(_read_image)
+
+    data_list_frame2 = dataset_frame2.read_data_list_file()
+    dataset_frame2 = tf.data.Dataset.from_tensor_slices(tf.constant(data_list_frame2))
+    dataset_frame2 = dataset_frame2.repeat().shuffle(buffer_size=1000000, seed=1).map(_read_image)
+
+    data_list_frame3 = dataset_frame3.read_data_list_file()
+    dataset_frame3 = tf.data.Dataset.from_tensor_slices(tf.constant(data_list_frame3))
+    dataset_frame3 = dataset_frame3.repeat().shuffle(buffer_size=1000000, seed=1).map(_read_image)
+
+    batch_frame1 = dataset_frame1.batch(FLAGS.batch_size).make_initializable_iterator()
+    batch_frame2 = dataset_frame2.batch(FLAGS.batch_size).make_initializable_iterator()
+    batch_frame3 = dataset_frame3.batch(FLAGS.batch_size).make_initializable_iterator()
+
     # Create input and target placeholder.
-    input_placeholder = tf.placeholder(tf.float32, shape=(None, 256, 256, 6))
-    target_placeholder = tf.placeholder(tf.float32, shape=(None, 256, 256, 3))
+    input_placeholder = tf.concat([batch_frame1.get_next(), batch_frame3.get_next()], 3)
+    target_placeholder = batch_frame2.get_next()
 
     # input_resized = tf.image.resize_area(input_placeholder, [128, 128])
     # target_resized = tf.image.resize_area(target_placeholder,[128, 128])
@@ -96,69 +118,23 @@ def train(dataset_frame1, dataset_frame2, dataset_frame3):
       # Build an initialization operation to run below.
       init = tf.initialize_all_variables()
       sess = tf.Session()
-      sess.run(init)
+      sess.run([init, batch_frame1.initializer, batch_frame2.initializer, batch_frame3.initializer])
 
     # Summary Writter
     summary_writer = tf.summary.FileWriter(
       FLAGS.train_dir,
       graph=sess.graph)
 
-    # Training loop using feed dict method.
-    data_list_frame1 = dataset_frame1.read_data_list_file()
-    random.seed(1)
-    shuffle(data_list_frame1)
-
-    data_list_frame2 = dataset_frame2.read_data_list_file()
-    random.seed(1)
-    shuffle(data_list_frame2)
-
-    data_list_frame3 = dataset_frame3.read_data_list_file()
-    random.seed(1)
-    shuffle(data_list_frame3)
-
     data_size = len(data_list_frame1)
     epoch_num = int(data_size / FLAGS.batch_size)
 
-    # num_workers = 1
-      
-    # load_fn_frame1 = partial(dataset_frame1.process_func)
-    # p_queue_frame1 = PrefetchQueue(load_fn_frame1, data_list_frame1, FLAGS.batch_size, shuffle=False, num_workers=num_workers)
-
-    # load_fn_frame2 = partial(dataset_frame2.process_func)
-    # p_queue_frame2 = PrefetchQueue(load_fn_frame2, data_list_frame2, FLAGS.batch_size, shuffle=False, num_workers=num_workers)
-
-    # load_fn_frame3 = partial(dataset_frame3.process_func)
-    # p_queue_frame3 = PrefetchQueue(load_fn_frame3, data_list_frame3, FLAGS.batch_size, shuffle=False, num_workers=num_workers)
-
     for step in range(0, FLAGS.max_steps):
       batch_idx = step % epoch_num
       
-      batch_data_list_frame1 = data_list_frame1[int(batch_idx * FLAGS.batch_size) : int((batch_idx + 1) * FLAGS.batch_size)]
-      batch_data_list_frame2 = data_list_frame2[int(batch_idx * FLAGS.batch_size) : int((batch_idx + 1) * FLAGS.batch_size)]
-      batch_data_list_frame3 = data_list_frame3[int(batch_idx * FLAGS.batch_size) : int((batch_idx + 1) * FLAGS.batch_size)]
-      
-      # Load batch data.
-      batch_data_frame1 = np.array([dataset_frame1.process_func(line) for line in batch_data_list_frame1])
-      batch_data_frame2 = np.array([dataset_frame2.process_func(line) for line in batch_data_list_frame2])
-      batch_data_frame3 = np.array([dataset_frame3.process_func(line) for line in batch_data_list_frame3])
-
-      # batch_data_frame1 = p_queue_frame1.get_batch()
-      # batch_data_frame2 = p_queue_frame2.get_batch()
-      # batch_data_frame3 = p_queue_frame3.get_batch()
-
-      feed_dict = {input_placeholder: np.concatenate((batch_data_frame1, batch_data_frame3), 3), target_placeholder: batch_data_frame2}
-     
       # Run single step update.
-      _, loss_value = sess.run([update_op, total_loss], feed_dict = feed_dict)
+      _, loss_value = sess.run([update_op, total_loss])
       
       if batch_idx == 0:
-        # Shuffle data at each epoch.
-        random.seed(1)
-        shuffle(data_list_frame1)
-        random.seed(1)
-        shuffle(data_list_frame2)
-        random.seed(1)
-        shuffle(data_list_frame3)
         print('Epoch Number: %d' % int(step / epoch_num))
       
       if step % 10 == 0:
@@ -166,12 +142,12 @@ def train(dataset_frame1, dataset_frame2, dataset_frame3):
 
       if step % 100 == 0:
         # Output Summary 
-        summary_str = sess.run(summary_op, feed_dict = feed_dict)
+        summary_str = sess.run(summary_op)
         summary_writer.add_summary(summary_str, step)
 
       if step % 500 == 0:
         # Run a batch of images        
-        prediction_np, target_np = sess.run([prediction, target_placeholder], feed_dict = feed_dict) 
+        prediction_np, target_np = sess.run([prediction, target_placeholder])
         for i in range(0,prediction_np.shape[0]):
           file_name = FLAGS.train_image_dir+str(i)+'_out.png'
           file_name_label = FLAGS.train_image_dir+str(i)+'_gt.png'