Added functions for bit-level reading.
authorSteinar H. Gunderson <sesse@debian.org>
Fri, 2 Jan 2009 15:45:04 +0000 (16:45 +0100)
committerSteinar H. Gunderson <sesse@debian.org>
Fri, 2 Jan 2009 15:45:04 +0000 (16:45 +0100)
input.c [new file with mode: 0644]
input.h [new file with mode: 0644]
input_test.c [new file with mode: 0644]

diff --git a/input.c b/input.c
new file mode 100644 (file)
index 0000000..b7f7039
--- /dev/null
+++ b/input.c
@@ -0,0 +1,55 @@
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+
+#include "input.h"
+
+void init_data_source(struct data_source* source, input_func_t* input_func, void* userdata)
+{
+       memset(source, 0, sizeof(*source));
+       source->bytes = (uint8_t*)malloc(bytereservoir_size);
+       source->input_func = input_func;
+       source->userdata = userdata;
+}
+
+void possibly_refill_slow_path(struct data_source* source, unsigned num_bits)
+{
+       // First, make sure there's stuff in the byte reservoir if we can.
+       assert(source->bytes_available <= bytereservoir_size);
+
+       // Read data from the source until we have enough to satisfy the request.
+       while (source->bits_available + 8 * source->bytes_available < num_bits) {
+               const size_t bytes_to_read = bytereservoir_size - source->bytes_available;
+               const ssize_t bytes_read =
+                       (*source->input_func)(source->userdata,
+                                             source->bytes + source->bytes_available,
+                                             bytes_to_read);
+               assert(bytes_read <= bytes_to_read);
+               assert(bytes_read >= (ssize_t)-1);
+
+               // TODO: We need better error handling here. setjmp()/longjmp()
+               // should hopefully do the trick, but we need to take care for
+               // suspension.
+               if (bytes_read == (ssize_t)-1) {
+                       fprintf(stderr, "Input function returned error\n");
+                       exit(1);
+               }
+               if (bytes_read == 0) {
+                       fprintf(stderr, "Premature EOF\n");
+                       exit(1);
+               }
+               
+               source->bytes_available += bytes_read;
+       }
+
+       // Fill the bit reservoir one by one byte until we have enough.
+       while (source->bits_available < num_bits) {
+               assert(source->bytes_available > 0);
+               assert(source->bits_available + 8 <= bitreservoir_size);
+               uint8_t byte = *(source->bytes);
+               ++source->bytes;
+               --source->bytes_available;
+               source->bits |= ((bitreservoir_t)byte << (bitreservoir_size - source->bits_available - 8));
+               source->bits_available += 8;
+       }
+}
diff --git a/input.h b/input.h
new file mode 100644 (file)
index 0000000..28c5f0b
--- /dev/null
+++ b/input.h
@@ -0,0 +1,88 @@
+#ifndef _INPUT_H
+#define _INPUT_H 1
+
+#include <assert.h>
+#include <stddef.h>
+#include <stdint.h>
+#include <sys/types.h>
+#include <arpa/inet.h>
+
+// Optimize for 64 bits. We might want to replace this for 32-bit machines
+// (benchmark later).
+typedef uint64_t bitreservoir_t;
+typedef uint32_t bitreservoir_fill_t;
+
+static inline bitreservoir_fill_t read_bitreservoir_fill(uint8_t* source)
+{
+       return ntohl(*(bitreservoir_fill_t*)(source));
+}
+
+static const unsigned bitreservoir_size = 8 * sizeof(bitreservoir_t);
+static const unsigned bitreservoir_fill_size = 8 * sizeof(bitreservoir_fill_t);
+static const unsigned bytereservoir_size = 4096;
+
+// A function to read bytes from some input source.
+// A return value of -1 indicates error, a return value of 0 indicates EOF.
+typedef ssize_t (input_func_t)(void*, uint8_t*, size_t);
+
+struct data_source {
+       // Short-term bit reservoir; holds up to 64 bits. When it's empty,
+       // it needs to get refilled from the medium-term bit reservoir.
+       bitreservoir_t bits;
+       unsigned bits_available;
+       
+       // Medium-term bit reservoir; holds a few kilobytes of spare data.
+       // When this is empty, it needs to be refilled from the input
+       // stream.
+       uint8_t* bytes;
+       unsigned bytes_available;
+
+       // Data source.
+       input_func_t* input_func;
+       void* userdata;
+};
+       
+void init_data_source(struct data_source* source, input_func_t* input_func, void* userdata);
+
+// Internal function. Do not use.
+void possibly_refill_slow_path(struct data_source* source, unsigned num_bits);
+
+// Make sure there's at least NUM_BITS available in the short-term bit reservoir.
+// You usually want to call this before read_bits(). The reason it's separate
+// is that if you want two reads and you know the size of both, it's faster to
+// refill A+B, read A, read B than refill A, read A, refill B, read B.
+static inline void possibly_refill(struct data_source* source, unsigned num_bits)
+{
+       assert(num_bits <= bitreservoir_fill_size + 1);
+
+       if (source->bits_available >= num_bits) {
+               // Fast path (~90% of invocations?)
+               return;
+       }
+
+       // Slower path (~99% of remaining invocations?)
+       assert(source->bits_available + bitreservoir_fill_size < bitreservoir_size);
+       if (source->bytes_available >= sizeof(bitreservoir_fill_t)) {
+               bitreservoir_fill_t fill = read_bitreservoir_fill(source->bytes);
+               source->bytes += sizeof(bitreservoir_fill_t);
+               source->bytes_available -= sizeof(bitreservoir_fill_t);
+               source->bits |= (bitreservoir_t)fill << (bitreservoir_size - bitreservoir_fill_size - source->bits_available);
+               source->bits_available += bitreservoir_fill_size;
+               return;
+       }
+
+       // Slow path: Refill from data source.
+       // Should not be inlined, so split into a separate function.
+       possibly_refill_slow_path(source, num_bits);
+}
+
+static inline unsigned read_bits(struct data_source* source, unsigned num_bits)
+{
+       assert(source->bits_available >= num_bits);
+       unsigned ret = (source->bits >> (bitreservoir_size - num_bits));
+       source->bits <<= num_bits;
+       source->bits_available -= num_bits;
+       return ret;
+}
+
+#endif /* !defined(_INPUT_H) */
diff --git a/input_test.c b/input_test.c
new file mode 100644 (file)
index 0000000..3f003d2
--- /dev/null
@@ -0,0 +1,130 @@
+#include <stdio.h>
+#include <string.h>
+#include <assert.h>
+#include <time.h>
+#include <sys/time.h>
+
+#include "input.h"
+
+struct custom_read_userdata {
+       uint8_t* bytes;
+       unsigned bytes_left;
+};
+
+ssize_t custom_read(void* userdata, uint8_t* buf, size_t count)
+{
+       struct custom_read_userdata* ud = (struct custom_read_userdata*)userdata;
+       size_t num_to_read = (ud->bytes_left > count ? count : ud->bytes_left);
+       memcpy(buf, ud->bytes, num_to_read);
+       ud->bytes += num_to_read;
+       ud->bytes_left -= num_to_read;
+       return num_to_read;     
+}
+
+ssize_t custom_read_slow(void* userdata, uint8_t* buf, size_t count)
+{
+       struct custom_read_userdata* ud = (struct custom_read_userdata*)userdata;
+       size_t num_to_read = (count > 0 ? 1 : 0);
+       memcpy(buf, ud->bytes, num_to_read);
+       ud->bytes += num_to_read;
+       ud->bytes_left -= num_to_read;
+       return num_to_read;
+}
+
+// Read 6 bits at a time. We should get 0b101010 every time.
+void test_basic_reading()
+{
+       uint8_t bytes[] = { 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa };
+       struct custom_read_userdata ud;
+       ud.bytes = bytes;
+       ud.bytes_left = sizeof(bytes);
+
+       struct data_source source;
+       init_data_source(&source, custom_read, &ud);
+
+       for (int i = 0; i < sizeof(bytes) * 8 / 6; ++i) {
+               possibly_refill(&source, 6);
+               unsigned ret = read_bits(&source, 6);
+               assert(ret == 0x2a);
+       }
+
+       assert(ud.bytes_left == 0);
+}
+
+// Same, but with an input source that gives back only one byte at a time.
+void test_slow_source()
+{
+       uint8_t bytes[] = { 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa };
+       struct custom_read_userdata ud;
+       ud.bytes = bytes;
+       ud.bytes_left = sizeof(bytes);
+
+       struct data_source source;
+       init_data_source(&source, custom_read_slow, &ud);
+
+       for (int i = 0; i < sizeof(bytes) * 8 / 6; ++i) {
+               possibly_refill(&source, 6);
+               unsigned ret = read_bits(&source, 6);
+               assert(ret == 0x2a);
+       }
+
+       assert(ud.bytes_left == 0);
+}
+
+// Read a few different bit sizes.
+void test_variable_size()
+{
+       uint8_t bytes[] = { 0x12, 0x34, 0x56, 0x78, 0xff };
+       struct custom_read_userdata ud;
+       ud.bytes = bytes;
+       ud.bytes_left = sizeof(bytes);
+
+       struct data_source source;
+       init_data_source(&source, custom_read, &ud);
+
+       {
+               possibly_refill(&source, 4);
+               unsigned ret = read_bits(&source, 1);
+               assert(ret == 0x0);
+               ret = read_bits(&source, 1);
+               assert(ret == 0x0);
+               ret = read_bits(&source, 1);
+               assert(ret == 0x0);
+               ret = read_bits(&source, 1);
+               assert(ret == 0x1);
+       }
+       {
+               possibly_refill(&source, 4);
+               unsigned ret = read_bits(&source, 4);
+               assert(ret == 0x2);
+       }
+       {
+               possibly_refill(&source, 12);
+               unsigned ret = read_bits(&source, 12);
+               assert(ret == 0x345);
+       }
+       {
+               possibly_refill(&source, 20);
+               unsigned ret = read_bits(&source, 16);
+               assert(ret == 0x678f);
+               ret = read_bits(&source, 4);
+               assert(ret == 0xf);
+       }
+
+       assert(ud.bytes_left == 0);
+}
+
+int main(void)
+{
+       printf("test_basic_reading()\n");
+       test_basic_reading();
+       
+       printf("test_slow_source()\n");
+       test_slow_source();
+       
+       printf("test_variable_size()\n");
+       test_variable_size();
+       
+       printf("All tests pass.\n");
+       return 0;
+}