]> git.sesse.net Git - voxel-flow/blobdiff - utils/geo_layer_utils.py
Update for TensorFlow 1.0 API.
[voxel-flow] / utils / geo_layer_utils.py
index 5e4bfaba11bdf394ccc3743f8e4515b2d8c46868..497017b1f52866b6b6140e2b5c9dc41a4eff875c 100755 (executable)
@@ -74,7 +74,7 @@ def bilinear_interp(im, x, y, name):
     idx_d = base_y1 + x1
 
     # Use indices to look up pixels
-    im_flat = tf.reshape(im, tf.pack([-1, channels]))
+    im_flat = tf.reshape(im, tf.stack([-1, channels]))
     im_flat = tf.to_float(im_flat)
     pixel_a = tf.gather(im_flat, idx_a)
     pixel_b = tf.gather(im_flat, idx_b)
@@ -91,7 +91,7 @@ def bilinear_interp(im, x, y, name):
     wd = tf.expand_dims((1.0 - (x1_f - x)) * (1.0 - (y1_f - y)), 1)
 
     output = tf.add_n([wa*pixel_a, wb*pixel_b, wc*pixel_c, wd*pixel_d]) 
-    output = tf.reshape(output, shape=tf.pack([num_batch, height, width, channels]))
+    output = tf.reshape(output, shape=tf.stack([num_batch, height, width, channels]))
     return output
 
 def meshgrid(height, width):
@@ -99,14 +99,14 @@ def meshgrid(height, width):
   """
   with tf.variable_scope('meshgrid'):
     x_t = tf.matmul(
-        tf.ones(shape=tf.pack([height,1])),
+        tf.ones(shape=tf.stack([height,1])),
         tf.transpose(
               tf.expand_dims(
                   tf.linspace(-1.0,1.0,width),1),[1,0]))
     y_t = tf.matmul(
         tf.expand_dims(
             tf.linspace(-1.0, 1.0, height), 1),
-        tf.ones(shape=tf.pack([1, width])))
+        tf.ones(shape=tf.stack([1, width])))
     x_t_flat = tf.reshape(x_t, (1,-1))
     y_t_flat = tf.reshape(y_t, (1,-1))
     # grid_x = tf.reshape(x_t_flat, [1, height, width, 1])
@@ -121,7 +121,7 @@ def vae_gaussian_layer(network, is_train=True, scope='gaussian_layer'):
     z_mean, z_logvar = tf.split(3, 2, network)  # Split into mean and variance
     if is_train:
       eps = tf.random_normal(tf.shape(z_mean))
-      z = tf.add(z_mean, tf.mul(eps, tf.exp(tf.mul(0.5, z_logvar))))
+      z = tf.add(z_mean, tf.multiply(eps, tf.exp(tf.multiply(0.5, z_logvar))))
     else:
       z = z_mean
     return z, z_mean, z_logvar