1 """Train a voxel flow model on ucf101 dataset."""
2 from __future__ import absolute_import
3 from __future__ import division
4 from __future__ import print_function
7 from utils.prefetch_queue_shuffle import PrefetchQueue
10 import tensorflow as tf
11 import tensorflow.contrib.slim as slim
12 from datetime import datetime
14 from random import shuffle
15 from voxel_flow_model import Voxel_flow_model
16 from utils.image_utils import imwrite
17 from functools import partial
20 FLAGS = tf.app.flags.FLAGS
22 # Define necessary FLAGS
23 tf.app.flags.DEFINE_string('train_dir', './voxel_flow_checkpoints/',
24 """Directory where to write event logs """
25 """and checkpoint.""")
26 tf.app.flags.DEFINE_string('train_image_dir', './voxel_flow_train_image/',
27 """Directory where to output images.""")
28 tf.app.flags.DEFINE_string('test_image_dir', './voxel_flow_test_image/',
29 """Directory where to output images.""")
30 tf.app.flags.DEFINE_string('subset', 'train',
31 """Either 'train' or 'validation'.""")
32 tf.app.flags.DEFINE_string('pretrained_model_checkpoint_path', './voxel_flow_checkpoints/',
33 """If specified, restore this pretrained model """
34 """before beginning any training.""")
35 tf.app.flags.DEFINE_integer('max_steps', 10000000,
36 """Number of batches to run.""")
37 tf.app.flags.DEFINE_integer(
38 'batch_size', 32, 'The number of samples in each batch.')
39 tf.app.flags.DEFINE_float('initial_learning_rate', 0.0003,
40 """Initial learning rate.""")
42 def _read_image(filename):
43 image_string = tf.read_file(filename)
44 image_decoded = tf.image.decode_image(image_string, channels=3)
45 image_decoded.set_shape([256, 256, 3])
46 return tf.cast(image_decoded, dtype=tf.float32) / 127.5 - 1.0
48 def train(dataset_frame1, dataset_frame2, dataset_frame3):
50 with tf.Graph().as_default():
52 data_list_frame1 = dataset_frame1.read_data_list_file()
53 dataset_frame1 = tf.data.Dataset.from_tensor_slices(tf.constant(data_list_frame1))
54 dataset_frame1 = dataset_frame1.repeat().shuffle(buffer_size=1000000, seed=1).map(_read_image)
56 data_list_frame2 = dataset_frame2.read_data_list_file()
57 dataset_frame2 = tf.data.Dataset.from_tensor_slices(tf.constant(data_list_frame2))
58 dataset_frame2 = dataset_frame2.repeat().shuffle(buffer_size=1000000, seed=1).map(_read_image)
60 data_list_frame3 = dataset_frame3.read_data_list_file()
61 dataset_frame3 = tf.data.Dataset.from_tensor_slices(tf.constant(data_list_frame3))
62 dataset_frame3 = dataset_frame3.repeat().shuffle(buffer_size=1000000, seed=1).map(_read_image)
64 batch_frame1 = dataset_frame1.batch(FLAGS.batch_size).make_initializable_iterator()
65 batch_frame2 = dataset_frame2.batch(FLAGS.batch_size).make_initializable_iterator()
66 batch_frame3 = dataset_frame3.batch(FLAGS.batch_size).make_initializable_iterator()
68 # Create input and target placeholder.
69 input_placeholder = tf.concat([batch_frame1.get_next(), batch_frame3.get_next()], 3)
70 target_placeholder = batch_frame2.get_next()
72 # input_resized = tf.image.resize_area(input_placeholder, [128, 128])
73 # target_resized = tf.image.resize_area(target_placeholder,[128, 128])
76 model = Voxel_flow_model()
77 prediction, flow = model.inference(input_placeholder)
78 # reproduction_loss, prior_loss = model.loss(prediction, target_placeholder)
79 reproduction_loss = model.loss(prediction, target_placeholder)
80 # total_loss = reproduction_loss + prior_loss
81 total_loss = reproduction_loss
83 # Perform learning rate scheduling.
84 learning_rate = FLAGS.initial_learning_rate
86 # Create an optimizer that performs gradient descent.
87 opt = tf.train.AdamOptimizer(learning_rate)
88 grads = opt.compute_gradients(total_loss)
89 update_op = opt.apply_gradients(grads)
92 summaries = tf.get_collection(tf.GraphKeys.SUMMARIES)
93 summaries.append(tf.summary.scalar('total_loss', total_loss))
94 summaries.append(tf.summary.scalar('reproduction_loss', reproduction_loss))
95 # summaries.append(tf.summary.scalar('prior_loss', prior_loss))
96 summaries.append(tf.summary.image('Input Image (before)', input_placeholder[:, :, :, 0:3], 3));
97 summaries.append(tf.summary.image('Input Image (after)', input_placeholder[:, :, :, 3:6], 3));
98 summaries.append(tf.summary.image('Output Image', prediction, 3))
99 summaries.append(tf.summary.image('Target Image', target_placeholder, 3))
100 summaries.append(tf.summary.image('Flow', flow, 3))
103 saver = tf.train.Saver(tf.all_variables())
105 # Build the summary operation from the last tower summaries.
106 summary_op = tf.summary.merge_all()
108 # Restore checkpoint from file.
109 if FLAGS.pretrained_model_checkpoint_path:
111 assert tf.gfile.Exists(FLAGS.pretrained_model_checkpoint_path)
112 ckpt = tf.train.get_checkpoint_state(
113 FLAGS.pretrained_model_checkpoint_path)
114 restorer = tf.train.Saver()
115 restorer.restore(sess, ckpt.model_checkpoint_path)
116 print('%s: Pre-trained model restored from %s' %
117 (datetime.now(), ckpt.model_checkpoint_path))
118 sess.run([batch_frame1.initializer, batch_frame2.initializer, batch_frame3.initializer])
120 # Build an initialization operation to run below.
121 init = tf.initialize_all_variables()
123 sess.run([init, batch_frame1.initializer, batch_frame2.initializer, batch_frame3.initializer])
126 summary_writer = tf.summary.FileWriter(
130 data_size = len(data_list_frame1)
131 epoch_num = int(data_size / FLAGS.batch_size)
133 for step in range(0, FLAGS.max_steps):
134 batch_idx = step % epoch_num
136 # Run single step update.
137 _, loss_value = sess.run([update_op, total_loss])
140 print('Epoch Number: %d' % int(step / epoch_num))
143 print("Loss at step %d: %f" % (step, loss_value))
147 summary_str = sess.run(summary_op)
148 summary_writer.add_summary(summary_str, step)
151 # Run a batch of images
152 prediction_np, target_np = sess.run([prediction, target_placeholder])
153 for i in range(0,prediction_np.shape[0]):
154 file_name = FLAGS.train_image_dir+str(i)+'_out.png'
155 file_name_label = FLAGS.train_image_dir+str(i)+'_gt.png'
156 imwrite(file_name, prediction_np[i,:,:,:])
157 imwrite(file_name_label, target_np[i,:,:,:])
160 if step % 5000 == 0 or (step +1) == FLAGS.max_steps:
161 checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
162 saver.save(sess, checkpoint_path, global_step=step)
164 def validate(dataset_frame1, dataset_frame2, dataset_frame3):
165 """Performs validation on model.
170 def test(dataset_frame1, dataset_frame2, dataset_frame3):
171 """Perform test on a trained model."""
172 with tf.Graph().as_default():
173 # Create input and target placeholder.
174 input_placeholder = tf.placeholder(tf.float32, shape=(None, 256, 256, 6))
175 target_placeholder = tf.placeholder(tf.float32, shape=(None, 256, 256, 3))
177 # input_resized = tf.image.resize_area(input_placeholder, [128, 128])
178 # target_resized = tf.image.resize_area(target_placeholder,[128, 128])
181 model, flow = Voxel_flow_model(is_train=True)
182 prediction = model.inference(input_placeholder)
183 # reproduction_loss, prior_loss = model.loss(prediction, target_placeholder)
184 reproduction_loss = model.loss(prediction, target_placeholder)
185 # total_loss = reproduction_loss + prior_loss
186 total_loss = reproduction_loss
188 # Create a saver and load.
189 saver = tf.train.Saver(tf.all_variables())
192 # Restore checkpoint from file.
193 if FLAGS.pretrained_model_checkpoint_path:
194 assert tf.gfile.Exists(FLAGS.pretrained_model_checkpoint_path)
195 ckpt = tf.train.get_checkpoint_state(
196 FLAGS.pretrained_model_checkpoint_path)
197 restorer = tf.train.Saver()
198 restorer.restore(sess, ckpt.model_checkpoint_path)
199 print('%s: Pre-trained model restored from %s' %
200 (datetime.now(), ckpt.model_checkpoint_path))
202 # Process on test dataset.
203 data_list_frame1 = dataset_frame1.read_data_list_file()
204 data_size = len(data_list_frame1)
205 epoch_num = int(data_size / FLAGS.batch_size)
207 data_list_frame2 = dataset_frame2.read_data_list_file()
209 data_list_frame3 = dataset_frame3.read_data_list_file()
214 for id_img in range(0, data_size):
216 line_image_frame1 = dataset_frame1.process_func(data_list_frame1[id_img])
217 line_image_frame2 = dataset_frame2.process_func(data_list_frame2[id_img])
218 line_image_frame3 = dataset_frame3.process_func(data_list_frame3[id_img])
220 batch_data_frame1 = [dataset_frame1.process_func(ll) for ll in data_list_frame1[0:63]]
221 batch_data_frame2 = [dataset_frame2.process_func(ll) for ll in data_list_frame2[0:63]]
222 batch_data_frame3 = [dataset_frame3.process_func(ll) for ll in data_list_frame3[0:63]]
224 batch_data_frame1.append(line_image_frame1)
225 batch_data_frame2.append(line_image_frame2)
226 batch_data_frame3.append(line_image_frame3)
228 batch_data_frame1 = np.array(batch_data_frame1)
229 batch_data_frame2 = np.array(batch_data_frame2)
230 batch_data_frame3 = np.array(batch_data_frame3)
232 feed_dict = {input_placeholder: np.concatenate((batch_data_frame1, batch_data_frame3), 3),
233 target_placeholder: batch_data_frame2}
234 # Run single step update.
235 prediction_np, target_np, loss_value = sess.run([prediction,
238 feed_dict = feed_dict)
239 print("Loss for image %d: %f" % (i,loss_value))
240 file_name = FLAGS.test_image_dir+str(i)+'_out.png'
241 file_name_label = FLAGS.test_image_dir+str(i)+'_gt.png'
242 imwrite(file_name, prediction_np[-1,:,:,:])
243 imwrite(file_name_label, target_np[-1,:,:,:])
245 PSNR += 10*np.log10(255.0*255.0/np.sum(np.square(prediction_np-target_np)))
246 print("Overall PSNR: %f db" % (PSNR/len(data_list)))
248 if __name__ == '__main__':
250 os.environ["CUDA_VISIBLE_DEVICES"] = "0"
252 if FLAGS.subset == 'train':
254 data_list_path_frame1 = "data_list/ucf101_train_files_frame1.txt"
255 data_list_path_frame2 = "data_list/ucf101_train_files_frame2.txt"
256 data_list_path_frame3 = "data_list/ucf101_train_files_frame3.txt"
258 ucf101_dataset_frame1 = dataset.Dataset(data_list_path_frame1)
259 ucf101_dataset_frame2 = dataset.Dataset(data_list_path_frame2)
260 ucf101_dataset_frame3 = dataset.Dataset(data_list_path_frame3)
262 train(ucf101_dataset_frame1, ucf101_dataset_frame2, ucf101_dataset_frame3)
264 elif FLAGS.subset == 'test':
266 data_list_path_frame1 = "data_list/ucf101_test_files_frame1.txt"
267 data_list_path_frame2 = "data_list/ucf101_test_files_frame2.txt"
268 data_list_path_frame3 = "data_list/ucf101_test_files_frame3.txt"
270 ucf101_dataset_frame1 = dataset.Dataset(data_list_path_frame1)
271 ucf101_dataset_frame2 = dataset.Dataset(data_list_path_frame2)
272 ucf101_dataset_frame3 = dataset.Dataset(data_list_path_frame3)
274 test(ucf101_dataset_frame1, ucf101_dataset_frame2, ucf101_dataset_frame3)