1 """Implements a voxel flow model."""
2 from __future__ import absolute_import
3 from __future__ import division
4 from __future__ import print_function
6 import tensorflow as tf
7 import tensorflow.contrib.slim as slim
8 from utils.loss_utils import l1_loss, l2_loss, vae_loss
9 from utils.geo_layer_utils import vae_gaussian_layer
10 from utils.geo_layer_utils import bilinear_interp
11 from utils.geo_layer_utils import meshgrid
13 class Voxel_flow_model(object):
14 def __init__(self, is_train=True):
15 self.is_train = is_train
17 def inference(self, input_images):
18 """Inference on a set of input_images.
21 return self._build_model(input_images)
23 def loss(self, predictions, targets):
24 """Compute the necessary loss for training.
28 self.reproduction_loss = l1_loss(predictions, targets) #l2_loss(predictions, targets)
29 # self.prior_loss = vae_loss(self.z_mean, self.z_logvar, prior_weight = 0.0001)
31 # return [self.reproduction_loss, self.prior_loss]
32 return self.reproduction_loss
34 def _build_model(self, input_images):
39 with slim.arg_scope([slim.conv2d],
40 activation_fn=tf.nn.relu,
41 weights_initializer=tf.truncated_normal_initializer(0.0, 0.01),
42 weights_regularizer=slim.l2_regularizer(0.0001)):
48 'is_training': self.is_train,
50 with slim.arg_scope([slim.batch_norm], is_training = self.is_train, updates_collections=None):
51 with slim.arg_scope([slim.conv2d], normalizer_fn=slim.batch_norm,
52 normalizer_params=batch_norm_params):
53 net = slim.conv2d(input_images, 64, [5, 5], stride=1, scope='conv1')
54 net = slim.max_pool2d(net, [2, 2], scope='pool1')
55 net = slim.conv2d(net, 128, [5, 5], stride=1, scope='conv2')
56 net = slim.max_pool2d(net, [2, 2], scope='pool2')
57 net = slim.conv2d(net, 256, [3, 3], stride=1, scope='conv3')
58 net = slim.max_pool2d(net, [2, 2], scope='pool3')
59 net = tf.image.resize_bilinear(net, [64,64])
60 net = slim.conv2d(net, 256, [3, 3], stride=1, scope='conv4')
61 net = tf.image.resize_bilinear(net, [128,128])
62 net = slim.conv2d(net, 128, [3, 3], stride=1, scope='conv5')
63 net = tf.image.resize_bilinear(net, [256,256])
64 net = slim.conv2d(net, 64, [5, 5], stride=1, scope='conv6')
65 net = slim.conv2d(net, 3, [5, 5], stride=1, activation_fn=tf.tanh,
66 normalizer_fn=None, scope='conv7')
68 flow = net[:, :, :, 0:2]
69 mask = tf.expand_dims(net[:, :, :, 2], 3)
71 grid_x, grid_y = meshgrid(256, 256)
72 grid_x = tf.tile(grid_x, [32, 1, 1]) # batch_size = 32
73 grid_y = tf.tile(grid_y, [32, 1, 1]) # batch_size = 32
77 coor_x_1 = grid_x + flow[:, :, :, 0]
78 coor_y_1 = grid_y + flow[:, :, :, 1]
80 coor_x_2 = grid_x - flow[:, :, :, 0]
81 coor_y_2 = grid_y - flow[:, :, :, 1]
83 output_1 = bilinear_interp(input_images[:, :, :, 0:3], coor_x_1, coor_y_1, 'interpolate')
84 output_2 = bilinear_interp(input_images[:, :, :, 3:6], coor_x_2, coor_y_2, 'interpolate')
86 mask = 0.5 * (1.0 + mask)
87 mask = tf.tile(mask, [1, 1, 1, 3])
88 net = tf.mul(mask, output_1) + tf.mul(1.0 - mask, output_2)