]> git.sesse.net Git - voxel-flow/blob - voxel_flow_model.py
Update for TensorFlow 1.0 API.
[voxel-flow] / voxel_flow_model.py
1 """Implements a voxel flow model."""
2 from __future__ import absolute_import
3 from __future__ import division
4 from __future__ import print_function
5
6 import tensorflow as tf
7 import tensorflow.contrib.slim as slim
8 from utils.loss_utils import l1_loss, l2_loss, vae_loss 
9 from utils.geo_layer_utils import vae_gaussian_layer
10 from utils.geo_layer_utils import bilinear_interp
11 from utils.geo_layer_utils import meshgrid
12
13 class Voxel_flow_model(object):
14   def __init__(self, is_train=True):
15     self.is_train = is_train
16
17   def inference(self, input_images):
18     """Inference on a set of input_images.
19     Args:
20     """
21     return self._build_model(input_images) 
22
23   def loss(self, predictions, targets):
24     """Compute the necessary loss for training.
25     Args:
26     Returns:
27     """
28     self.reproduction_loss = l1_loss(predictions, targets) #l2_loss(predictions, targets)
29     # self.prior_loss = vae_loss(self.z_mean, self.z_logvar, prior_weight = 0.0001)
30
31     # return [self.reproduction_loss, self.prior_loss]
32     return self.reproduction_loss
33
34   def _build_model(self, input_images):
35     """Build a VAE model.
36     Args:
37     """
38
39     with slim.arg_scope([slim.conv2d],
40                         activation_fn=tf.nn.relu,
41                         weights_initializer=tf.truncated_normal_initializer(0.0, 0.01),
42                         weights_regularizer=slim.l2_regularizer(0.0001)):
43       
44       # Define network      
45       batch_norm_params = {
46         'decay': 0.9997,
47         'epsilon': 0.001,
48         'is_training': self.is_train,
49       }
50       with slim.arg_scope([slim.batch_norm], is_training = self.is_train, updates_collections=None):
51         with slim.arg_scope([slim.conv2d], normalizer_fn=slim.batch_norm,
52           normalizer_params=batch_norm_params):
53           net = slim.conv2d(input_images, 64, [5, 5], stride=1, scope='conv1')
54           net = slim.max_pool2d(net, [2, 2], scope='pool1')
55           net = slim.conv2d(net, 128, [5, 5], stride=1, scope='conv2')
56           net = slim.max_pool2d(net, [2, 2], scope='pool2')
57           net = slim.conv2d(net, 256, [3, 3], stride=1, scope='conv3')
58           net = slim.max_pool2d(net, [2, 2], scope='pool3')
59           net = tf.image.resize_bilinear(net, [64,64])
60           net = slim.conv2d(net, 256, [3, 3], stride=1, scope='conv4')
61           net = tf.image.resize_bilinear(net, [128,128])
62           net = slim.conv2d(net, 128, [3, 3], stride=1, scope='conv5')
63           net = tf.image.resize_bilinear(net, [256,256])
64           net = slim.conv2d(net, 64, [5, 5], stride=1, scope='conv6')
65     net = slim.conv2d(net, 3, [5, 5], stride=1, activation_fn=tf.tanh,
66     normalizer_fn=None, scope='conv7')
67     
68     flow = net[:, :, :, 0:2]
69     mask = tf.expand_dims(net[:, :, :, 2], 3)
70
71     grid_x, grid_y = meshgrid(256, 256)
72     grid_x = tf.tile(grid_x, [32, 1, 1]) # batch_size = 32
73     grid_y = tf.tile(grid_y, [32, 1, 1]) # batch_size = 32
74
75     flow = 0.5 * flow
76
77     coor_x_1 = grid_x + flow[:, :, :, 0]
78     coor_y_1 = grid_y + flow[:, :, :, 1]
79
80     coor_x_2 = grid_x - flow[:, :, :, 0]
81     coor_y_2 = grid_y - flow[:, :, :, 1]    
82     
83     output_1 = bilinear_interp(input_images[:, :, :, 0:3], coor_x_1, coor_y_1, 'interpolate')
84     output_2 = bilinear_interp(input_images[:, :, :, 3:6], coor_x_2, coor_y_2, 'interpolate')
85
86     mask = 0.5 * (1.0 + mask)
87     mask = tf.tile(mask, [1, 1, 1, 3])
88     net = tf.multiply(mask, output_1) + tf.multiply(1.0 - mask, output_2)
89
90     return net