]> git.sesse.net Git - voxel-flow/blob - voxel_flow_train.py
Fix and reenable summary outputs.
[voxel-flow] / voxel_flow_train.py
1 """Train a voxel flow model on ucf101 dataset."""
2 from __future__ import absolute_import
3 from __future__ import division
4 from __future__ import print_function
5
6 import dataset
7 from utils.prefetch_queue_shuffle import PrefetchQueue
8 import numpy as np
9 import os
10 import tensorflow as tf
11 import tensorflow.contrib.slim as slim
12 from datetime import datetime
13 import random
14 from random import shuffle
15 from voxel_flow_model import Voxel_flow_model
16 from utils.image_utils import imwrite
17 from functools import partial
18 import pdb
19
20 FLAGS = tf.app.flags.FLAGS
21
22 # Define necessary FLAGS
23 tf.app.flags.DEFINE_string('train_dir', './voxel_flow_checkpoints/',
24                            """Directory where to write event logs """
25                            """and checkpoint.""")
26 tf.app.flags.DEFINE_string('train_image_dir', './voxel_flow_train_image/',
27                            """Directory where to output images.""")
28 tf.app.flags.DEFINE_string('test_image_dir', './voxel_flow_test_image/',
29                            """Directory where to output images.""")
30 tf.app.flags.DEFINE_string('subset', 'train',
31                            """Either 'train' or 'validation'.""")
32 tf.app.flags.DEFINE_string('pretrained_model_checkpoint_path', './voxel_flow_checkpoints/',
33                            """If specified, restore this pretrained model """
34                            """before beginning any training.""")
35 tf.app.flags.DEFINE_integer('max_steps', 10000000,
36                             """Number of batches to run.""")
37 tf.app.flags.DEFINE_integer(
38         'batch_size', 32, 'The number of samples in each batch.')
39 tf.app.flags.DEFINE_float('initial_learning_rate', 0.0003,
40                           """Initial learning rate.""")
41
42
43 def train(dataset_frame1, dataset_frame2, dataset_frame3):
44   """Trains a model."""
45   with tf.Graph().as_default():
46     # Create input and target placeholder.
47     input_placeholder = tf.placeholder(tf.float32, shape=(None, 256, 256, 6))
48     target_placeholder = tf.placeholder(tf.float32, shape=(None, 256, 256, 3))
49
50     # input_resized = tf.image.resize_area(input_placeholder, [128, 128])
51     # target_resized = tf.image.resize_area(target_placeholder,[128, 128])
52
53     # Prepare model.
54     model = Voxel_flow_model()
55     prediction = model.inference(input_placeholder)
56     # reproduction_loss, prior_loss = model.loss(prediction, target_placeholder)
57     reproduction_loss = model.loss(prediction, target_placeholder)
58     # total_loss = reproduction_loss + prior_loss
59     total_loss = reproduction_loss
60     
61     # Perform learning rate scheduling.
62     learning_rate = FLAGS.initial_learning_rate
63
64     # Create an optimizer that performs gradient descent.
65     opt = tf.train.AdamOptimizer(learning_rate)
66     grads = opt.compute_gradients(total_loss)
67     update_op = opt.apply_gradients(grads)
68
69     # Create summaries
70     summaries = tf.get_collection(tf.GraphKeys.SUMMARIES)
71     summaries.append(tf.summary.scalar('total_loss', total_loss))
72     summaries.append(tf.summary.scalar('reproduction_loss', reproduction_loss))
73     # summaries.append(tf.summary.scalar('prior_loss', prior_loss))
74     summaries.append(tf.summary.image('Input Image (before)', input_placeholder[:, :, :, 0:3], 3));
75     summaries.append(tf.summary.image('Input Image (after)', input_placeholder[:, :, :, 3:6], 3));
76     summaries.append(tf.summary.image('Output Image', prediction, 3))
77     summaries.append(tf.summary.image('Target Image', target_placeholder, 3))
78
79     # Create a saver.
80     saver = tf.train.Saver(tf.all_variables())
81
82     # Build the summary operation from the last tower summaries.
83     summary_op = tf.summary.merge_all()
84
85     # Build an initialization operation to run below.
86     init = tf.initialize_all_variables()
87     sess = tf.Session()
88     sess.run(init)
89
90     # Summary Writter
91     summary_writer = tf.summary.FileWriter(
92       FLAGS.train_dir,
93       graph=sess.graph)
94
95     # Training loop using feed dict method.
96     data_list_frame1 = dataset_frame1.read_data_list_file()
97     random.seed(1)
98     shuffle(data_list_frame1)
99
100     data_list_frame2 = dataset_frame2.read_data_list_file()
101     random.seed(1)
102     shuffle(data_list_frame2)
103
104     data_list_frame3 = dataset_frame3.read_data_list_file()
105     random.seed(1)
106     shuffle(data_list_frame3)
107
108     data_size = len(data_list_frame1)
109     epoch_num = int(data_size / FLAGS.batch_size)
110
111     # num_workers = 1
112       
113     # load_fn_frame1 = partial(dataset_frame1.process_func)
114     # p_queue_frame1 = PrefetchQueue(load_fn_frame1, data_list_frame1, FLAGS.batch_size, shuffle=False, num_workers=num_workers)
115
116     # load_fn_frame2 = partial(dataset_frame2.process_func)
117     # p_queue_frame2 = PrefetchQueue(load_fn_frame2, data_list_frame2, FLAGS.batch_size, shuffle=False, num_workers=num_workers)
118
119     # load_fn_frame3 = partial(dataset_frame3.process_func)
120     # p_queue_frame3 = PrefetchQueue(load_fn_frame3, data_list_frame3, FLAGS.batch_size, shuffle=False, num_workers=num_workers)
121
122     for step in range(0, FLAGS.max_steps):
123       batch_idx = step % epoch_num
124       
125       batch_data_list_frame1 = data_list_frame1[int(batch_idx * FLAGS.batch_size) : int((batch_idx + 1) * FLAGS.batch_size)]
126       batch_data_list_frame2 = data_list_frame2[int(batch_idx * FLAGS.batch_size) : int((batch_idx + 1) * FLAGS.batch_size)]
127       batch_data_list_frame3 = data_list_frame3[int(batch_idx * FLAGS.batch_size) : int((batch_idx + 1) * FLAGS.batch_size)]
128       
129       # Load batch data.
130       batch_data_frame1 = np.array([dataset_frame1.process_func(line) for line in batch_data_list_frame1])
131       batch_data_frame2 = np.array([dataset_frame2.process_func(line) for line in batch_data_list_frame2])
132       batch_data_frame3 = np.array([dataset_frame3.process_func(line) for line in batch_data_list_frame3])
133
134       # batch_data_frame1 = p_queue_frame1.get_batch()
135       # batch_data_frame2 = p_queue_frame2.get_batch()
136       # batch_data_frame3 = p_queue_frame3.get_batch()
137
138       feed_dict = {input_placeholder: np.concatenate((batch_data_frame1, batch_data_frame3), 3), target_placeholder: batch_data_frame2}
139      
140       # Run single step update.
141       _, loss_value = sess.run([update_op, total_loss], feed_dict = feed_dict)
142       
143       if batch_idx == 0:
144         # Shuffle data at each epoch.
145         random.seed(1)
146         shuffle(data_list_frame1)
147         random.seed(1)
148         shuffle(data_list_frame2)
149         random.seed(1)
150         shuffle(data_list_frame3)
151         print('Epoch Number: %d' % int(step / epoch_num))
152       
153       if step % 10 == 0:
154         print("Loss at step %d: %f" % (step, loss_value))
155
156       if step % 100 == 0:
157         # Output Summary 
158         summary_str = sess.run(summary_op, feed_dict = feed_dict)
159         summary_writer.add_summary(summary_str, step)
160
161       if step % 500 == 0:
162         # Run a batch of images 
163         prediction_np, target_np = sess.run([prediction, target_placeholder], feed_dict = feed_dict) 
164         for i in range(0,prediction_np.shape[0]):
165           file_name = FLAGS.train_image_dir+str(i)+'_out.png'
166           file_name_label = FLAGS.train_image_dir+str(i)+'_gt.png'
167           imwrite(file_name, prediction_np[i,:,:,:])
168           imwrite(file_name_label, target_np[i,:,:,:])
169
170       # Save checkpoint 
171       if step % 5000 == 0 or (step +1) == FLAGS.max_steps:
172         checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
173         saver.save(sess, checkpoint_path, global_step=step)
174
175 def validate(dataset_frame1, dataset_frame2, dataset_frame3):
176   """Performs validation on model.
177   Args:  
178   """
179   pass
180
181 def test(dataset_frame1, dataset_frame2, dataset_frame3):
182   """Perform test on a trained model."""
183   with tf.Graph().as_default():
184                 # Create input and target placeholder.
185     input_placeholder = tf.placeholder(tf.float32, shape=(None, 256, 256, 6))
186     target_placeholder = tf.placeholder(tf.float32, shape=(None, 256, 256, 3))
187     
188     # input_resized = tf.image.resize_area(input_placeholder, [128, 128])
189     # target_resized = tf.image.resize_area(target_placeholder,[128, 128])
190
191     # Prepare model.
192     model = Voxel_flow_model(is_train=True)
193     prediction = model.inference(input_placeholder)
194     # reproduction_loss, prior_loss = model.loss(prediction, target_placeholder)
195     reproduction_loss = model.loss(prediction, target_placeholder)
196     # total_loss = reproduction_loss + prior_loss
197     total_loss = reproduction_loss
198
199     # Create a saver and load.
200     saver = tf.train.Saver(tf.all_variables())
201     sess = tf.Session()
202
203     # Restore checkpoint from file.
204     if FLAGS.pretrained_model_checkpoint_path:
205       assert tf.gfile.Exists(FLAGS.pretrained_model_checkpoint_path)
206       ckpt = tf.train.get_checkpoint_state(
207                FLAGS.pretrained_model_checkpoint_path)
208       restorer = tf.train.Saver()
209       restorer.restore(sess, ckpt.model_checkpoint_path)
210       print('%s: Pre-trained model restored from %s' %
211         (datetime.now(), ckpt.model_checkpoint_path))
212     
213     # Process on test dataset.
214     data_list_frame1 = dataset_frame1.read_data_list_file()
215     data_size = len(data_list_frame1)
216     epoch_num = int(data_size / FLAGS.batch_size)
217
218     data_list_frame2 = dataset_frame2.read_data_list_file()
219
220     data_list_frame3 = dataset_frame3.read_data_list_file()
221
222     i = 0 
223     PSNR = 0
224
225     for id_img in range(0, data_size):  
226       # Load single data.
227       line_image_frame1 = dataset_frame1.process_func(data_list_frame1[id_img])
228       line_image_frame2 = dataset_frame2.process_func(data_list_frame2[id_img])
229       line_image_frame3 = dataset_frame3.process_func(data_list_frame3[id_img])
230       
231       batch_data_frame1 = [dataset_frame1.process_func(ll) for ll in data_list_frame1[0:63]]
232       batch_data_frame2 = [dataset_frame2.process_func(ll) for ll in data_list_frame2[0:63]]
233       batch_data_frame3 = [dataset_frame3.process_func(ll) for ll in data_list_frame3[0:63]]
234       
235       batch_data_frame1.append(line_image_frame1)
236       batch_data_frame2.append(line_image_frame2)
237       batch_data_frame3.append(line_image_frame3)
238       
239       batch_data_frame1 = np.array(batch_data_frame1)
240       batch_data_frame2 = np.array(batch_data_frame2)
241       batch_data_frame3 = np.array(batch_data_frame3)
242       
243       feed_dict = {input_placeholder: np.concatenate((batch_data_frame1, batch_data_frame3), 3),
244                     target_placeholder: batch_data_frame2}
245       # Run single step update.
246       prediction_np, target_np, loss_value = sess.run([prediction,
247                                                       target_placeholder,
248                                                       total_loss],
249                                                       feed_dict = feed_dict)
250       print("Loss for image %d: %f" % (i,loss_value))
251       file_name = FLAGS.test_image_dir+str(i)+'_out.png'
252       file_name_label = FLAGS.test_image_dir+str(i)+'_gt.png'
253       imwrite(file_name, prediction_np[-1,:,:,:])
254       imwrite(file_name_label, target_np[-1,:,:,:])
255       i += 1
256       PSNR += 10*np.log10(255.0*255.0/np.sum(np.square(prediction_np-target_np)))
257     print("Overall PSNR: %f db" % (PSNR/len(data_list)))
258       
259 if __name__ == '__main__':
260   
261   os.environ["CUDA_VISIBLE_DEVICES"] = "0"
262
263   if FLAGS.subset == 'train':
264     
265     data_list_path_frame1 = "data_list/ucf101_train_files_frame1.txt"
266     data_list_path_frame2 = "data_list/ucf101_train_files_frame2.txt"
267     data_list_path_frame3 = "data_list/ucf101_train_files_frame3.txt"
268     
269     ucf101_dataset_frame1 = dataset.Dataset(data_list_path_frame1) 
270     ucf101_dataset_frame2 = dataset.Dataset(data_list_path_frame2) 
271     ucf101_dataset_frame3 = dataset.Dataset(data_list_path_frame3) 
272     
273     train(ucf101_dataset_frame1, ucf101_dataset_frame2, ucf101_dataset_frame3)
274   
275   elif FLAGS.subset == 'test':
276     
277     data_list_path_frame1 = "data_list/ucf101_test_files_frame1.txt"
278     data_list_path_frame2 = "data_list/ucf101_test_files_frame2.txt"
279     data_list_path_frame3 = "data_list/ucf101_test_files_frame3.txt"
280     
281     ucf101_dataset_frame1 = dataset.Dataset(data_list_path_frame1) 
282     ucf101_dataset_frame2 = dataset.Dataset(data_list_path_frame2) 
283     ucf101_dataset_frame3 = dataset.Dataset(data_list_path_frame3) 
284     
285     test(ucf101_dataset_frame1, ucf101_dataset_frame2, ucf101_dataset_frame3)