]> git.sesse.net Git - narabu/blobdiff - qdd.cpp
More fixes of hard-coded values.
[narabu] / qdd.cpp
diff --git a/qdd.cpp b/qdd.cpp
index 8c8196dfdcab8bc537901c27d5ff872d08581898..d9b788c6c080ee586bdccbdee0cdf0caed22bc2a 100644 (file)
--- a/qdd.cpp
+++ b/qdd.cpp
@@ -6,8 +6,15 @@
 
 #define WIDTH 1280
 #define HEIGHT 720
+#define WIDTH_BLOCKS (WIDTH/8)
+#define WIDTH_BLOCKS_CHROMA (WIDTH/16)
+#define HEIGHT_BLOCKS (HEIGHT/8)
+#define NUM_BLOCKS (WIDTH_BLOCKS * HEIGHT_BLOCKS)
+#define NUM_BLOCKS_CHROMA (WIDTH_BLOCKS_CHROMA * HEIGHT_BLOCKS)
+
 #define NUM_SYMS 256
 #define ESCAPE_LIMIT (NUM_SYMS - 1)
+#define BLOCKS_PER_STREAM 320
 
 #include "ryg_rans/rans_byte.h"
 
@@ -29,7 +36,8 @@ struct RansDecodeTable {
        int cum2sym[prob_scale];
        RansDecSymbol dsyms[NUM_SYMS];
 };
-RansDecodeTable decode_tables[16];
+#define NUM_TABLES 8
+RansDecodeTable decode_tables[NUM_TABLES];
 
 static const unsigned char std_luminance_quant_tbl[64] = {
 #if 0
@@ -55,10 +63,34 @@ static const unsigned char std_luminance_quant_tbl[64] = {
 };
 
 
-int pick_stats_for(int y, int x)
+const int luma_mapping[64] = {
+       0, 0, 1, 1, 2, 2, 3, 3,
+       0, 0, 1, 2, 2, 2, 3, 3,
+       1, 1, 2, 2, 2, 3, 3, 3,
+       1, 1, 2, 2, 2, 3, 3, 3,
+       1, 2, 2, 2, 2, 3, 3, 3,
+       2, 2, 2, 2, 3, 3, 3, 3,
+       2, 2, 3, 3, 3, 3, 3, 3,
+       3, 3, 3, 3, 3, 3, 3, 3,
+};
+const int chroma_mapping[64] = {
+       0, 1, 1, 2, 2, 2, 3, 3,
+       1, 1, 2, 2, 2, 3, 3, 3,
+       2, 2, 2, 2, 3, 3, 3, 3,
+       2, 2, 2, 3, 3, 3, 3, 3,
+       2, 3, 3, 3, 3, 3, 3, 3,
+       3, 3, 3, 3, 3, 3, 3, 3,
+       3, 3, 3, 3, 3, 3, 3, 3,
+       3, 3, 3, 3, 3, 3, 3, 3,
+};
+
+int pick_stats_for(int x, int y, bool is_chroma)
 {
-       if (x + y >= 7) return 7;
-       return x + y;
+       if (is_chroma) {
+               return chroma_mapping[y * 8 + x] + 4;
+       } else {
+               return luma_mapping[y * 8 + x];
+       }
 }
 
 uint32_t read_varint(FILE *fp)
@@ -103,32 +135,32 @@ int main(void)
                exit(1);
        }
 
