X-Git-Url: https://git.sesse.net/?a=blobdiff_plain;f=voxel_flow_train.py;fp=voxel_flow_train.py;h=736255b184f41909de9f3c071acf6027ec1e8823;hb=a36660b4e7c0ddb51c0aa5ad4b2a9aa6cf483349;hp=7d39968fa5d418d0a357862ebce1ce55d1b7c47b;hpb=bb5e196c9163ea49124ef98725c6fac928c67f4b;p=voxel-flow diff --git a/voxel_flow_train.py b/voxel_flow_train.py index 7d39968..736255b 100755 --- a/voxel_flow_train.py +++ b/voxel_flow_train.py @@ -52,14 +52,17 @@ def train(dataset_frame1, dataset_frame2, dataset_frame3): data_list_frame1 = dataset_frame1.read_data_list_file() dataset_frame1 = tf.data.Dataset.from_tensor_slices(tf.constant(data_list_frame1)) dataset_frame1 = dataset_frame1.repeat().shuffle(buffer_size=1000000, seed=1).map(_read_image) + dataset_frame1 = dataset_frame1.prefetch(100) data_list_frame2 = dataset_frame2.read_data_list_file() dataset_frame2 = tf.data.Dataset.from_tensor_slices(tf.constant(data_list_frame2)) dataset_frame2 = dataset_frame2.repeat().shuffle(buffer_size=1000000, seed=1).map(_read_image) + dataset_frame2 = dataset_frame2.prefetch(100) data_list_frame3 = dataset_frame3.read_data_list_file() dataset_frame3 = tf.data.Dataset.from_tensor_slices(tf.constant(data_list_frame3)) dataset_frame3 = dataset_frame3.repeat().shuffle(buffer_size=1000000, seed=1).map(_read_image) + dataset_frame3 = dataset_frame3.prefetch(100) batch_frame1 = dataset_frame1.batch(FLAGS.batch_size).make_initializable_iterator() batch_frame2 = dataset_frame2.batch(FLAGS.batch_size).make_initializable_iterator()