]> git.sesse.net Git - narabu/blobdiff - qdd.cpp
Add some parallel slicing code (not really a win).
[narabu] / qdd.cpp
diff --git a/qdd.cpp b/qdd.cpp
index 8c8196dfdcab8bc537901c27d5ff872d08581898..a4fa907a297549f03d0860bb016d41fb7bf83420 100644 (file)
--- a/qdd.cpp
+++ b/qdd.cpp
@@ -103,16 +103,21 @@ int main(void)
                exit(1);
        }
 
+       uint32_t sign_bias[16];
        for (unsigned table = 0; table < 16; ++table) {
                uint32_t cum_freq = 0;
                for (unsigned sym = 0; sym < NUM_SYMS; ++sym) {
                        uint32_t freq = read_varint(fp);
                        fprintf(stderr, "sym=%u/%u: freq=%u\n", sym, NUM_SYMS, freq);
-                       RansDecSymbolInit(&decode_tables[table].dsyms[sym], cum_freq, freq);
+                       RansDecSymbolInit(&decode_tables[table].dsyms[(sym + 1) & 255], cum_freq, freq);
                        for (uint32_t i = 0; i < freq; ++i) {
-                               decode_tables[table].cum2sym[cum_freq++] = sym;
+                               if (cum_freq < prob_scale)
+                                       decode_tables[table].cum2sym[cum_freq] = (sym + 1) & 255;
+                               ++cum_freq;
                        }
                }
+               sign_bias[table] = cum_freq;
+               printf("sign_bias=%u (of %d)\n", sign_bias[table], prob_scale * 2);
        }
 
        // loop over all coefficients
@@ -122,13 +127,8 @@ int main(void)
                
                        RansState rans = 0;
 
-                       //unique_ptr<uint8_t[]> rans_bytes(new uint8_t[num_rans_bytes]);
-                       //unique_ptr<uint8_t[]> sign_bytes(new uint8_t[num_sign_bytes]);
                        unique_ptr<uint8_t[]> rans_bytes;
-                       unique_ptr<uint8_t[]> sign_bytes;
                        uint8_t *rans_ptr = nullptr;
-                       uint8_t *sign_ptr = nullptr;  // optimize later
-                       uint32_t sign_buf = 0, sign_bits_left = 0;
 
                        // loop over all DCT blocks
                        for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
@@ -138,22 +138,7 @@ int main(void)
                                        rans_bytes.reset(new uint8_t[num_rans_bytes]);
                                        fread(rans_bytes.get(), 1, num_rans_bytes, fp);
 
-                                       uint32_t val = read_varint(fp);
-                                       uint8_t free_sign_bits = val & 0x7;
-                                       uint32_t num_sign_bytes = val >> 3;
-                                       sign_bytes.reset(new uint8_t[num_sign_bytes]);
-                                       fread(sign_bytes.get(), 1, num_sign_bytes, fp);
-
-                                       sign_ptr = sign_bytes.get();
-                                       if (free_sign_bits == 0) {
-                                               sign_buf = *sign_ptr++;
-                                               sign_bits_left = 8;
-                                       } else {
-                                               sign_buf = *sign_ptr++ >> free_sign_bits;
-                                               sign_bits_left = 8 - free_sign_bits;
-                                       }
-
-                                       printf("%d,%d: read %d rANS bytes, %d sign bytes\n", x, y, num_rans_bytes, num_sign_bytes);     
+                                       printf("%d,%d: read %d rANS bytes\n", x, y, num_rans_bytes);
                                        //printf("first bytes: %02x %02x %02x %02x %02x %02x %02x %02x\n", rans_bytes[0], rans_bytes[1], rans_bytes[2], rans_bytes[3], rans_bytes[4], rans_bytes[5], rans_bytes[6], rans_bytes[7]);
 
 
@@ -162,25 +147,24 @@ int main(void)
                                        RansDecInit(&rans, &rans_ptr);
                                }
                                for (unsigned xb = 0; xb < WIDTH; xb += 8) {
-                                       uint32_t k = decode_tables[tbl].cum2sym[RansDecGet(&rans, prob_bits)];
-                                       //printf("reading symbol, rans state = %08x\n", rans);
-                                       RansDecAdvanceSymbol(&rans, &rans_ptr, &decode_tables[tbl].dsyms[k], prob_bits);
-                                       //printf("done reading symbol, rans state = %08x\n", rans);
+                                       uint32_t bottom_bits = RansDecGet(&rans, prob_bits + 1);
+                                       uint32_t sign = 0;
+                                       if (bottom_bits >= sign_bias[tbl]) {
+                                               bottom_bits -= sign_bias[tbl];
+                                               rans -= sign_bias[tbl];
+                                               sign = 1;
+                                       }
+                                       uint32_t k = decode_tables[tbl].cum2sym[std::min(bottom_bits, prob_scale - 1)];
+                                       RansDecAdvanceSymbol(&rans, &rans_ptr, &decode_tables[tbl].dsyms[k], prob_bits + 1);
                                        assert(k <= ESCAPE_LIMIT);
                                        if (k == ESCAPE_LIMIT) {
                                                k = RansDecGet(&rans, prob_bits);
                                                assert(k >= ESCAPE_LIMIT);
-                                               //printf("reading escape symbol, rans state = %08x\n", rans);
                                                RansDecAdvance(&rans, &rans_ptr, k, 1, prob_bits);
                                        }
-                                       if (k != 0) {
-                                               if (sign_bits_left == 0) {
-                                                       sign_buf = *sign_ptr++;
-                                                       sign_bits_left = 8;
-                                               }
-                                               if (sign_buf & 1) k = -k;
-                                               --sign_bits_left;
-                                               sign_buf >>= 1;
+                                       if (sign) {
+                                               assert(k != 0);
+                                               k = -k;
                                        }
 
                                        // reverse