]> git.sesse.net Git - voxel-flow/blob - voxel_flow_train.py
Checkpoint a bit more often.
[voxel-flow] / voxel_flow_train.py
1 """Train a voxel flow model on ucf101 dataset."""
2 from __future__ import absolute_import
3 from __future__ import division
4 from __future__ import print_function
5
6 import dataset
7 from utils.prefetch_queue_shuffle import PrefetchQueue
8 import numpy as np
9 import os
10 import tensorflow as tf
11 import tensorflow.contrib.slim as slim
12 from datetime import datetime
13 import random
14 from random import shuffle
15 from voxel_flow_model import Voxel_flow_model
16 from utils.image_utils import imwrite
17 from functools import partial
18 import pdb
19
20 FLAGS = tf.app.flags.FLAGS
21
22 # Define necessary FLAGS
23 tf.app.flags.DEFINE_string('train_dir', './voxel_flow_checkpoints/',
24                            """Directory where to write event logs """
25                            """and checkpoint.""")
26 tf.app.flags.DEFINE_string('train_image_dir', './voxel_flow_train_image/',
27                            """Directory where to output images.""")
28 tf.app.flags.DEFINE_string('test_image_dir', './voxel_flow_test_image/',
29                            """Directory where to output images.""")
30 tf.app.flags.DEFINE_string('subset', 'train',
31                            """Either 'train' or 'validation'.""")
32 tf.app.flags.DEFINE_string('pretrained_model_checkpoint_path', './voxel_flow_checkpoints/',
33                            """If specified, restore this pretrained model """
34                            """before beginning any training.""")
35 tf.app.flags.DEFINE_integer('max_steps', 10000000,
36                             """Number of batches to run.""")
37 tf.app.flags.DEFINE_integer(
38         'batch_size', 32, 'The number of samples in each batch.')
39 tf.app.flags.DEFINE_float('initial_learning_rate', 0.0003,
40                           """Initial learning rate.""")
41
42 def _read_image(filename):
43   image_string = tf.read_file(filename)
44   image_decoded = tf.image.decode_image(image_string, channels=3)
45   image_decoded.set_shape([256, 256, 3])
46   return tf.cast(image_decoded, dtype=tf.float32) / 127.5 - 1.0
47
48 def train(dataset_frame1, dataset_frame2, dataset_frame3):
49   """Trains a model."""
50   with tf.Graph().as_default():
51     # Create input.
52     data_list_frame1 = dataset_frame1.read_data_list_file()
53     dataset_frame1 = tf.data.Dataset.from_tensor_slices(tf.constant(data_list_frame1))
54     dataset_frame1 = dataset_frame1.repeat().shuffle(buffer_size=1000000, seed=1).map(_read_image)
55     dataset_frame1 = dataset_frame1.prefetch(100)
56
57     data_list_frame2 = dataset_frame2.read_data_list_file()
58     dataset_frame2 = tf.data.Dataset.from_tensor_slices(tf.constant(data_list_frame2))
59     dataset_frame2 = dataset_frame2.repeat().shuffle(buffer_size=1000000, seed=1).map(_read_image)
60     dataset_frame2 = dataset_frame2.prefetch(100)
61
62     data_list_frame3 = dataset_frame3.read_data_list_file()
63     dataset_frame3 = tf.data.Dataset.from_tensor_slices(tf.constant(data_list_frame3))
64     dataset_frame3 = dataset_frame3.repeat().shuffle(buffer_size=1000000, seed=1).map(_read_image)
65     dataset_frame3 = dataset_frame3.prefetch(100)
66
67     batch_frame1 = dataset_frame1.batch(FLAGS.batch_size).make_initializable_iterator()
68     batch_frame2 = dataset_frame2.batch(FLAGS.batch_size).make_initializable_iterator()
69     batch_frame3 = dataset_frame3.batch(FLAGS.batch_size).make_initializable_iterator()
70
71     # Create input and target placeholder.
72     input_placeholder = tf.concat([batch_frame1.get_next(), batch_frame3.get_next()], 3)
73     target_placeholder = batch_frame2.get_next()
74
75     # input_resized = tf.image.resize_area(input_placeholder, [128, 128])
76     # target_resized = tf.image.resize_area(target_placeholder,[128, 128])
77
78     # Prepare model.
79     model = Voxel_flow_model()
80     prediction, flow = model.inference(input_placeholder)
81     # reproduction_loss, prior_loss = model.loss(prediction, target_placeholder)
82     reproduction_loss = model.loss(prediction, target_placeholder)
83     # total_loss = reproduction_loss + prior_loss
84     total_loss = reproduction_loss
85     
86     # Perform learning rate scheduling.
87     learning_rate = FLAGS.initial_learning_rate
88
89     # Create an optimizer that performs gradient descent.
90     opt = tf.train.AdamOptimizer(learning_rate)
91     grads = opt.compute_gradients(total_loss)
92     update_op = opt.apply_gradients(grads)
93
94     # Create summaries
95     summaries = tf.get_collection(tf.GraphKeys.SUMMARIES)
96     summaries.append(tf.summary.scalar('total_loss', total_loss))
97     summaries.append(tf.summary.scalar('reproduction_loss', reproduction_loss))
98     # summaries.append(tf.summary.scalar('prior_loss', prior_loss))
99     summaries.append(tf.summary.image('Input Image (before)', input_placeholder[:, :, :, 0:3], 3));
100     summaries.append(tf.summary.image('Input Image (after)', input_placeholder[:, :, :, 3:6], 3));
101     summaries.append(tf.summary.image('Output Image', prediction, 3))
102     summaries.append(tf.summary.image('Target Image', target_placeholder, 3))
103     summaries.append(tf.summary.image('Flow', flow, 3))
104
105     # Create a saver.
106     saver = tf.train.Saver(tf.all_variables())
107
108     # Build the summary operation from the last tower summaries.
109     summary_op = tf.summary.merge_all()
110
111     # Restore checkpoint from file.
112     if FLAGS.pretrained_model_checkpoint_path:
113       sess = tf.Session()
114       assert tf.gfile.Exists(FLAGS.pretrained_model_checkpoint_path)
115       ckpt = tf.train.get_checkpoint_state(
116                FLAGS.pretrained_model_checkpoint_path)
117       restorer = tf.train.Saver()
118       restorer.restore(sess, ckpt.model_checkpoint_path)
119       print('%s: Pre-trained model restored from %s' %
120         (datetime.now(), ckpt.model_checkpoint_path))
121       sess.run([batch_frame1.initializer, batch_frame2.initializer, batch_frame3.initializer])
122     else:
123       # Build an initialization operation to run below.
124       init = tf.initialize_all_variables()
125       sess = tf.Session()
126       sess.run([init, batch_frame1.initializer, batch_frame2.initializer, batch_frame3.initializer])
127
128     # Summary Writter
129     summary_writer = tf.summary.FileWriter(
130       FLAGS.train_dir,
131       graph=sess.graph)
132
133     data_size = len(data_list_frame1)
134     epoch_num = int(data_size / FLAGS.batch_size)
135
136     for step in range(0, FLAGS.max_steps):
137       batch_idx = step % epoch_num
138       
139       # Run single step update.
140       _, loss_value = sess.run([update_op, total_loss])
141       
142       if batch_idx == 0:
143         print('Epoch Number: %d' % int(step / epoch_num))
144       
145       if step % 10 == 0:
146         print("Loss at step %d: %f" % (step, loss_value))
147
148       if step % 100 == 0:
149         # Output Summary 
150         summary_str = sess.run(summary_op)
151         summary_writer.add_summary(summary_str, step)
152
153       if step % 500 == 0:
154         # Run a batch of images 
155         prediction_np, target_np = sess.run([prediction, target_placeholder])
156         for i in range(0,prediction_np.shape[0]):
157           file_name = FLAGS.train_image_dir+str(i)+'_out.png'
158           file_name_label = FLAGS.train_image_dir+str(i)+'_gt.png'
159           imwrite(file_name, prediction_np[i,:,:,:])
160           imwrite(file_name_label, target_np[i,:,:,:])
161
162       # Save checkpoint 
163       if step % 500 == 0 or (step +1) == FLAGS.max_steps:
164         checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
165         saver.save(sess, checkpoint_path, global_step=step)
166
167 def validate(dataset_frame1, dataset_frame2, dataset_frame3):
168   """Performs validation on model.
169   Args:  
170   """
171   pass
172
173 def test(dataset_frame1, dataset_frame2, dataset_frame3):
174   """Perform test on a trained model."""
175   with tf.Graph().as_default():
176                 # Create input and target placeholder.
177     input_placeholder = tf.placeholder(tf.float32, shape=(None, 256, 256, 6))
178     target_placeholder = tf.placeholder(tf.float32, shape=(None, 256, 256, 3))
179     
180     # input_resized = tf.image.resize_area(input_placeholder, [128, 128])
181     # target_resized = tf.image.resize_area(target_placeholder,[128, 128])
182
183     # Prepare model.
184     model, flow = Voxel_flow_model(is_train=True)
185     prediction = model.inference(input_placeholder)
186     # reproduction_loss, prior_loss = model.loss(prediction, target_placeholder)
187     reproduction_loss = model.loss(prediction, target_placeholder)
188     # total_loss = reproduction_loss + prior_loss
189     total_loss = reproduction_loss
190
191     # Create a saver and load.
192     saver = tf.train.Saver(tf.all_variables())
193     sess = tf.Session()
194
195     # Restore checkpoint from file.
196     if FLAGS.pretrained_model_checkpoint_path:
197       assert tf.gfile.Exists(FLAGS.pretrained_model_checkpoint_path)
198       ckpt = tf.train.get_checkpoint_state(
199                FLAGS.pretrained_model_checkpoint_path)
200       restorer = tf.train.Saver()
201       restorer.restore(sess, ckpt.model_checkpoint_path)
202       print('%s: Pre-trained model restored from %s' %
203         (datetime.now(), ckpt.model_checkpoint_path))
204     
205     # Process on test dataset.
206     data_list_frame1 = dataset_frame1.read_data_list_file()
207     data_size = len(data_list_frame1)
208     epoch_num = int(data_size / FLAGS.batch_size)
209
210     data_list_frame2 = dataset_frame2.read_data_list_file()
211
212     data_list_frame3 = dataset_frame3.read_data_list_file()
213
214     i = 0 
215     PSNR = 0
216
217     for id_img in range(0, data_size):  
218       # Load single data.
219       line_image_frame1 = dataset_frame1.process_func(data_list_frame1[id_img])
220       line_image_frame2 = dataset_frame2.process_func(data_list_frame2[id_img])
221       line_image_frame3 = dataset_frame3.process_func(data_list_frame3[id_img])
222       
223       batch_data_frame1 = [dataset_frame1.process_func(ll) for ll in data_list_frame1[0:63]]
224       batch_data_frame2 = [dataset_frame2.process_func(ll) for ll in data_list_frame2[0:63]]
225       batch_data_frame3 = [dataset_frame3.process_func(ll) for ll in data_list_frame3[0:63]]
226       
227       batch_data_frame1.append(line_image_frame1)
228       batch_data_frame2.append(line_image_frame2)
229       batch_data_frame3.append(line_image_frame3)
230       
231       batch_data_frame1 = np.array(batch_data_frame1)
232       batch_data_frame2 = np.array(batch_data_frame2)
233       batch_data_frame3 = np.array(batch_data_frame3)
234       
235       feed_dict = {input_placeholder: np.concatenate((batch_data_frame1, batch_data_frame3), 3),
236                     target_placeholder: batch_data_frame2}
237       # Run single step update.
238       prediction_np, target_np, loss_value = sess.run([prediction,
239                                                       target_placeholder,
240                                                       total_loss],
241                                                       feed_dict = feed_dict)
242       print("Loss for image %d: %f" % (i,loss_value))
243       file_name = FLAGS.test_image_dir+str(i)+'_out.png'
244       file_name_label = FLAGS.test_image_dir+str(i)+'_gt.png'
245       imwrite(file_name, prediction_np[-1,:,:,:])
246       imwrite(file_name_label, target_np[-1,:,:,:])
247       i += 1
248       PSNR += 10*np.log10(255.0*255.0/np.sum(np.square(prediction_np-target_np)))
249     print("Overall PSNR: %f db" % (PSNR/len(data_list)))
250       
251 if __name__ == '__main__':
252   
253   os.environ["CUDA_VISIBLE_DEVICES"] = "0"
254
255   if FLAGS.subset == 'train':
256     
257     data_list_path_frame1 = "data_list/ucf101_train_files_frame1.txt"
258     data_list_path_frame2 = "data_list/ucf101_train_files_frame2.txt"
259     data_list_path_frame3 = "data_list/ucf101_train_files_frame3.txt"
260     
261     ucf101_dataset_frame1 = dataset.Dataset(data_list_path_frame1) 
262     ucf101_dataset_frame2 = dataset.Dataset(data_list_path_frame2) 
263     ucf101_dataset_frame3 = dataset.Dataset(data_list_path_frame3) 
264     
265     train(ucf101_dataset_frame1, ucf101_dataset_frame2, ucf101_dataset_frame3)
266   
267   elif FLAGS.subset == 'test':
268     
269     data_list_path_frame1 = "data_list/ucf101_test_files_frame1.txt"
270     data_list_path_frame2 = "data_list/ucf101_test_files_frame2.txt"
271     data_list_path_frame3 = "data_list/ucf101_test_files_frame3.txt"
272     
273     ucf101_dataset_frame1 = dataset.Dataset(data_list_path_frame1) 
274     ucf101_dataset_frame2 = dataset.Dataset(data_list_path_frame2) 
275     ucf101_dataset_frame3 = dataset.Dataset(data_list_path_frame3) 
276     
277     test(ucf101_dataset_frame1, ucf101_dataset_frame2, ucf101_dataset_frame3)