]> git.sesse.net Git - voxel-flow/blob - voxel_flow_train.py
Visualize the flow in a summary op.
[voxel-flow] / voxel_flow_train.py
1 """Train a voxel flow model on ucf101 dataset."""
2 from __future__ import absolute_import
3 from __future__ import division
4 from __future__ import print_function
5
6 import dataset
7 from utils.prefetch_queue_shuffle import PrefetchQueue
8 import numpy as np
9 import os
10 import tensorflow as tf
11 import tensorflow.contrib.slim as slim
12 from datetime import datetime
13 import random
14 from random import shuffle
15 from voxel_flow_model import Voxel_flow_model
16 from utils.image_utils import imwrite
17 from functools import partial
18 import pdb
19
20 FLAGS = tf.app.flags.FLAGS
21
22 # Define necessary FLAGS
23 tf.app.flags.DEFINE_string('train_dir', './voxel_flow_checkpoints/',
24                            """Directory where to write event logs """
25                            """and checkpoint.""")
26 tf.app.flags.DEFINE_string('train_image_dir', './voxel_flow_train_image/',
27                            """Directory where to output images.""")
28 tf.app.flags.DEFINE_string('test_image_dir', './voxel_flow_test_image/',
29                            """Directory where to output images.""")
30 tf.app.flags.DEFINE_string('subset', 'train',
31                            """Either 'train' or 'validation'.""")
32 tf.app.flags.DEFINE_string('pretrained_model_checkpoint_path', './voxel_flow_checkpoints/',
33                            """If specified, restore this pretrained model """
34                            """before beginning any training.""")
35 tf.app.flags.DEFINE_integer('max_steps', 10000000,
36                             """Number of batches to run.""")
37 tf.app.flags.DEFINE_integer(
38         'batch_size', 32, 'The number of samples in each batch.')
39 tf.app.flags.DEFINE_float('initial_learning_rate', 0.0003,
40                           """Initial learning rate.""")
41
42 def _read_image(filename):
43   image_string = tf.read_file(filename)
44   image_decoded = tf.image.decode_image(image_string, channels=3)
45   image_decoded.set_shape([256, 256, 3])
46   return tf.cast(image_decoded, dtype=tf.float32) / 127.5 - 1.0
47
48 def train(dataset_frame1, dataset_frame2, dataset_frame3):
49   """Trains a model."""
50   with tf.Graph().as_default():
51     # Create input.
52     data_list_frame1 = dataset_frame1.read_data_list_file()
53     dataset_frame1 = tf.data.Dataset.from_tensor_slices(tf.constant(data_list_frame1))
54     dataset_frame1 = dataset_frame1.repeat().shuffle(buffer_size=1000000, seed=1).map(_read_image)
55
56     data_list_frame2 = dataset_frame2.read_data_list_file()
57     dataset_frame2 = tf.data.Dataset.from_tensor_slices(tf.constant(data_list_frame2))
58     dataset_frame2 = dataset_frame2.repeat().shuffle(buffer_size=1000000, seed=1).map(_read_image)
59
60     data_list_frame3 = dataset_frame3.read_data_list_file()
61     dataset_frame3 = tf.data.Dataset.from_tensor_slices(tf.constant(data_list_frame3))
62     dataset_frame3 = dataset_frame3.repeat().shuffle(buffer_size=1000000, seed=1).map(_read_image)
63
64     batch_frame1 = dataset_frame1.batch(FLAGS.batch_size).make_initializable_iterator()
65     batch_frame2 = dataset_frame2.batch(FLAGS.batch_size).make_initializable_iterator()
66     batch_frame3 = dataset_frame3.batch(FLAGS.batch_size).make_initializable_iterator()
67
68     # Create input and target placeholder.
69     input_placeholder = tf.concat([batch_frame1.get_next(), batch_frame3.get_next()], 3)
70     target_placeholder = batch_frame2.get_next()
71
72     # input_resized = tf.image.resize_area(input_placeholder, [128, 128])
73     # target_resized = tf.image.resize_area(target_placeholder,[128, 128])
74
75     # Prepare model.
76     model = Voxel_flow_model()
77     prediction, flow = model.inference(input_placeholder)
78     # reproduction_loss, prior_loss = model.loss(prediction, target_placeholder)
79     reproduction_loss = model.loss(prediction, target_placeholder)
80     # total_loss = reproduction_loss + prior_loss
81     total_loss = reproduction_loss
82     
83     # Perform learning rate scheduling.
84     learning_rate = FLAGS.initial_learning_rate
85
86     # Create an optimizer that performs gradient descent.
87     opt = tf.train.AdamOptimizer(learning_rate)
88     grads = opt.compute_gradients(total_loss)
89     update_op = opt.apply_gradients(grads)
90
91     # Create summaries
92     summaries = tf.get_collection(tf.GraphKeys.SUMMARIES)
93     summaries.append(tf.summary.scalar('total_loss', total_loss))
94     summaries.append(tf.summary.scalar('reproduction_loss', reproduction_loss))
95     # summaries.append(tf.summary.scalar('prior_loss', prior_loss))
96     summaries.append(tf.summary.image('Input Image (before)', input_placeholder[:, :, :, 0:3], 3));
97     summaries.append(tf.summary.image('Input Image (after)', input_placeholder[:, :, :, 3:6], 3));
98     summaries.append(tf.summary.image('Output Image', prediction, 3))
99     summaries.append(tf.summary.image('Target Image', target_placeholder, 3))
100     summaries.append(tf.summary.image('Flow', flow, 3))
101
102     # Create a saver.
103     saver = tf.train.Saver(tf.all_variables())
104
105     # Build the summary operation from the last tower summaries.
106     summary_op = tf.summary.merge_all()
107
108     # Restore checkpoint from file.
109     if FLAGS.pretrained_model_checkpoint_path:
110       sess = tf.Session()
111       assert tf.gfile.Exists(FLAGS.pretrained_model_checkpoint_path)
112       ckpt = tf.train.get_checkpoint_state(
113                FLAGS.pretrained_model_checkpoint_path)
114       restorer = tf.train.Saver()
115       restorer.restore(sess, ckpt.model_checkpoint_path)
116       print('%s: Pre-trained model restored from %s' %
117         (datetime.now(), ckpt.model_checkpoint_path))
118     else:
119       # Build an initialization operation to run below.
120       init = tf.initialize_all_variables()
121       sess = tf.Session()
122       sess.run([init, batch_frame1.initializer, batch_frame2.initializer, batch_frame3.initializer])
123
124     # Summary Writter
125     summary_writer = tf.summary.FileWriter(
126       FLAGS.train_dir,
127       graph=sess.graph)
128
129     data_size = len(data_list_frame1)
130     epoch_num = int(data_size / FLAGS.batch_size)
131
132     for step in range(0, FLAGS.max_steps):
133       batch_idx = step % epoch_num
134       
135       # Run single step update.
136       _, loss_value = sess.run([update_op, total_loss])
137       
138       if batch_idx == 0:
139         print('Epoch Number: %d' % int(step / epoch_num))
140       
141       if step % 10 == 0:
142         print("Loss at step %d: %f" % (step, loss_value))
143
144       if step % 100 == 0:
145         # Output Summary 
146         summary_str = sess.run(summary_op)
147         summary_writer.add_summary(summary_str, step)
148
149       if step % 500 == 0:
150         # Run a batch of images 
151         prediction_np, target_np = sess.run([prediction, target_placeholder])
152         for i in range(0,prediction_np.shape[0]):
153           file_name = FLAGS.train_image_dir+str(i)+'_out.png'
154           file_name_label = FLAGS.train_image_dir+str(i)+'_gt.png'
155           imwrite(file_name, prediction_np[i,:,:,:])
156           imwrite(file_name_label, target_np[i,:,:,:])
157
158       # Save checkpoint 
159       if step % 5000 == 0 or (step +1) == FLAGS.max_steps:
160         checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
161         saver.save(sess, checkpoint_path, global_step=step)
162
163 def validate(dataset_frame1, dataset_frame2, dataset_frame3):
164   """Performs validation on model.
165   Args:  
166   """
167   pass
168
169 def test(dataset_frame1, dataset_frame2, dataset_frame3):
170   """Perform test on a trained model."""
171   with tf.Graph().as_default():
172                 # Create input and target placeholder.
173     input_placeholder = tf.placeholder(tf.float32, shape=(None, 256, 256, 6))
174     target_placeholder = tf.placeholder(tf.float32, shape=(None, 256, 256, 3))
175     
176     # input_resized = tf.image.resize_area(input_placeholder, [128, 128])
177     # target_resized = tf.image.resize_area(target_placeholder,[128, 128])
178
179     # Prepare model.
180     model, flow = Voxel_flow_model(is_train=True)
181     prediction = model.inference(input_placeholder)
182     # reproduction_loss, prior_loss = model.loss(prediction, target_placeholder)
183     reproduction_loss = model.loss(prediction, target_placeholder)
184     # total_loss = reproduction_loss + prior_loss
185     total_loss = reproduction_loss
186
187     # Create a saver and load.
188     saver = tf.train.Saver(tf.all_variables())
189     sess = tf.Session()
190
191     # Restore checkpoint from file.
192     if FLAGS.pretrained_model_checkpoint_path:
193       assert tf.gfile.Exists(FLAGS.pretrained_model_checkpoint_path)
194       ckpt = tf.train.get_checkpoint_state(
195                FLAGS.pretrained_model_checkpoint_path)
196       restorer = tf.train.Saver()
197       restorer.restore(sess, ckpt.model_checkpoint_path)
198       print('%s: Pre-trained model restored from %s' %
199         (datetime.now(), ckpt.model_checkpoint_path))
200     
201     # Process on test dataset.
202     data_list_frame1 = dataset_frame1.read_data_list_file()
203     data_size = len(data_list_frame1)
204     epoch_num = int(data_size / FLAGS.batch_size)
205
206     data_list_frame2 = dataset_frame2.read_data_list_file()
207
208     data_list_frame3 = dataset_frame3.read_data_list_file()
209
210     i = 0 
211     PSNR = 0
212
213     for id_img in range(0, data_size):  
214       # Load single data.
215       line_image_frame1 = dataset_frame1.process_func(data_list_frame1[id_img])
216       line_image_frame2 = dataset_frame2.process_func(data_list_frame2[id_img])
217       line_image_frame3 = dataset_frame3.process_func(data_list_frame3[id_img])
218       
219       batch_data_frame1 = [dataset_frame1.process_func(ll) for ll in data_list_frame1[0:63]]
220       batch_data_frame2 = [dataset_frame2.process_func(ll) for ll in data_list_frame2[0:63]]
221       batch_data_frame3 = [dataset_frame3.process_func(ll) for ll in data_list_frame3[0:63]]
222       
223       batch_data_frame1.append(line_image_frame1)
224       batch_data_frame2.append(line_image_frame2)
225       batch_data_frame3.append(line_image_frame3)
226       
227       batch_data_frame1 = np.array(batch_data_frame1)
228       batch_data_frame2 = np.array(batch_data_frame2)
229       batch_data_frame3 = np.array(batch_data_frame3)
230       
231       feed_dict = {input_placeholder: np.concatenate((batch_data_frame1, batch_data_frame3), 3),
232                     target_placeholder: batch_data_frame2}
233       # Run single step update.
234       prediction_np, target_np, loss_value = sess.run([prediction,
235                                                       target_placeholder,
236                                                       total_loss],
237                                                       feed_dict = feed_dict)
238       print("Loss for image %d: %f" % (i,loss_value))
239       file_name = FLAGS.test_image_dir+str(i)+'_out.png'
240       file_name_label = FLAGS.test_image_dir+str(i)+'_gt.png'
241       imwrite(file_name, prediction_np[-1,:,:,:])
242       imwrite(file_name_label, target_np[-1,:,:,:])
243       i += 1
244       PSNR += 10*np.log10(255.0*255.0/np.sum(np.square(prediction_np-target_np)))
245     print("Overall PSNR: %f db" % (PSNR/len(data_list)))
246       
247 if __name__ == '__main__':
248   
249   os.environ["CUDA_VISIBLE_DEVICES"] = "0"
250
251   if FLAGS.subset == 'train':
252     
253     data_list_path_frame1 = "data_list/ucf101_train_files_frame1.txt"
254     data_list_path_frame2 = "data_list/ucf101_train_files_frame2.txt"
255     data_list_path_frame3 = "data_list/ucf101_train_files_frame3.txt"
256     
257     ucf101_dataset_frame1 = dataset.Dataset(data_list_path_frame1) 
258     ucf101_dataset_frame2 = dataset.Dataset(data_list_path_frame2) 
259     ucf101_dataset_frame3 = dataset.Dataset(data_list_path_frame3) 
260     
261     train(ucf101_dataset_frame1, ucf101_dataset_frame2, ucf101_dataset_frame3)
262   
263   elif FLAGS.subset == 'test':
264     
265     data_list_path_frame1 = "data_list/ucf101_test_files_frame1.txt"
266     data_list_path_frame2 = "data_list/ucf101_test_files_frame2.txt"
267     data_list_path_frame3 = "data_list/ucf101_test_files_frame3.txt"
268     
269     ucf101_dataset_frame1 = dataset.Dataset(data_list_path_frame1) 
270     ucf101_dataset_frame2 = dataset.Dataset(data_list_path_frame2) 
271     ucf101_dataset_frame3 = dataset.Dataset(data_list_path_frame3) 
272     
273     test(ucf101_dataset_frame1, ucf101_dataset_frame2, ucf101_dataset_frame3)