1 """Implements various custom tensorflow layer.
4 from __future__ import absolute_import
5 from __future__ import division
6 from __future__ import print_function
8 import tensorflow as tf
9 import tensorflow.contrib.slim as slim
11 def bilinear_interp(im, x, y, name):
12 """Perform bilinear sampling on im given x, y coordinates
14 This function implements the differentiable sampling mechanism with
15 bilinear kernel. Introduced in https://arxiv.org/abs/1506.02025, equation
18 x,y are tensors specfying normalized coorindates [-1,1] to sample from im.
19 (-1,1) means (0,0) coordinate in im. (1,1) means the most bottom right pixel.
22 im: Tensor of size [batch_size, height, width, depth]
23 x: Tensor of size [batch_size, height, width, 1]
24 y: Tensor of size [batch_size, height, width, 1]
25 name: String for the name for this opt.
27 Tensor of size [batch_size, height, width, depth]
29 with tf.variable_scope(name):
30 x = tf.reshape(x, [-1])
31 y = tf.reshape(y, [-1])
34 num_batch = tf.shape(im)[0]
35 _, height, width, channels = im.get_shape().as_list()
40 height_f = tf.cast(height, 'float32')
41 width_f = tf.cast(width, 'float32')
42 zero = tf.constant(0, dtype=tf.int32)
44 max_x = tf.cast(tf.shape(im)[2] - 1, 'int32')
45 max_y = tf.cast(tf.shape(im)[1] - 1, 'int32')
46 x = (x + 1.0) * (width_f - 1.0) / 2.0
47 y = (y + 1.0) * (height_f - 1.0) / 2.0
50 x0 = tf.cast(tf.floor(x), 'int32')
52 y0 = tf.cast(tf.floor(y), 'int32')
55 x0 = tf.clip_by_value(x0, zero, max_x)
56 x1 = tf.clip_by_value(x1, zero, max_x)
57 y0 = tf.clip_by_value(y0, zero, max_y)
58 y1 = tf.clip_by_value(y1, zero, max_y)
64 base = tf.range(num_batch) * dim1
65 base = tf.reshape(base, [-1, 1])
66 base = tf.tile(base, [1, height * width])
67 base = tf.reshape(base, [-1])
69 base_y0 = base + y0 * dim2
70 base_y1 = base + y1 * dim2
76 # Use indices to look up pixels
77 im_flat = tf.reshape(im, tf.stack([-1, channels]))
78 im_flat = tf.to_float(im_flat)
79 pixel_a = tf.gather(im_flat, idx_a)
80 pixel_b = tf.gather(im_flat, idx_b)
81 pixel_c = tf.gather(im_flat, idx_c)
82 pixel_d = tf.gather(im_flat, idx_d)
84 # Interpolate the values
85 x1_f = tf.to_float(x1)
86 y1_f = tf.to_float(y1)
88 wa = tf.expand_dims((x1_f - x) * (y1_f - y), 1)
89 wb = tf.expand_dims((x1_f - x) * (1.0 - (y1_f - y)), 1)
90 wc = tf.expand_dims((1.0 - (x1_f - x)) * (y1_f - y), 1)
91 wd = tf.expand_dims((1.0 - (x1_f - x)) * (1.0 - (y1_f - y)), 1)
93 output = tf.add_n([wa*pixel_a, wb*pixel_b, wc*pixel_c, wd*pixel_d])
94 output = tf.reshape(output, shape=tf.stack([num_batch, height, width, channels]))
97 def meshgrid(height, width):
98 """Tensorflow meshgrid function.
100 with tf.variable_scope('meshgrid'):
102 tf.ones(shape=tf.stack([height,1])),
105 tf.linspace(-1.0,1.0,width),1),[1,0]))
108 tf.linspace(-1.0, 1.0, height), 1),
109 tf.ones(shape=tf.stack([1, width])))
110 x_t_flat = tf.reshape(x_t, (1,-1))
111 y_t_flat = tf.reshape(y_t, (1,-1))
112 # grid_x = tf.reshape(x_t_flat, [1, height, width, 1])
113 # grid_y = tf.reshape(y_t_flat, [1, height, width, 1])
114 grid_x = tf.reshape(x_t_flat, [1, height, width])
115 grid_y = tf.reshape(y_t_flat, [1, height, width])
116 return grid_x, grid_y
118 def vae_gaussian_layer(network, is_train=True, scope='gaussian_layer'):
119 """Implements a gaussian reparameterization vae layer"""
120 with tf.variable_scope(scope):
121 z_mean, z_logvar = tf.split(3, 2, network) # Split into mean and variance
123 eps = tf.random_normal(tf.shape(z_mean))
124 z = tf.add(z_mean, tf.multiply(eps, tf.exp(tf.multiply(0.5, z_logvar))))
127 return z, z_mean, z_logvar