]> git.sesse.net Git - voxel-flow/blob - utils/geo_layer_utils.py
Update for TensorFlow 1.0 API.
[voxel-flow] / utils / geo_layer_utils.py
1 """Implements various custom tensorflow layer.
2 """
3
4 from __future__ import absolute_import
5 from __future__ import division
6 from __future__ import print_function
7
8 import tensorflow as tf
9 import tensorflow.contrib.slim as slim
10
11 def bilinear_interp(im, x, y, name):
12   """Perform bilinear sampling on im given x, y coordinates
13   
14   This function implements the differentiable sampling mechanism with
15   bilinear kernel. Introduced in https://arxiv.org/abs/1506.02025, equation
16   (5).
17  
18   x,y are tensors specfying normalized coorindates [-1,1] to sample from im.
19   (-1,1) means (0,0) coordinate in im. (1,1) means the most bottom right pixel.
20
21   Args:
22     im: Tensor of size [batch_size, height, width, depth] 
23     x: Tensor of size [batch_size, height, width, 1]
24     y: Tensor of size [batch_size, height, width, 1]
25     name: String for the name for this opt.
26   Returns:
27     Tensor of size [batch_size, height, width, depth]
28   """
29   with tf.variable_scope(name):
30     x = tf.reshape(x, [-1]) 
31     y = tf.reshape(y, [-1]) 
32
33     # constants
34     num_batch = tf.shape(im)[0]
35     _, height, width, channels = im.get_shape().as_list()
36
37     x = tf.to_float(x)
38     y = tf.to_float(y)
39
40     height_f = tf.cast(height, 'float32')
41     width_f = tf.cast(width, 'float32')
42     zero = tf.constant(0, dtype=tf.int32)
43
44     max_x = tf.cast(tf.shape(im)[2] - 1, 'int32')
45     max_y = tf.cast(tf.shape(im)[1] - 1, 'int32')
46     x = (x + 1.0) * (width_f - 1.0) / 2.0
47     y = (y + 1.0) * (height_f - 1.0) / 2.0
48
49     # Sampling
50     x0 = tf.cast(tf.floor(x), 'int32')
51     x1 = x0 + 1
52     y0 = tf.cast(tf.floor(y), 'int32')
53     y1 = y0 + 1
54
55     x0 = tf.clip_by_value(x0, zero, max_x)
56     x1 = tf.clip_by_value(x1, zero, max_x)
57     y0 = tf.clip_by_value(y0, zero, max_y)
58     y1 = tf.clip_by_value(y1, zero, max_y)
59
60     dim2 = width 
61     dim1 = width * height
62
63     # Create base index
64     base = tf.range(num_batch) * dim1
65     base = tf.reshape(base, [-1, 1])
66     base = tf.tile(base, [1, height * width])
67     base = tf.reshape(base, [-1])
68
69     base_y0 = base + y0 * dim2 
70     base_y1 = base + y1 * dim2 
71     idx_a = base_y0 + x0 
72     idx_b = base_y1 + x0
73     idx_c = base_y0 + x1
74     idx_d = base_y1 + x1
75
76     # Use indices to look up pixels
77     im_flat = tf.reshape(im, tf.stack([-1, channels]))
78     im_flat = tf.to_float(im_flat)
79     pixel_a = tf.gather(im_flat, idx_a)
80     pixel_b = tf.gather(im_flat, idx_b)
81     pixel_c = tf.gather(im_flat, idx_c)
82     pixel_d = tf.gather(im_flat, idx_d)
83
84     # Interpolate the values 
85     x1_f = tf.to_float(x1)
86     y1_f = tf.to_float(y1)
87
88     wa = tf.expand_dims((x1_f - x) * (y1_f - y), 1)
89     wb = tf.expand_dims((x1_f - x) * (1.0 - (y1_f - y)), 1)
90     wc = tf.expand_dims((1.0 - (x1_f - x)) * (y1_f - y), 1)
91     wd = tf.expand_dims((1.0 - (x1_f - x)) * (1.0 - (y1_f - y)), 1)
92
93     output = tf.add_n([wa*pixel_a, wb*pixel_b, wc*pixel_c, wd*pixel_d]) 
94     output = tf.reshape(output, shape=tf.stack([num_batch, height, width, channels]))
95     return output
96
97 def meshgrid(height, width):
98   """Tensorflow meshgrid function.
99   """
100   with tf.variable_scope('meshgrid'):
101     x_t = tf.matmul(
102         tf.ones(shape=tf.stack([height,1])),
103         tf.transpose(
104               tf.expand_dims(
105                   tf.linspace(-1.0,1.0,width),1),[1,0]))
106     y_t = tf.matmul(
107         tf.expand_dims(
108             tf.linspace(-1.0, 1.0, height), 1),
109         tf.ones(shape=tf.stack([1, width])))
110     x_t_flat = tf.reshape(x_t, (1,-1))
111     y_t_flat = tf.reshape(y_t, (1,-1))
112     # grid_x = tf.reshape(x_t_flat, [1, height, width, 1])
113     # grid_y = tf.reshape(y_t_flat, [1, height, width, 1])
114     grid_x = tf.reshape(x_t_flat, [1, height, width])
115     grid_y = tf.reshape(y_t_flat, [1, height, width])
116     return grid_x, grid_y 
117
118 def vae_gaussian_layer(network, is_train=True, scope='gaussian_layer'):
119   """Implements a gaussian reparameterization vae layer"""
120   with tf.variable_scope(scope):
121     z_mean, z_logvar = tf.split(3, 2, network)  # Split into mean and variance
122     if is_train:
123       eps = tf.random_normal(tf.shape(z_mean))
124       z = tf.add(z_mean, tf.multiply(eps, tf.exp(tf.multiply(0.5, z_logvar))))
125     else:
126       z = z_mean
127     return z, z_mean, z_logvar