-       for (unsigned table = 0; table < 16; ++table) {
+       uint32_t sign_bias[NUM_TABLES];
+       for (unsigned table = 0; table < NUM_TABLES; ++table) {
                uint32_t cum_freq = 0;
                for (unsigned sym = 0; sym < NUM_SYMS; ++sym) {
                        uint32_t freq = read_varint(fp);
                        fprintf(stderr, "sym=%u/%u: freq=%u\n", sym, NUM_SYMS, freq);
-                       RansDecSymbolInit(&decode_tables[table].dsyms[sym], cum_freq, freq);
+                       RansDecSymbolInit(&decode_tables[table].dsyms[(sym + 1) & 255], cum_freq, freq);
                        for (uint32_t i = 0; i < freq; ++i) {
-                               decode_tables[table].cum2sym[cum_freq++] = sym;
+                               if (cum_freq < prob_scale)
+                                       decode_tables[table].cum2sym[cum_freq] = (sym + 1) & 255;
+                               ++cum_freq;
                        }
                }
+               sign_bias[table] = cum_freq;
+               printf("sign_bias=%u (of %d)\n", sign_bias[table], prob_scale * 2);
        }
 
        // loop over all coefficients
        for (unsigned y = 0; y < 8; ++y) {
                for (unsigned x = 0; x < 8; ++x) {
-                       unsigned tbl = pick_stats_for(x, y);
+                       unsigned tbl = pick_stats_for(x, y, false);
                
                        RansState rans = 0;
 
-                       //unique_ptr<uint8_t[]> rans_bytes(new uint8_t[num_rans_bytes]);
-                       //unique_ptr<uint8_t[]> sign_bytes(new uint8_t[num_sign_bytes]);
                        unique_ptr<uint8_t[]> rans_bytes;
-                       unique_ptr<uint8_t[]> sign_bytes;
                        uint8_t *rans_ptr = nullptr;
-                       uint8_t *sign_ptr = nullptr;  // optimize later
-                       uint32_t sign_buf = 0, sign_bits_left = 0;
 
                        // loop over all DCT blocks
                        for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
@@ -138,22 +170,7 @@ int main(void)
                                        rans_bytes.reset(new uint8_t[num_rans_bytes]);
                                        fread(rans_bytes.get(), 1, num_rans_bytes, fp);
 
-                                       uint32_t val = read_varint(fp);
-                                       uint8_t free_sign_bits = val & 0x7;
-                                       uint32_t num_sign_bytes = val >> 3;
-                                       sign_bytes.reset(new uint8_t[num_sign_bytes]);
-                                       fread(sign_bytes.get(), 1, num_sign_bytes, fp);
-
-                                       sign_ptr = sign_bytes.get();
-                                       if (free_sign_bits == 0) {
-                                               sign_buf = *sign_ptr++;
-                                               sign_bits_left = 8;
-                                       } else {
-                                               sign_buf = *sign_ptr++ >> free_sign_bits;
-                                               sign_bits_left = 8 - free_sign_bits;
-                                       }
-
-                                       printf("%d,%d: read %d rANS bytes, %d sign bytes\n", x, y, num_rans_bytes, num_sign_bytes);     
+                                       printf("%d,%d: read %d rANS bytes\n", x, y, num_rans_bytes);
                                        //printf("first bytes: %02x %02x %02x %02x %02x %02x %02x %02x\n", rans_bytes[0], rans_bytes[1], rans_bytes[2], rans_bytes[3], rans_bytes[4], rans_bytes[5], rans_bytes[6], rans_bytes[7]);
 
 
@@ -162,25 +179,24 @@ int main(void)
                                        RansDecInit(&rans, &rans_ptr);
                                }
                                for (unsigned xb = 0; xb < WIDTH; xb += 8) {
-                                       uint32_t k = decode_tables[tbl].cum2sym[RansDecGet(&rans, prob_bits)];
-                                       //printf("reading symbol, rans state = %08x\n", rans);
-                                       RansDecAdvanceSymbol(&rans, &rans_ptr, &decode_tables[tbl].dsyms[k], prob_bits);
-                                       //printf("done reading symbol, rans state = %08x\n", rans);
+                                       uint32_t bottom_bits = RansDecGet(&rans, prob_bits + 1);
+                                       uint32_t sign = 0;
+                                       if (bottom_bits >= sign_bias[tbl]) {
+                                               bottom_bits -= sign_bias[tbl];
+                                               rans -= sign_bias[tbl];
+                                               sign = 1;
+                                       }
+                                       uint32_t k = decode_tables[tbl].cum2sym[std::min(bottom_bits, prob_scale - 1)];
+                                       RansDecAdvanceSymbol(&rans, &rans_ptr, &decode_tables[tbl].dsyms[k], prob_bits + 1);
                                        assert(k <= ESCAPE_LIMIT);
                                        if (k == ESCAPE_LIMIT) {
                                                k = RansDecGet(&rans, prob_bits);
                                                assert(k >= ESCAPE_LIMIT);
-                                               //printf("reading escape symbol, rans state = %08x\n", rans);
                                                RansDecAdvance(&rans, &rans_ptr, k, 1, prob_bits);
                                        }
-                                       if (k != 0) {
-                                               if (sign_bits_left == 0) {
-                                                       sign_buf = *sign_ptr++;
-                                                       sign_bits_left = 8;
-                                               }
-                                               if (sign_buf & 1) k = -k;
-                                               --sign_bits_left;
-                                               sign_buf >>= 1;
+                                       if (sign) {
+                                               assert(k != 0);
+                                               k = -k;
                                        }
 
                                        // reverse
@@ -199,10 +215,17 @@ int main(void)
        }
        fclose(fp);
 
-       // DC coefficient pred from the right to left
-       for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
-               for (int xb = WIDTH - 16; xb >= 0; xb -= 8) {
-                       coeff[yb * WIDTH + xb] += coeff[yb * WIDTH + (xb + 8)];
+       // DC coefficient pred from the right to left (within each slice)
+       for (unsigned block_idx = 0; block_idx < NUM_BLOCKS; block_idx += BLOCKS_PER_STREAM) {
+               int prev_k = 128;
+
+               for (unsigned subblock_idx = BLOCKS_PER_STREAM; subblock_idx --> 0; ) {
+                       unsigned yb = (block_idx + subblock_idx) / WIDTH_BLOCKS;
+                       unsigned xb = (block_idx + subblock_idx) % WIDTH_BLOCKS;
+                       int k = coeff[(yb * 8) * WIDTH + (xb * 8)];
+
+                       prev_k += k;
+                       coeff[(yb * 8) * WIDTH + (xb * 8)] = prev_k;
                }
        }