]> git.sesse.net Git - fjl/blobdiff - bytesource.c
Fix a bug where we could return too much data from the byte source. Add test.
[fjl] / bytesource.c
index 25cea227e399cce151228489642004e6ad49fe7d..fed7551d50721c1e06ed2057c3a23e4dc83c069a 100644 (file)
@@ -1,5 +1,3 @@
-#include <stdio.h>
-#include <stdlib.h>
 #include <stdbool.h>
 #include <string.h>
 #include <assert.h>
@@ -19,16 +17,33 @@ void init_byte_source(struct byte_source* source, raw_input_func_t* input_func,
        source->userdata = userdata;
 }
 
-uint8_t byte_source_read_marker(struct byte_source* source)
+uint8_t byte_source_read_marker(struct byte_source* src)
 {
-       assert(source->bytes_available >= 2);
-       assert(source->bytes[0] == MARKER_CHAR);
-       assert(source->bytes[1] != STUFF_MARKER);
+       // Refill until we have at least two bytes or EOF.
+       while (src->bytes_available < 2) {
+               const unsigned bytes_to_read = 2 - src->bytes_available;
+               const ssize_t bytes_read =
+                       (*src->input_func)(src->userdata,
+                                          src->bytes + src->bytes_available,
+                                          bytes_to_read);
+               assert(bytes_read >= -1);
+               assert(bytes_read <= (ssize_t)bytes_to_read);
+               
+               if (bytes_read == -1 || bytes_read == 0) {
+                       return 0x00;
+               }
 
-       uint8_t ret = source->bytes[1];
+               src->bytes_available += bytes_read;
+       }
+
+       assert(src->bytes_available >= 2);
+       if (src->bytes[0] != MARKER_CHAR || src->bytes[1] == STUFF_MARKER) {
+               return 0x00;
+       }
 
-       memmove(source->bytes, source->bytes + 2, source->bytes_available - 2);
-       source->bytes_available -= 2;
+       uint8_t ret = src->bytes[1];
+       memmove(src->bytes, src->bytes + 2, src->bytes_available - 2);
+       src->bytes_available -= 2;
 
        return ret;
 }
@@ -42,14 +57,15 @@ ssize_t byte_source_input_func(void* source, uint8_t* buf, size_t len)
        while (src->bytes_available == 0 ||
               (src->bytes_available == 1 && src->bytes[0] == MARKER_CHAR)) {
                const unsigned space_left = BYTESOURCE_CHUNK_SIZE - src->bytes_available;
-               const size_t bytes_to_read = (len > space_left ? space_left : len);
+               const unsigned missing_data = len - src->bytes_available;
+               const size_t bytes_to_read = (missing_data > space_left ? space_left : missing_data);
                assert(bytes_to_read <= BYTESOURCE_CHUNK_SIZE);
                const ssize_t bytes_read =
                        (*src->input_func)(src->userdata,
-                                             src->bytes + src->bytes_available,
-                                             bytes_to_read);
+                                          src->bytes + src->bytes_available,
+                                          bytes_to_read);
                assert(bytes_read >= -1);
-               assert(bytes_read <= bytes_to_read);
+               assert(bytes_read <= (ssize_t)bytes_to_read);
                
                if (bytes_read == -1) {
                        return -1;