X-Git-Url: https://git.sesse.net/?a=blobdiff_plain;f=voxel_flow_train.py;h=19847f16cb56a8e7433f6a7c158410589868c2b8;hb=9ca9b0416d25aecce26313f0c9a2a45c61088661;hp=856ba857070fdd007f977eb7f7d76859f22987e0;hpb=bc1177902e6abfae7a65cf68b3172ba96c07ce05;p=voxel-flow diff --git a/voxel_flow_train.py b/voxel_flow_train.py index 856ba85..19847f1 100755 --- a/voxel_flow_train.py +++ b/voxel_flow_train.py @@ -39,20 +39,42 @@ tf.app.flags.DEFINE_integer( tf.app.flags.DEFINE_float('initial_learning_rate', 0.0003, """Initial learning rate.""") +def _read_image(filename): + image_string = tf.read_file(filename) + image_decoded = tf.image.decode_image(image_string, channels=3) + image_decoded.set_shape([256, 256, 3]) + return tf.cast(image_decoded, dtype=tf.float32) / 127.5 - 1.0 def train(dataset_frame1, dataset_frame2, dataset_frame3): """Trains a model.""" with tf.Graph().as_default(): + # Create input. + data_list_frame1 = dataset_frame1.read_data_list_file() + dataset_frame1 = tf.data.Dataset.from_tensor_slices(tf.constant(data_list_frame1)) + dataset_frame1 = dataset_frame1.repeat().shuffle(buffer_size=1000000, seed=1).map(_read_image) + + data_list_frame2 = dataset_frame2.read_data_list_file() + dataset_frame2 = tf.data.Dataset.from_tensor_slices(tf.constant(data_list_frame2)) + dataset_frame2 = dataset_frame2.repeat().shuffle(buffer_size=1000000, seed=1).map(_read_image) + + data_list_frame3 = dataset_frame3.read_data_list_file() + dataset_frame3 = tf.data.Dataset.from_tensor_slices(tf.constant(data_list_frame3)) + dataset_frame3 = dataset_frame3.repeat().shuffle(buffer_size=1000000, seed=1).map(_read_image) + + batch_frame1 = dataset_frame1.batch(FLAGS.batch_size).make_initializable_iterator() + batch_frame2 = dataset_frame2.batch(FLAGS.batch_size).make_initializable_iterator() + batch_frame3 = dataset_frame3.batch(FLAGS.batch_size).make_initializable_iterator() + # Create input and target placeholder. - input_placeholder = tf.placeholder(tf.float32, shape=(None, 256, 256, 6)) - target_placeholder = tf.placeholder(tf.float32, shape=(None, 256, 256, 3)) + input_placeholder = tf.concat([batch_frame1.get_next(), batch_frame3.get_next()], 3) + target_placeholder = batch_frame2.get_next() # input_resized = tf.image.resize_area(input_placeholder, [128, 128]) # target_resized = tf.image.resize_area(target_placeholder,[128, 128]) # Prepare model. model = Voxel_flow_model() - prediction = model.inference(input_placeholder) d + prediction = model.inference(input_placeholder) # reproduction_loss, prior_loss = model.loss(prediction, target_placeholder) reproduction_loss = model.loss(prediction, target_placeholder) # total_loss = reproduction_loss + prior_loss @@ -68,96 +90,64 @@ def train(dataset_frame1, dataset_frame2, dataset_frame3): # Create summaries summaries = tf.get_collection(tf.GraphKeys.SUMMARIES) - summaries.append(tf.scalar_summary('total_loss', total_loss)) - summaries.append(tf.scalar_summary('reproduction_loss', reproduction_loss)) - # summaries.append(tf.scalar_summary('prior_loss', prior_loss)) - summaries.append(tf.image_summary('Input Image', input_placeholder, 3)) - summaries.append(tf.image_summary('Output Image', prediction, 3)) - summaries.append(tf.image_summary('Target Image', target_placeholder, 3)) + summaries.append(tf.summary.scalar('total_loss', total_loss)) + summaries.append(tf.summary.scalar('reproduction_loss', reproduction_loss)) + # summaries.append(tf.summary.scalar('prior_loss', prior_loss)) + summaries.append(tf.summary.image('Input Image (before)', input_placeholder[:, :, :, 0:3], 3)); + summaries.append(tf.summary.image('Input Image (after)', input_placeholder[:, :, :, 3:6], 3)); + summaries.append(tf.summary.image('Output Image', prediction, 3)) + summaries.append(tf.summary.image('Target Image', target_placeholder, 3)) # Create a saver. saver = tf.train.Saver(tf.all_variables()) # Build the summary operation from the last tower summaries. - summary_op = tf.merge_all_summaries() + summary_op = tf.summary.merge_all() - # Build an initialization operation to run below. - init = tf.initialize_all_variables() - sess = tf.Session() - sess.run(init) + # Restore checkpoint from file. + if FLAGS.pretrained_model_checkpoint_path: + sess = tf.Session() + assert tf.gfile.Exists(FLAGS.pretrained_model_checkpoint_path) + ckpt = tf.train.get_checkpoint_state( + FLAGS.pretrained_model_checkpoint_path) + restorer = tf.train.Saver() + restorer.restore(sess, ckpt.model_checkpoint_path) + print('%s: Pre-trained model restored from %s' % + (datetime.now(), ckpt.model_checkpoint_path)) + else: + # Build an initialization operation to run below. + init = tf.initialize_all_variables() + sess = tf.Session() + sess.run([init, batch_frame1.initializer, batch_frame2.initializer, batch_frame3.initializer]) # Summary Writter - summary_writer = tf.train.SummaryWriter( + summary_writer = tf.summary.FileWriter( FLAGS.train_dir, graph=sess.graph) - # Training loop using feed dict method. - data_list_frame1 = dataset_frame1.read_data_list_file() - random.seed(1) - shuffle(data_list_frame1) - - data_list_frame2 = dataset_frame2.read_data_list_file() - random.seed(1) - shuffle(data_list_frame2) - - data_list_frame3 = dataset_frame3.read_data_list_file() - random.seed(1) - shuffle(data_list_frame3) - data_size = len(data_list_frame1) epoch_num = int(data_size / FLAGS.batch_size) - # num_workers = 1 - - # load_fn_frame1 = partial(dataset_frame1.process_func) - # p_queue_frame1 = PrefetchQueue(load_fn_frame1, data_list_frame1, FLAGS.batch_size, shuffle=False, num_workers=num_workers) - - # load_fn_frame2 = partial(dataset_frame2.process_func) - # p_queue_frame2 = PrefetchQueue(load_fn_frame2, data_list_frame2, FLAGS.batch_size, shuffle=False, num_workers=num_workers) - - # load_fn_frame3 = partial(dataset_frame3.process_func) - # p_queue_frame3 = PrefetchQueue(load_fn_frame3, data_list_frame3, FLAGS.batch_size, shuffle=False, num_workers=num_workers) - - for step in xrange(0, FLAGS.max_steps): + for step in range(0, FLAGS.max_steps): batch_idx = step % epoch_num - batch_data_list_frame1 = data_list_frame1[int(batch_idx * FLAGS.batch_size) : int((batch_idx + 1) * FLAGS.batch_size)] - batch_data_list_frame2 = data_list_frame2[int(batch_idx * FLAGS.batch_size) : int((batch_idx + 1) * FLAGS.batch_size)] - batch_data_list_frame3 = data_list_frame3[int(batch_idx * FLAGS.batch_size) : int((batch_idx + 1) * FLAGS.batch_size)] - - # Load batch data. - batch_data_frame1 = np.array([dataset_frame1.process_func(line) for line in batch_data_list_frame1]) - batch_data_frame2 = np.array([dataset_frame2.process_func(line) for line in batch_data_list_frame2]) - batch_data_frame3 = np.array([dataset_frame3.process_func(line) for line in batch_data_list_frame3]) - - # batch_data_frame1 = p_queue_frame1.get_batch() - # batch_data_frame2 = p_queue_frame2.get_batch() - # batch_data_frame3 = p_queue_frame3.get_batch() - - feed_dict = {input_placeholder: np.concatenate((batch_data_frame1, batch_data_frame3), 3), target_placeholder: batch_data_frame2} - # Run single step update. - _, loss_value = sess.run([update_op, total_loss], feed_dict = feed_dict) + _, loss_value = sess.run([update_op, total_loss]) if batch_idx == 0: - # Shuffle data at each epoch. - random.seed(1) - shuffle(data_list_frame1) - random.seed(1) - shuffle(data_list_frame2) - random.seed(1) - shuffle(data_list_frame3) print('Epoch Number: %d' % int(step / epoch_num)) - # Output Summary if step % 10 == 0: - # summary_str = sess.run(summary_op, feed_dict = feed_dict) - # summary_writer.add_summary(summary_str, step) - print("Loss at step %d: %f" % (step, loss_value)) + print("Loss at step %d: %f" % (step, loss_value)) + + if step % 100 == 0: + # Output Summary + summary_str = sess.run(summary_op) + summary_writer.add_summary(summary_str, step) if step % 500 == 0: # Run a batch of images - prediction_np, target_np = sess.run([prediction, target_placeholder], feed_dict = feed_dict) + prediction_np, target_np = sess.run([prediction, target_placeholder]) for i in range(0,prediction_np.shape[0]): file_name = FLAGS.train_image_dir+str(i)+'_out.png' file_name_label = FLAGS.train_image_dir+str(i)+'_gt.png'