]> git.sesse.net Git - voxel-flow/blobdiff - voxel_flow_train.py
Visualize the flow in a summary op.
[voxel-flow] / voxel_flow_train.py
index 19847f16cb56a8e7433f6a7c158410589868c2b8..42f91af8b22e6ba0b7fb96888c35c03d0fa62dd6 100755 (executable)
@@ -74,7 +74,7 @@ def train(dataset_frame1, dataset_frame2, dataset_frame3):
 
     # Prepare model.
     model = Voxel_flow_model()
-    prediction = model.inference(input_placeholder)
+    prediction, flow = model.inference(input_placeholder)
     # reproduction_loss, prior_loss = model.loss(prediction, target_placeholder)
     reproduction_loss = model.loss(prediction, target_placeholder)
     # total_loss = reproduction_loss + prior_loss
@@ -97,6 +97,7 @@ def train(dataset_frame1, dataset_frame2, dataset_frame3):
     summaries.append(tf.summary.image('Input Image (after)', input_placeholder[:, :, :, 3:6], 3));
     summaries.append(tf.summary.image('Output Image', prediction, 3))
     summaries.append(tf.summary.image('Target Image', target_placeholder, 3))
+    summaries.append(tf.summary.image('Flow', flow, 3))
 
     # Create a saver.
     saver = tf.train.Saver(tf.all_variables())
@@ -176,7 +177,7 @@ def test(dataset_frame1, dataset_frame2, dataset_frame3):
     # target_resized = tf.image.resize_area(target_placeholder,[128, 128])
 
     # Prepare model.
-    model = Voxel_flow_model(is_train=True)
+    model, flow = Voxel_flow_model(is_train=True)
     prediction = model.inference(input_placeholder)
     # reproduction_loss, prior_loss = model.loss(prediction, target_placeholder)
     reproduction_loss = model.loss(prediction, target_placeholder)