X-Git-Url: https://git.sesse.net/?a=blobdiff_plain;f=voxel_flow_train.py;h=f6d3abb67a0f062e6750c0b3bd8471174be1b8f6;hb=c74057bc77fb9bc4eb75668896adfb85ddb99251;hp=856ba857070fdd007f977eb7f7d76859f22987e0;hpb=bc1177902e6abfae7a65cf68b3172ba96c07ce05;p=voxel-flow diff --git a/voxel_flow_train.py b/voxel_flow_train.py index 856ba85..f6d3abb 100755 --- a/voxel_flow_train.py +++ b/voxel_flow_train.py @@ -52,7 +52,7 @@ def train(dataset_frame1, dataset_frame2, dataset_frame3): # Prepare model. model = Voxel_flow_model() - prediction = model.inference(input_placeholder) d + prediction = model.inference(input_placeholder) # reproduction_loss, prior_loss = model.loss(prediction, target_placeholder) reproduction_loss = model.loss(prediction, target_placeholder) # total_loss = reproduction_loss + prior_loss @@ -68,18 +68,18 @@ def train(dataset_frame1, dataset_frame2, dataset_frame3): # Create summaries summaries = tf.get_collection(tf.GraphKeys.SUMMARIES) - summaries.append(tf.scalar_summary('total_loss', total_loss)) - summaries.append(tf.scalar_summary('reproduction_loss', reproduction_loss)) - # summaries.append(tf.scalar_summary('prior_loss', prior_loss)) - summaries.append(tf.image_summary('Input Image', input_placeholder, 3)) - summaries.append(tf.image_summary('Output Image', prediction, 3)) - summaries.append(tf.image_summary('Target Image', target_placeholder, 3)) + summaries.append(tf.summary.scalar('total_loss', total_loss)) + summaries.append(tf.summary.scalar('reproduction_loss', reproduction_loss)) + # summaries.append(tf.summary.scalar('prior_loss', prior_loss)) + summaries.append(tf.summary.image('Input Image', input_placeholder, 3)) + summaries.append(tf.summary.image('Output Image', prediction, 3)) + summaries.append(tf.summary.image('Target Image', target_placeholder, 3)) # Create a saver. saver = tf.train.Saver(tf.all_variables()) # Build the summary operation from the last tower summaries. - summary_op = tf.merge_all_summaries() + summary_op = tf.summary.merge_all() # Build an initialization operation to run below. init = tf.initialize_all_variables() @@ -87,7 +87,7 @@ def train(dataset_frame1, dataset_frame2, dataset_frame3): sess.run(init) # Summary Writter - summary_writer = tf.train.SummaryWriter( + summary_writer = tf.summary.FileWriter( FLAGS.train_dir, graph=sess.graph)