]> git.sesse.net Git - narabu/blob - qdd.cpp
8c8196dfdcab8bc537901c27d5ff872d08581898
[narabu] / qdd.cpp
1 #include <stdio.h>
2 #include <stdlib.h>
3 #include <stdint.h>
4 #include <assert.h>
5 #include <memory>
6
7 #define WIDTH 1280
8 #define HEIGHT 720
9 #define NUM_SYMS 256
10 #define ESCAPE_LIMIT (NUM_SYMS - 1)
11
12 #include "ryg_rans/rans_byte.h"
13
14 using namespace std;
15
16 void fdct_int32(short *const In);
17 void idct_int32(short *const In);
18
19 unsigned char pix[WIDTH * HEIGHT];
20 short coeff[WIDTH * HEIGHT];
21
22 struct RansDecoder {
23 };
24
25 static constexpr uint32_t prob_bits = 12;
26 static constexpr uint32_t prob_scale = 1 << prob_bits;
27
28 struct RansDecodeTable {
29         int cum2sym[prob_scale];
30         RansDecSymbol dsyms[NUM_SYMS];
31 };
32 RansDecodeTable decode_tables[16];
33
34 static const unsigned char std_luminance_quant_tbl[64] = {
35 #if 0
36         16,  11,  10,  16,  24,  40,  51,  61,
37         12,  12,  14,  19,  26,  58,  60,  55,
38         14,  13,  16,  24,  40,  57,  69,  56,
39         14,  17,  22,  29,  51,  87,  80,  62,
40         18,  22,  37,  56,  68, 109, 103,  77,
41         24,  35,  55,  64,  81, 104, 113,  92,
42         49,  64,  78,  87, 103, 121, 120, 101,
43         72,  92,  95,  98, 112, 100, 103,  99
44 #else
45         // ff_mpeg1_default_intra_matrix
46          8, 16, 19, 22, 26, 27, 29, 34,
47         16, 16, 22, 24, 27, 29, 34, 37,
48         19, 22, 26, 27, 29, 34, 34, 38,
49         22, 22, 26, 27, 29, 34, 37, 40,
50         22, 26, 27, 29, 32, 35, 40, 48,
51         26, 27, 29, 32, 35, 40, 48, 58,
52         26, 27, 29, 34, 38, 46, 56, 69,
53         27, 29, 35, 38, 46, 56, 69, 83
54 #endif
55 };
56
57
58 int pick_stats_for(int y, int x)
59 {
60         if (x + y >= 7) return 7;
61         return x + y;
62 }
63
64 uint32_t read_varint(FILE *fp)
65 {
66         uint32_t x = 0;
67         int shift = 0;
68         for ( ;; ) {
69                 int ch = getc(fp);
70                 if (ch == -1) {
71                         fprintf(stderr, "Premature EOF\n");
72                         exit(1);
73                 }
74
75                 x |= (ch & 0x7f) << shift;
76                 if ((ch & 0x80) == 0) return x;
77                 shift += 7;
78         }
79 }
80
81 static constexpr int dc_scalefac = 8;  // Matches the FDCT's gain.
82 static constexpr double quant_scalefac = 4.0;  // whatever?
83
84 static inline int unquantize(int qf, int coeff_idx)
85 {
86         if (coeff_idx == 0) {
87                 return qf * dc_scalefac;
88         }
89         if (qf == 0) {
90                 return 0;
91         }
92
93         const int w = std_luminance_quant_tbl[coeff_idx];
94         const int s = quant_scalefac;
95         return (2 * qf * w * s) / 32;
96 }
97
98 int main(void)
99 {
100         FILE *fp = fopen("coded.dat", "rb");
101         if (fp == nullptr) {
102                 perror("coded.dat");
103                 exit(1);
104         }
105
106         for (unsigned table = 0; table < 16; ++table) {
107                 uint32_t cum_freq = 0;
108                 for (unsigned sym = 0; sym < NUM_SYMS; ++sym) {
109                         uint32_t freq = read_varint(fp);
110                         fprintf(stderr, "sym=%u/%u: freq=%u\n", sym, NUM_SYMS, freq);
111                         RansDecSymbolInit(&decode_tables[table].dsyms[sym], cum_freq, freq);
112                         for (uint32_t i = 0; i < freq; ++i) {
113                                 decode_tables[table].cum2sym[cum_freq++] = sym;
114                         }
115                 }
116         }
117
118         // loop over all coefficients
119         for (unsigned y = 0; y < 8; ++y) {
120                 for (unsigned x = 0; x < 8; ++x) {
121                         unsigned tbl = pick_stats_for(x, y);
122                 
123                         RansState rans = 0;
124
125                         //unique_ptr<uint8_t[]> rans_bytes(new uint8_t[num_rans_bytes]);
126                         //unique_ptr<uint8_t[]> sign_bytes(new uint8_t[num_sign_bytes]);
127                         unique_ptr<uint8_t[]> rans_bytes;
128                         unique_ptr<uint8_t[]> sign_bytes;
129                         uint8_t *rans_ptr = nullptr;
130                         uint8_t *sign_ptr = nullptr;  // optimize later
131                         uint32_t sign_buf = 0, sign_bits_left = 0;
132
133                         // loop over all DCT blocks
134                         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
135                                 if (yb % 16 == 0) {
136                                         // read a block
137                                         uint32_t num_rans_bytes = read_varint(fp);
138                                         rans_bytes.reset(new uint8_t[num_rans_bytes]);
139                                         fread(rans_bytes.get(), 1, num_rans_bytes, fp);
140
141                                         uint32_t val = read_varint(fp);
142                                         uint8_t free_sign_bits = val & 0x7;
143                                         uint32_t num_sign_bytes = val >> 3;
144                                         sign_bytes.reset(new uint8_t[num_sign_bytes]);
145                                         fread(sign_bytes.get(), 1, num_sign_bytes, fp);
146
147                                         sign_ptr = sign_bytes.get();
148                                         if (free_sign_bits == 0) {
149                                                 sign_buf = *sign_ptr++;
150                                                 sign_bits_left = 8;
151                                         } else {
152                                                 sign_buf = *sign_ptr++ >> free_sign_bits;
153                                                 sign_bits_left = 8 - free_sign_bits;
154                                         }
155
156                                         printf("%d,%d: read %d rANS bytes, %d sign bytes\n", x, y, num_rans_bytes, num_sign_bytes);     
157                                         //printf("first bytes: %02x %02x %02x %02x %02x %02x %02x %02x\n", rans_bytes[0], rans_bytes[1], rans_bytes[2], rans_bytes[3], rans_bytes[4], rans_bytes[5], rans_bytes[6], rans_bytes[7]);
158
159
160                                         // init rANS decoder
161                                         rans_ptr = rans_bytes.get();
162                                         RansDecInit(&rans, &rans_ptr);
163                                 }
164                                 for (unsigned xb = 0; xb < WIDTH; xb += 8) {
165                                         uint32_t k = decode_tables[tbl].cum2sym[RansDecGet(&rans, prob_bits)];
166                                         //printf("reading symbol, rans state = %08x\n", rans);
167                                         RansDecAdvanceSymbol(&rans, &rans_ptr, &decode_tables[tbl].dsyms[k], prob_bits);
168                                         //printf("done reading symbol, rans state = %08x\n", rans);
169                                         assert(k <= ESCAPE_LIMIT);
170                                         if (k == ESCAPE_LIMIT) {
171                                                 k = RansDecGet(&rans, prob_bits);
172                                                 assert(k >= ESCAPE_LIMIT);
173                                                 //printf("reading escape symbol, rans state = %08x\n", rans);
174                                                 RansDecAdvance(&rans, &rans_ptr, k, 1, prob_bits);
175                                         }
176                                         if (k != 0) {
177                                                 if (sign_bits_left == 0) {
178                                                         sign_buf = *sign_ptr++;
179                                                         sign_bits_left = 8;
180                                                 }
181                                                 if (sign_buf & 1) k = -k;
182                                                 --sign_bits_left;
183                                                 sign_buf >>= 1;
184                                         }
185
186                                         // reverse
187                                         int reversed_yb = yb ^ 8;
188                                         int reversed_xb = WIDTH - 8 - xb;
189                                         coeff[(reversed_yb + y) * WIDTH + (reversed_xb + x)] = k;
190                 //                      printf("coeff %d xb,yb=%d,%d: decoded %d\n", y * 8 + x, reversed_xb, reversed_yb, k);
191                                 }
192                         }
193
194 #if 0
195                         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
196                                 for (unsigned xb = 0; xb < WIDTH; xb += 8) {
197 #endif  
198                 }
199         }
200         fclose(fp);
201
202         // DC coefficient pred from the right to left
203         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
204                 for (int xb = WIDTH - 16; xb >= 0; xb -= 8) {
205                         coeff[yb * WIDTH + xb] += coeff[yb * WIDTH + (xb + 8)];
206                 }
207         }
208
209         // IDCT
210         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
211                 for (unsigned xb = 0; xb < WIDTH; xb += 8) {
212                         // Read one block
213                         short in[64];
214                         for (unsigned y = 0; y < 8; ++y) {
215                                 for (unsigned x = 0; x < 8; ++x) {
216                                         int k = coeff[(yb + y) * WIDTH + (xb + x)];
217                                         in[y * 8 + x] = unquantize(k, y * 8 + x);
218                                         printf("%3d ", in[y * 8 + x]);
219                                 }
220                                 printf("\n");
221                         }
222                         printf("\n");
223
224                         idct_int32(in);
225
226                         // Clamp and move back
227                         for (unsigned y = 0; y < 8; ++y) {
228                                 for (unsigned x = 0; x < 8; ++x) {
229                                         int k = in[y * 8 + x];
230                                         printf("%3d ", k);
231                                         if (k < 0) k = 0;
232                                         if (k > 255) k = 255;
233                                         pix[(yb + y) * WIDTH + (xb + x)] = k;
234                                 }
235                                 printf("\n");
236                         }
237                         printf("\n");
238                 }
239         }
240
241         fp = fopen("output.pgm", "wb");
242         fprintf(fp, "P5\n%d %d\n255\n", WIDTH, HEIGHT);
243         fwrite(pix, 1, WIDTH * HEIGHT, fp);
244         fclose(fp);
245 }