]> git.sesse.net Git - voxel-flow/blob - voxel_flow_model.py
Unbreak non-default batch sizes.
[voxel-flow] / voxel_flow_model.py
1 """Implements a voxel flow model."""
2 from __future__ import absolute_import
3 from __future__ import division
4 from __future__ import print_function
5
6 import tensorflow as tf
7 import tensorflow.contrib.slim as slim
8 from utils.loss_utils import l1_loss, l2_loss, vae_loss 
9 from utils.geo_layer_utils import vae_gaussian_layer
10 from utils.geo_layer_utils import bilinear_interp
11 from utils.geo_layer_utils import meshgrid
12
13 FLAGS = tf.app.flags.FLAGS
14
15 class Voxel_flow_model(object):
16   def __init__(self, is_train=True):
17     self.is_train = is_train
18
19   def inference(self, input_images):
20     """Inference on a set of input_images.
21     Args:
22     """
23     return self._build_model(input_images) 
24
25   def loss(self, predictions, targets):
26     """Compute the necessary loss for training.
27     Args:
28     Returns:
29     """
30     self.reproduction_loss = l1_loss(predictions, targets) #l2_loss(predictions, targets)
31     # self.prior_loss = vae_loss(self.z_mean, self.z_logvar, prior_weight = 0.0001)
32
33     # return [self.reproduction_loss, self.prior_loss]
34     return self.reproduction_loss
35
36   def _build_model(self, input_images):
37     """Build a VAE model.
38     Args:
39     """
40
41     with slim.arg_scope([slim.conv2d],
42                         activation_fn=tf.nn.relu,
43                         weights_initializer=tf.truncated_normal_initializer(0.0, 0.01),
44                         weights_regularizer=slim.l2_regularizer(0.0001)):
45       
46       # Define network      
47       batch_norm_params = {
48         'decay': 0.9997,
49         'epsilon': 0.001,
50         'is_training': self.is_train,
51       }
52       with slim.arg_scope([slim.batch_norm], is_training = self.is_train, updates_collections=None):
53         with slim.arg_scope([slim.conv2d], normalizer_fn=slim.batch_norm,
54           normalizer_params=batch_norm_params):
55           net = slim.conv2d(input_images, 64, [5, 5], stride=1, scope='conv1')
56           net = slim.max_pool2d(net, [2, 2], scope='pool1')
57           net = slim.conv2d(net, 128, [5, 5], stride=1, scope='conv2')
58           net = slim.max_pool2d(net, [2, 2], scope='pool2')
59           net = slim.conv2d(net, 256, [3, 3], stride=1, scope='conv3')
60           net = slim.max_pool2d(net, [2, 2], scope='pool3')
61           net = tf.image.resize_bilinear(net, [64,64])
62           net = slim.conv2d(net, 256, [3, 3], stride=1, scope='conv4')
63           net = tf.image.resize_bilinear(net, [128,128])
64           net = slim.conv2d(net, 128, [3, 3], stride=1, scope='conv5')
65           net = tf.image.resize_bilinear(net, [256,256])
66           net = slim.conv2d(net, 64, [5, 5], stride=1, scope='conv6')
67     net = slim.conv2d(net, 3, [5, 5], stride=1, activation_fn=tf.tanh,
68     normalizer_fn=None, scope='conv7')
69     
70     flow = net[:, :, :, 0:2]
71     mask = tf.expand_dims(net[:, :, :, 2], 3)
72
73     grid_x, grid_y = meshgrid(256, 256)
74     grid_x = tf.tile(grid_x, [FLAGS.batch_size, 1, 1])
75     grid_y = tf.tile(grid_y, [FLAGS.batch_size, 1, 1])
76
77     flow = 0.5 * flow
78
79     coor_x_1 = grid_x + flow[:, :, :, 0]
80     coor_y_1 = grid_y + flow[:, :, :, 1]
81
82     coor_x_2 = grid_x - flow[:, :, :, 0]
83     coor_y_2 = grid_y - flow[:, :, :, 1]    
84     
85     output_1 = bilinear_interp(input_images[:, :, :, 0:3], coor_x_1, coor_y_1, 'interpolate')
86     output_2 = bilinear_interp(input_images[:, :, :, 3:6], coor_x_2, coor_y_2, 'interpolate')
87
88     mask = 0.5 * (1.0 + mask)
89     mask = tf.tile(mask, [1, 1, 1, 3])
90     net = tf.multiply(mask, output_1) + tf.multiply(1.0 - mask, output_2)
91
92     return net