X-Git-Url: https://git.sesse.net/?a=blobdiff_plain;f=voxel_flow_model.py;h=2cca42054555f85e98998ca7dc67d339e77ad456;hb=9ca9b0416d25aecce26313f0c9a2a45c61088661;hp=d419acf6a4493cd534c881c60fbee73bd2e57bff;hpb=bc1177902e6abfae7a65cf68b3172ba96c07ce05;p=voxel-flow diff --git a/voxel_flow_model.py b/voxel_flow_model.py index d419acf..2cca420 100755 --- a/voxel_flow_model.py +++ b/voxel_flow_model.py @@ -10,6 +10,8 @@ from utils.geo_layer_utils import vae_gaussian_layer from utils.geo_layer_utils import bilinear_interp from utils.geo_layer_utils import meshgrid +FLAGS = tf.app.flags.FLAGS + class Voxel_flow_model(object): def __init__(self, is_train=True): self.is_train = is_train @@ -69,8 +71,8 @@ class Voxel_flow_model(object): mask = tf.expand_dims(net[:, :, :, 2], 3) grid_x, grid_y = meshgrid(256, 256) - grid_x = tf.tile(grid_x, [32, 1, 1]) # batch_size = 32 - grid_y = tf.tile(grid_y, [32, 1, 1]) # batch_size = 32 + grid_x = tf.tile(grid_x, [FLAGS.batch_size, 1, 1]) + grid_y = tf.tile(grid_y, [FLAGS.batch_size, 1, 1]) flow = 0.5 * flow @@ -85,6 +87,6 @@ class Voxel_flow_model(object): mask = 0.5 * (1.0 + mask) mask = tf.tile(mask, [1, 1, 1, 3]) - net = tf.mul(mask, output_1) + tf.mul(1.0 - mask, output_2) + net = tf.multiply(mask, output_1) + tf.multiply(1.0 - mask, output_2) return net