]> git.sesse.net Git - voxel-flow/blobdiff - voxel_flow_train.py
Update for TensorFlow 1.0 API.
[voxel-flow] / voxel_flow_train.py
index dfe87b198e63a3286bb674361b7608b4509a74d6..f6d3abb67a0f062e6750c0b3bd8471174be1b8f6 100755 (executable)
@@ -68,18 +68,18 @@ def train(dataset_frame1, dataset_frame2, dataset_frame3):
 
     # Create summaries
     summaries = tf.get_collection(tf.GraphKeys.SUMMARIES)
-    summaries.append(tf.scalar_summary('total_loss', total_loss))
-    summaries.append(tf.scalar_summary('reproduction_loss', reproduction_loss))
-    # summaries.append(tf.scalar_summary('prior_loss', prior_loss))
-    summaries.append(tf.image_summary('Input Image', input_placeholder, 3))
-    summaries.append(tf.image_summary('Output Image', prediction, 3))
-    summaries.append(tf.image_summary('Target Image', target_placeholder, 3))
+    summaries.append(tf.summary.scalar('total_loss', total_loss))
+    summaries.append(tf.summary.scalar('reproduction_loss', reproduction_loss))
+    # summaries.append(tf.summary.scalar('prior_loss', prior_loss))
+    summaries.append(tf.summary.image('Input Image', input_placeholder, 3))
+    summaries.append(tf.summary.image('Output Image', prediction, 3))
+    summaries.append(tf.summary.image('Target Image', target_placeholder, 3))
 
     # Create a saver.
     saver = tf.train.Saver(tf.all_variables())
 
     # Build the summary operation from the last tower summaries.
-    summary_op = tf.merge_all_summaries()
+    summary_op = tf.summary.merge_all()
 
     # Build an initialization operation to run below.
     init = tf.initialize_all_variables()
@@ -87,7 +87,7 @@ def train(dataset_frame1, dataset_frame2, dataset_frame3):
     sess.run(init)
 
     # Summary Writter
-    summary_writer = tf.train.SummaryWriter(
+    summary_writer = tf.summary.FileWriter(
       FLAGS.train_dir,
       graph=sess.graph)