X-Git-Url: https://git.sesse.net/?a=blobdiff_plain;f=voxel_flow_train.py;fp=voxel_flow_train.py;h=e1fa96f2c3102f226148b18283eb1f5855cec4eb;hb=e46b852bfceb21baf504f2dbe39cf1fd31f47cc3;hp=6b2bcb440a7598a5cdda25ca76ab8a1e70b4b44b;hpb=d55e5f71d6a84b7f39ab143bc5b56bb24d35cba1;p=voxel-flow diff --git a/voxel_flow_train.py b/voxel_flow_train.py index 6b2bcb4..e1fa96f 100755 --- a/voxel_flow_train.py +++ b/voxel_flow_train.py @@ -82,10 +82,21 @@ def train(dataset_frame1, dataset_frame2, dataset_frame3): # Build the summary operation from the last tower summaries. summary_op = tf.summary.merge_all() - # Build an initialization operation to run below. - init = tf.initialize_all_variables() - sess = tf.Session() - sess.run(init) + # Restore checkpoint from file. + if FLAGS.pretrained_model_checkpoint_path: + sess = tf.Session() + assert tf.gfile.Exists(FLAGS.pretrained_model_checkpoint_path) + ckpt = tf.train.get_checkpoint_state( + FLAGS.pretrained_model_checkpoint_path) + restorer = tf.train.Saver() + restorer.restore(sess, ckpt.model_checkpoint_path) + print('%s: Pre-trained model restored from %s' % + (datetime.now(), ckpt.model_checkpoint_path)) + else: + # Build an initialization operation to run below. + init = tf.initialize_all_variables() + sess = tf.Session() + sess.run(init) # Summary Writter summary_writer = tf.summary.FileWriter(