1 """Implements a voxel flow model."""
2 from __future__ import absolute_import
3 from __future__ import division
4 from __future__ import print_function
6 import tensorflow as tf
7 import tensorflow.contrib.slim as slim
8 from utils.loss_utils import l1_loss, l2_loss, vae_loss
9 from utils.geo_layer_utils import vae_gaussian_layer
10 from utils.geo_layer_utils import bilinear_interp
11 from utils.geo_layer_utils import meshgrid
13 FLAGS = tf.app.flags.FLAGS
15 class Voxel_flow_model(object):
16 def __init__(self, is_train=True):
17 self.is_train = is_train
19 def inference(self, input_images):
20 """Inference on a set of input_images.
23 return self._build_model(input_images)
25 def loss(self, predictions, targets):
26 """Compute the necessary loss for training.
30 self.reproduction_loss = l1_loss(predictions, targets) #l2_loss(predictions, targets)
31 # self.prior_loss = vae_loss(self.z_mean, self.z_logvar, prior_weight = 0.0001)
33 # return [self.reproduction_loss, self.prior_loss]
34 return self.reproduction_loss
36 def _build_model(self, input_images):
41 with slim.arg_scope([slim.conv2d],
42 activation_fn=tf.nn.relu,
43 weights_initializer=tf.truncated_normal_initializer(0.0, 0.01),
44 weights_regularizer=slim.l2_regularizer(0.0001)):
50 'is_training': self.is_train,
52 with slim.arg_scope([slim.batch_norm], is_training = self.is_train, updates_collections=None):
53 with slim.arg_scope([slim.conv2d], normalizer_fn=slim.batch_norm,
54 normalizer_params=batch_norm_params):
55 net = slim.conv2d(input_images, 64, [5, 5], stride=1, scope='conv1')
56 net = slim.max_pool2d(net, [2, 2], scope='pool1')
57 net = slim.conv2d(net, 128, [5, 5], stride=1, scope='conv2')
58 net = slim.max_pool2d(net, [2, 2], scope='pool2')
59 net = slim.conv2d(net, 256, [3, 3], stride=1, scope='conv3')
60 net = slim.max_pool2d(net, [2, 2], scope='pool3')
61 net = tf.image.resize_bilinear(net, [64,64])
62 net = slim.conv2d(net, 256, [3, 3], stride=1, scope='conv4')
63 net = tf.image.resize_bilinear(net, [128,128])
64 net = slim.conv2d(net, 128, [3, 3], stride=1, scope='conv5')
65 net = tf.image.resize_bilinear(net, [256,256])
66 net = slim.conv2d(net, 64, [5, 5], stride=1, scope='conv6')
67 net = slim.conv2d(net, 3, [5, 5], stride=1, activation_fn=tf.tanh,
68 normalizer_fn=None, scope='conv7')
71 flow = net[:, :, :, 0:2]
72 mask = tf.expand_dims(net[:, :, :, 2], 3)
74 grid_x, grid_y = meshgrid(256, 256)
75 grid_x = tf.tile(grid_x, [FLAGS.batch_size, 1, 1])
76 grid_y = tf.tile(grid_y, [FLAGS.batch_size, 1, 1])
80 coor_x_1 = grid_x + flow[:, :, :, 0]
81 coor_y_1 = grid_y + flow[:, :, :, 1]
83 coor_x_2 = grid_x - flow[:, :, :, 0]
84 coor_y_2 = grid_y - flow[:, :, :, 1]
86 output_1 = bilinear_interp(input_images[:, :, :, 0:3], coor_x_1, coor_y_1, 'interpolate')
87 output_2 = bilinear_interp(input_images[:, :, :, 3:6], coor_x_2, coor_y_2, 'interpolate')
89 mask = 0.5 * (1.0 + mask)
90 mask = tf.tile(mask, [1, 1, 1, 3])
91 net = tf.multiply(mask, output_1) + tf.multiply(1.0 - mask, output_2)
93 return [net, net_copy]