X-Git-Url: https://git.sesse.net/?a=blobdiff_plain;f=voxel_flow_train.py;fp=voxel_flow_train.py;h=42f91af8b22e6ba0b7fb96888c35c03d0fa62dd6;hb=e80cd1567ea62da68f3483f8f616d148478adb13;hp=19847f16cb56a8e7433f6a7c158410589868c2b8;hpb=9ca9b0416d25aecce26313f0c9a2a45c61088661;p=voxel-flow diff --git a/voxel_flow_train.py b/voxel_flow_train.py index 19847f1..42f91af 100755 --- a/voxel_flow_train.py +++ b/voxel_flow_train.py @@ -74,7 +74,7 @@ def train(dataset_frame1, dataset_frame2, dataset_frame3): # Prepare model. model = Voxel_flow_model() - prediction = model.inference(input_placeholder) + prediction, flow = model.inference(input_placeholder) # reproduction_loss, prior_loss = model.loss(prediction, target_placeholder) reproduction_loss = model.loss(prediction, target_placeholder) # total_loss = reproduction_loss + prior_loss @@ -97,6 +97,7 @@ def train(dataset_frame1, dataset_frame2, dataset_frame3): summaries.append(tf.summary.image('Input Image (after)', input_placeholder[:, :, :, 3:6], 3)); summaries.append(tf.summary.image('Output Image', prediction, 3)) summaries.append(tf.summary.image('Target Image', target_placeholder, 3)) + summaries.append(tf.summary.image('Flow', flow, 3)) # Create a saver. saver = tf.train.Saver(tf.all_variables()) @@ -176,7 +177,7 @@ def test(dataset_frame1, dataset_frame2, dataset_frame3): # target_resized = tf.image.resize_area(target_placeholder,[128, 128]) # Prepare model. - model = Voxel_flow_model(is_train=True) + model, flow = Voxel_flow_model(is_train=True) prediction = model.inference(input_placeholder) # reproduction_loss, prior_loss = model.loss(prediction, target_placeholder) reproduction_loss = model.loss(prediction, target_placeholder)