1 from __future__ import print_function
12 class DummpyData(object):
13 def __init__(self, data):
15 def __cmp__(self, other):
18 def prefetch_job(load_fn, prefetch_queue, data_list, shuffle, prefetch_size):
22 total_count = len(data_list)
27 random.shuffle(data_list)
28 data = load_fn(data_list[data_count]) #Load your data here.
29 if type(data) is list:
30 for data_point in data:
31 idx = random.randint(0, prefetch_size)
32 dummy_data = DummpyData(data_point)
33 prefetch_queue.put((idx, dummy_data), block=True)
35 idx = random.randint(0, prefetch_size)
36 dummy_data = DummpyData(data)
37 prefetch_queue.put((idx, dummy_data), block=True)
39 data = load_fn(data_list[data_count]) #Load your data here.
40 dummy_data = DummpyData(data)
41 prefetch_queue.put((idx, dummy_data), block=True)
42 idx = (idx + 1) % prefetch_size
44 data_count = (data_count + 1) % total_count
46 class PrefetchQueue(object):
47 def __init__(self, load_fn, data_list, batch_size=32, prefetch_size=None, shuffle=True, num_workers=4):
48 self.data_list = data_list
49 self.shuffle = shuffle
50 self.prefetch_size = prefetch_size
51 self.load_fn = load_fn
52 self.batch_size = batch_size
53 if prefetch_size is None:
54 self.prefetch_size = 4 * batch_size
56 # Start prefetching thread
57 # self.prefetch_queue = Queue.Queue(maxsize=prefetch_size)
58 self.prefetch_queue = Queue.PriorityQueue(maxsize=prefetch_size)
59 for k in range(num_workers):
60 t = threading.Thread(target=prefetch_job,
61 args=(self.load_fn, self.prefetch_queue, self.data_list,
62 self.shuffle, self.prefetch_size))
68 for k in range(0, self.batch_size):
69 # if self.prefetch_queue.empty():
70 # print('Prefetch Queue is empty, waiting for data to be read.')
71 _, data_dummy = self.prefetch_queue.get(block=True)
72 data = data_dummy.data
73 data_list.append(np.expand_dims(data,0))
74 return np.concatenate(data_list, axis=0)
77 if __name__ == '__main__':
78 # Simple Eval Script For Usage.
79 def load_fn_example(data_file_path):
80 return scipy.misc.imread(data_file_path)
83 data_path_pattern = '/home/VoxelFlow/dataset/ucf101/*.jpg'
84 data_list = glob.glob(data_path_pattern) # dataset.read_data_list_file()
85 load_fn = load_fn_example # dataset.process_func()
90 p_queue = PrefetchQueue(load_fn, data_list, batch_size, num_workers=num_workers)
94 a = datetime.datetime.now()
97 X = p_queue.get_batch()
98 b = datetime.datetime.now()
101 print("%d miliseconds" % int(delta.total_seconds()))
105 data_list = glob.glob(data_path_pattern)
106 a = datetime.datetime.now()
107 for k in range(0,50):
109 data_sub_list = data_list[k*batch_size:(k+1)*batch_size]
110 im_list = [np.expand_dims(scipy.misc.imread(file_name),0) for file_name in data_sub_list]
111 X = np.concatenate(im_list,axis=0)
113 b = datetime.datetime.now()
116 print("%d miliseconds" % int(delta.total_seconds()))