]> git.sesse.net Git - voxel-flow/commitdiff
Unbreak non-default batch sizes.
authorSteinar H. Gunderson <sgunderson@bigfoot.com>
Thu, 8 Feb 2018 18:52:37 +0000 (19:52 +0100)
committerSteinar H. Gunderson <sgunderson@bigfoot.com>
Thu, 8 Feb 2018 18:52:37 +0000 (19:52 +0100)
voxel_flow_model.py

index b3f554d59f935ef327a8a1b86d936461988cb779..2cca42054555f85e98998ca7dc67d339e77ad456 100755 (executable)
@@ -10,6 +10,8 @@ from utils.geo_layer_utils import vae_gaussian_layer
 from utils.geo_layer_utils import bilinear_interp
 from utils.geo_layer_utils import meshgrid
 
+FLAGS = tf.app.flags.FLAGS
+
 class Voxel_flow_model(object):
   def __init__(self, is_train=True):
     self.is_train = is_train
@@ -69,8 +71,8 @@ class Voxel_flow_model(object):
     mask = tf.expand_dims(net[:, :, :, 2], 3)
 
     grid_x, grid_y = meshgrid(256, 256)
-    grid_x = tf.tile(grid_x, [32, 1, 1]) # batch_size = 32
-    grid_y = tf.tile(grid_y, [32, 1, 1]) # batch_size = 32
+    grid_x = tf.tile(grid_x, [FLAGS.batch_size, 1, 1])
+    grid_y = tf.tile(grid_y, [FLAGS.batch_size, 1, 1])
 
     flow = 0.5 * flow