]> git.sesse.net Git - voxel-flow/blob - utils/prefetch_queue_shuffle.py
Initial commit
[voxel-flow] / utils / prefetch_queue_shuffle.py
1 from __future__ import print_function
2
3 import glob
4 import numpy as np
5 import os 
6 import Queue
7 import random
8 import scipy
9 from scipy import misc
10 import threading 
11
12 class DummpyData(object):
13   def __init__(self, data):
14     self.data = data
15   def __cmp__(self, other):
16     return 0
17
18 def prefetch_job(load_fn, prefetch_queue, data_list, shuffle, prefetch_size):
19   """
20   """
21   data_count = 0
22   total_count = len(data_list)
23   idx = 0
24   while True:
25     if shuffle:
26       if data_count == 0:
27         random.shuffle(data_list)
28       data = load_fn(data_list[data_count]) #Load your data here.
29       if type(data) is list:
30         for data_point in data: 
31           idx = random.randint(0, prefetch_size)
32           dummy_data = DummpyData(data_point)
33           prefetch_queue.put((idx, dummy_data), block=True)
34       else:
35         idx = random.randint(0, prefetch_size)
36         dummy_data = DummpyData(data)
37         prefetch_queue.put((idx, dummy_data), block=True)
38     else:
39       data = load_fn(data_list[data_count]) #Load your data here.
40       dummy_data = DummpyData(data)
41       prefetch_queue.put((idx, dummy_data), block=True)
42       idx = (idx + 1) % prefetch_size
43
44     data_count = (data_count + 1) % total_count
45
46 class PrefetchQueue(object):
47   def __init__(self, load_fn, data_list, batch_size=32, prefetch_size=None, shuffle=True, num_workers=4):
48     self.data_list = data_list
49     self.shuffle = shuffle
50     self.prefetch_size = prefetch_size
51     self.load_fn = load_fn
52     self.batch_size = batch_size
53     if prefetch_size is None:
54       self.prefetch_size = 4 * batch_size
55
56     # Start prefetching thread
57     # self.prefetch_queue = Queue.Queue(maxsize=prefetch_size)
58     self.prefetch_queue = Queue.PriorityQueue(maxsize=prefetch_size)
59     for k in range(num_workers):
60       t = threading.Thread(target=prefetch_job,
61         args=(self.load_fn, self.prefetch_queue, self.data_list,
62               self.shuffle, self.prefetch_size))
63       t.daemon = True
64       t.start()
65
66   def get_batch(self):
67     data_list = []
68     for k in range(0, self.batch_size):
69       # if self.prefetch_queue.empty():
70       #   print('Prefetch Queue is empty, waiting for data to be read.')
71       _, data_dummy = self.prefetch_queue.get(block=True)
72       data = data_dummy.data
73       data_list.append(np.expand_dims(data,0))
74     return np.concatenate(data_list, axis=0)
75
76
77 if __name__ == '__main__':
78   # Simple Eval Script For Usage.
79   def load_fn_example(data_file_path):
80     return scipy.misc.imread(data_file_path)
81
82   import time
83   data_path_pattern = '/home/VoxelFlow/dataset/ucf101/*.jpg' 
84   data_list = glob.glob(data_path_pattern)  # dataset.read_data_list_file()
85   load_fn = load_fn_example  # dataset.process_func()
86   num_workers=2
87   batch_size = 32
88   
89   # Prefetch IO.
90   p_queue = PrefetchQueue(load_fn, data_list, batch_size, num_workers=num_workers)
91   time.sleep(5)
92   print('Start') 
93   import datetime
94   a = datetime.datetime.now()
95   for k in range(0,50):
96     time.sleep(0.1)
97     X = p_queue.get_batch()
98   b = datetime.datetime.now()
99   delta = b - a
100   print(delta)
101   print("%d miliseconds" % int(delta.total_seconds()))
102
103   # Naive FILE IO
104   import glob
105   data_list = glob.glob(data_path_pattern) 
106   a = datetime.datetime.now()
107   for k in range(0,50):
108     time.sleep(0.1)
109     data_sub_list = data_list[k*batch_size:(k+1)*batch_size]
110     im_list = [np.expand_dims(scipy.misc.imread(file_name),0) for file_name in data_sub_list]
111     X = np.concatenate(im_list,axis=0)  
112
113   b = datetime.datetime.now()
114   delta = b - a
115   print(delta)
116   print("%d miliseconds" % int(delta.total_seconds()))