From: Steinar H. Gunderson Date: Thu, 8 Feb 2018 18:48:31 +0000 (+0100) Subject: Update for TensorFlow 1.0 API. X-Git-Url: https://git.sesse.net/?p=voxel-flow;a=commitdiff_plain;h=c74057bc77fb9bc4eb75668896adfb85ddb99251 Update for TensorFlow 1.0 API. --- diff --git a/utils/geo_layer_utils.py b/utils/geo_layer_utils.py index 5e4bfab..497017b 100755 --- a/utils/geo_layer_utils.py +++ b/utils/geo_layer_utils.py @@ -74,7 +74,7 @@ def bilinear_interp(im, x, y, name): idx_d = base_y1 + x1 # Use indices to look up pixels - im_flat = tf.reshape(im, tf.pack([-1, channels])) + im_flat = tf.reshape(im, tf.stack([-1, channels])) im_flat = tf.to_float(im_flat) pixel_a = tf.gather(im_flat, idx_a) pixel_b = tf.gather(im_flat, idx_b) @@ -91,7 +91,7 @@ def bilinear_interp(im, x, y, name): wd = tf.expand_dims((1.0 - (x1_f - x)) * (1.0 - (y1_f - y)), 1) output = tf.add_n([wa*pixel_a, wb*pixel_b, wc*pixel_c, wd*pixel_d]) - output = tf.reshape(output, shape=tf.pack([num_batch, height, width, channels])) + output = tf.reshape(output, shape=tf.stack([num_batch, height, width, channels])) return output def meshgrid(height, width): @@ -99,14 +99,14 @@ def meshgrid(height, width): """ with tf.variable_scope('meshgrid'): x_t = tf.matmul( - tf.ones(shape=tf.pack([height,1])), + tf.ones(shape=tf.stack([height,1])), tf.transpose( tf.expand_dims( tf.linspace(-1.0,1.0,width),1),[1,0])) y_t = tf.matmul( tf.expand_dims( tf.linspace(-1.0, 1.0, height), 1), - tf.ones(shape=tf.pack([1, width]))) + tf.ones(shape=tf.stack([1, width]))) x_t_flat = tf.reshape(x_t, (1,-1)) y_t_flat = tf.reshape(y_t, (1,-1)) # grid_x = tf.reshape(x_t_flat, [1, height, width, 1]) @@ -121,7 +121,7 @@ def vae_gaussian_layer(network, is_train=True, scope='gaussian_layer'): z_mean, z_logvar = tf.split(3, 2, network) # Split into mean and variance if is_train: eps = tf.random_normal(tf.shape(z_mean)) - z = tf.add(z_mean, tf.mul(eps, tf.exp(tf.mul(0.5, z_logvar)))) + z = tf.add(z_mean, tf.multiply(eps, tf.exp(tf.multiply(0.5, z_logvar)))) else: z = z_mean return z, z_mean, z_logvar diff --git a/voxel_flow_model.py b/voxel_flow_model.py index d419acf..b3f554d 100755 --- a/voxel_flow_model.py +++ b/voxel_flow_model.py @@ -85,6 +85,6 @@ class Voxel_flow_model(object): mask = 0.5 * (1.0 + mask) mask = tf.tile(mask, [1, 1, 1, 3]) - net = tf.mul(mask, output_1) + tf.mul(1.0 - mask, output_2) + net = tf.multiply(mask, output_1) + tf.multiply(1.0 - mask, output_2) return net diff --git a/voxel_flow_train.py b/voxel_flow_train.py index dfe87b1..f6d3abb 100755 --- a/voxel_flow_train.py +++ b/voxel_flow_train.py @@ -68,18 +68,18 @@ def train(dataset_frame1, dataset_frame2, dataset_frame3): # Create summaries summaries = tf.get_collection(tf.GraphKeys.SUMMARIES) - summaries.append(tf.scalar_summary('total_loss', total_loss)) - summaries.append(tf.scalar_summary('reproduction_loss', reproduction_loss)) - # summaries.append(tf.scalar_summary('prior_loss', prior_loss)) - summaries.append(tf.image_summary('Input Image', input_placeholder, 3)) - summaries.append(tf.image_summary('Output Image', prediction, 3)) - summaries.append(tf.image_summary('Target Image', target_placeholder, 3)) + summaries.append(tf.summary.scalar('total_loss', total_loss)) + summaries.append(tf.summary.scalar('reproduction_loss', reproduction_loss)) + # summaries.append(tf.summary.scalar('prior_loss', prior_loss)) + summaries.append(tf.summary.image('Input Image', input_placeholder, 3)) + summaries.append(tf.summary.image('Output Image', prediction, 3)) + summaries.append(tf.summary.image('Target Image', target_placeholder, 3)) # Create a saver. saver = tf.train.Saver(tf.all_variables()) # Build the summary operation from the last tower summaries. - summary_op = tf.merge_all_summaries() + summary_op = tf.summary.merge_all() # Build an initialization operation to run below. init = tf.initialize_all_variables() @@ -87,7 +87,7 @@ def train(dataset_frame1, dataset_frame2, dataset_frame3): sess.run(init) # Summary Writter - summary_writer = tf.train.SummaryWriter( + summary_writer = tf.summary.FileWriter( FLAGS.train_dir, graph=sess.graph)