1 """Train a voxel flow model on ucf101 dataset."""
2 from __future__ import absolute_import
3 from __future__ import division
4 from __future__ import print_function
7 from utils.prefetch_queue_shuffle import PrefetchQueue
10 import tensorflow as tf
11 import tensorflow.contrib.slim as slim
12 from datetime import datetime
14 from random import shuffle
15 from voxel_flow_model import Voxel_flow_model
16 from utils.image_utils import imwrite
17 from functools import partial
20 FLAGS = tf.app.flags.FLAGS
22 # Define necessary FLAGS
23 tf.app.flags.DEFINE_string('train_dir', './voxel_flow_checkpoints/',
24 """Directory where to write event logs """
25 """and checkpoint.""")
26 tf.app.flags.DEFINE_string('train_image_dir', './voxel_flow_train_image/',
27 """Directory where to output images.""")
28 tf.app.flags.DEFINE_string('test_image_dir', './voxel_flow_test_image/',
29 """Directory where to output images.""")
30 tf.app.flags.DEFINE_string('subset', 'train',
31 """Either 'train' or 'validation'.""")
32 tf.app.flags.DEFINE_string('pretrained_model_checkpoint_path', './voxel_flow_checkpoints/',
33 """If specified, restore this pretrained model """
34 """before beginning any training.""")
35 tf.app.flags.DEFINE_integer('max_steps', 10000000,
36 """Number of batches to run.""")
37 tf.app.flags.DEFINE_integer(
38 'batch_size', 32, 'The number of samples in each batch.')
39 tf.app.flags.DEFINE_float('initial_learning_rate', 0.0003,
40 """Initial learning rate.""")
43 def train(dataset_frame1, dataset_frame2, dataset_frame3):
45 with tf.Graph().as_default():
46 # Create input and target placeholder.
47 input_placeholder = tf.placeholder(tf.float32, shape=(None, 256, 256, 6))
48 target_placeholder = tf.placeholder(tf.float32, shape=(None, 256, 256, 3))
50 # input_resized = tf.image.resize_area(input_placeholder, [128, 128])
51 # target_resized = tf.image.resize_area(target_placeholder,[128, 128])
54 model = Voxel_flow_model()
55 prediction = model.inference(input_placeholder) d
56 # reproduction_loss, prior_loss = model.loss(prediction, target_placeholder)
57 reproduction_loss = model.loss(prediction, target_placeholder)
58 # total_loss = reproduction_loss + prior_loss
59 total_loss = reproduction_loss
61 # Perform learning rate scheduling.
62 learning_rate = FLAGS.initial_learning_rate
64 # Create an optimizer that performs gradient descent.
65 opt = tf.train.AdamOptimizer(learning_rate)
66 grads = opt.compute_gradients(total_loss)
67 update_op = opt.apply_gradients(grads)
70 summaries = tf.get_collection(tf.GraphKeys.SUMMARIES)
71 summaries.append(tf.scalar_summary('total_loss', total_loss))
72 summaries.append(tf.scalar_summary('reproduction_loss', reproduction_loss))
73 # summaries.append(tf.scalar_summary('prior_loss', prior_loss))
74 summaries.append(tf.image_summary('Input Image', input_placeholder, 3))
75 summaries.append(tf.image_summary('Output Image', prediction, 3))
76 summaries.append(tf.image_summary('Target Image', target_placeholder, 3))
79 saver = tf.train.Saver(tf.all_variables())
81 # Build the summary operation from the last tower summaries.
82 summary_op = tf.merge_all_summaries()
84 # Build an initialization operation to run below.
85 init = tf.initialize_all_variables()
90 summary_writer = tf.train.SummaryWriter(
94 # Training loop using feed dict method.
95 data_list_frame1 = dataset_frame1.read_data_list_file()
97 shuffle(data_list_frame1)
99 data_list_frame2 = dataset_frame2.read_data_list_file()
101 shuffle(data_list_frame2)
103 data_list_frame3 = dataset_frame3.read_data_list_file()
105 shuffle(data_list_frame3)
107 data_size = len(data_list_frame1)
108 epoch_num = int(data_size / FLAGS.batch_size)
112 # load_fn_frame1 = partial(dataset_frame1.process_func)
113 # p_queue_frame1 = PrefetchQueue(load_fn_frame1, data_list_frame1, FLAGS.batch_size, shuffle=False, num_workers=num_workers)
115 # load_fn_frame2 = partial(dataset_frame2.process_func)
116 # p_queue_frame2 = PrefetchQueue(load_fn_frame2, data_list_frame2, FLAGS.batch_size, shuffle=False, num_workers=num_workers)
118 # load_fn_frame3 = partial(dataset_frame3.process_func)
119 # p_queue_frame3 = PrefetchQueue(load_fn_frame3, data_list_frame3, FLAGS.batch_size, shuffle=False, num_workers=num_workers)
121 for step in xrange(0, FLAGS.max_steps):
122 batch_idx = step % epoch_num
124 batch_data_list_frame1 = data_list_frame1[int(batch_idx * FLAGS.batch_size) : int((batch_idx + 1) * FLAGS.batch_size)]
125 batch_data_list_frame2 = data_list_frame2[int(batch_idx * FLAGS.batch_size) : int((batch_idx + 1) * FLAGS.batch_size)]
126 batch_data_list_frame3 = data_list_frame3[int(batch_idx * FLAGS.batch_size) : int((batch_idx + 1) * FLAGS.batch_size)]
129 batch_data_frame1 = np.array([dataset_frame1.process_func(line) for line in batch_data_list_frame1])
130 batch_data_frame2 = np.array([dataset_frame2.process_func(line) for line in batch_data_list_frame2])
131 batch_data_frame3 = np.array([dataset_frame3.process_func(line) for line in batch_data_list_frame3])
133 # batch_data_frame1 = p_queue_frame1.get_batch()
134 # batch_data_frame2 = p_queue_frame2.get_batch()
135 # batch_data_frame3 = p_queue_frame3.get_batch()
137 feed_dict = {input_placeholder: np.concatenate((batch_data_frame1, batch_data_frame3), 3), target_placeholder: batch_data_frame2}
139 # Run single step update.
140 _, loss_value = sess.run([update_op, total_loss], feed_dict = feed_dict)
143 # Shuffle data at each epoch.
145 shuffle(data_list_frame1)
147 shuffle(data_list_frame2)
149 shuffle(data_list_frame3)
150 print('Epoch Number: %d' % int(step / epoch_num))
154 # summary_str = sess.run(summary_op, feed_dict = feed_dict)
155 # summary_writer.add_summary(summary_str, step)
156 print("Loss at step %d: %f" % (step, loss_value))
159 # Run a batch of images
160 prediction_np, target_np = sess.run([prediction, target_placeholder], feed_dict = feed_dict)
161 for i in range(0,prediction_np.shape[0]):
162 file_name = FLAGS.train_image_dir+str(i)+'_out.png'
163 file_name_label = FLAGS.train_image_dir+str(i)+'_gt.png'
164 imwrite(file_name, prediction_np[i,:,:,:])
165 imwrite(file_name_label, target_np[i,:,:,:])
168 if step % 5000 == 0 or (step +1) == FLAGS.max_steps:
169 checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
170 saver.save(sess, checkpoint_path, global_step=step)
172 def validate(dataset_frame1, dataset_frame2, dataset_frame3):
173 """Performs validation on model.
178 def test(dataset_frame1, dataset_frame2, dataset_frame3):
179 """Perform test on a trained model."""
180 with tf.Graph().as_default():
181 # Create input and target placeholder.
182 input_placeholder = tf.placeholder(tf.float32, shape=(None, 256, 256, 6))
183 target_placeholder = tf.placeholder(tf.float32, shape=(None, 256, 256, 3))
185 # input_resized = tf.image.resize_area(input_placeholder, [128, 128])
186 # target_resized = tf.image.resize_area(target_placeholder,[128, 128])
189 model = Voxel_flow_model(is_train=True)
190 prediction = model.inference(input_placeholder)
191 # reproduction_loss, prior_loss = model.loss(prediction, target_placeholder)
192 reproduction_loss = model.loss(prediction, target_placeholder)
193 # total_loss = reproduction_loss + prior_loss
194 total_loss = reproduction_loss
196 # Create a saver and load.
197 saver = tf.train.Saver(tf.all_variables())
200 # Restore checkpoint from file.
201 if FLAGS.pretrained_model_checkpoint_path:
202 assert tf.gfile.Exists(FLAGS.pretrained_model_checkpoint_path)
203 ckpt = tf.train.get_checkpoint_state(
204 FLAGS.pretrained_model_checkpoint_path)
205 restorer = tf.train.Saver()
206 restorer.restore(sess, ckpt.model_checkpoint_path)
207 print('%s: Pre-trained model restored from %s' %
208 (datetime.now(), ckpt.model_checkpoint_path))
210 # Process on test dataset.
211 data_list_frame1 = dataset_frame1.read_data_list_file()
212 data_size = len(data_list_frame1)
213 epoch_num = int(data_size / FLAGS.batch_size)
215 data_list_frame2 = dataset_frame2.read_data_list_file()
217 data_list_frame3 = dataset_frame3.read_data_list_file()
222 for id_img in range(0, data_size):
224 line_image_frame1 = dataset_frame1.process_func(data_list_frame1[id_img])
225 line_image_frame2 = dataset_frame2.process_func(data_list_frame2[id_img])
226 line_image_frame3 = dataset_frame3.process_func(data_list_frame3[id_img])
228 batch_data_frame1 = [dataset_frame1.process_func(ll) for ll in data_list_frame1[0:63]]
229 batch_data_frame2 = [dataset_frame2.process_func(ll) for ll in data_list_frame2[0:63]]
230 batch_data_frame3 = [dataset_frame3.process_func(ll) for ll in data_list_frame3[0:63]]
232 batch_data_frame1.append(line_image_frame1)
233 batch_data_frame2.append(line_image_frame2)
234 batch_data_frame3.append(line_image_frame3)
236 batch_data_frame1 = np.array(batch_data_frame1)
237 batch_data_frame2 = np.array(batch_data_frame2)
238 batch_data_frame3 = np.array(batch_data_frame3)
240 feed_dict = {input_placeholder: np.concatenate((batch_data_frame1, batch_data_frame3), 3),
241 target_placeholder: batch_data_frame2}
242 # Run single step update.
243 prediction_np, target_np, loss_value = sess.run([prediction,
246 feed_dict = feed_dict)
247 print("Loss for image %d: %f" % (i,loss_value))
248 file_name = FLAGS.test_image_dir+str(i)+'_out.png'
249 file_name_label = FLAGS.test_image_dir+str(i)+'_gt.png'
250 imwrite(file_name, prediction_np[-1,:,:,:])
251 imwrite(file_name_label, target_np[-1,:,:,:])
253 PSNR += 10*np.log10(255.0*255.0/np.sum(np.square(prediction_np-target_np)))
254 print("Overall PSNR: %f db" % (PSNR/len(data_list)))
256 if __name__ == '__main__':
258 os.environ["CUDA_VISIBLE_DEVICES"] = "0"
260 if FLAGS.subset == 'train':
262 data_list_path_frame1 = "data_list/ucf101_train_files_frame1.txt"
263 data_list_path_frame2 = "data_list/ucf101_train_files_frame2.txt"
264 data_list_path_frame3 = "data_list/ucf101_train_files_frame3.txt"
266 ucf101_dataset_frame1 = dataset.Dataset(data_list_path_frame1)
267 ucf101_dataset_frame2 = dataset.Dataset(data_list_path_frame2)
268 ucf101_dataset_frame3 = dataset.Dataset(data_list_path_frame3)
270 train(ucf101_dataset_frame1, ucf101_dataset_frame2, ucf101_dataset_frame3)
272 elif FLAGS.subset == 'test':
274 data_list_path_frame1 = "data_list/ucf101_test_files_frame1.txt"
275 data_list_path_frame2 = "data_list/ucf101_test_files_frame2.txt"
276 data_list_path_frame3 = "data_list/ucf101_test_files_frame3.txt"
278 ucf101_dataset_frame1 = dataset.Dataset(data_list_path_frame1)
279 ucf101_dataset_frame2 = dataset.Dataset(data_list_path_frame2)
280 ucf101_dataset_frame3 = dataset.Dataset(data_list_path_frame3)
282 test(ucf101_dataset_frame1, ucf101_dataset_frame2, ucf101_dataset_frame3)