]> git.sesse.net Git - narabu/blob - qdc.cpp
Initial checkin.
[narabu] / qdc.cpp
1 #include <stdio.h>
2 #include <stdint.h>
3 #include <stdlib.h>
4 #include <assert.h>
5 #include <math.h>
6
7 //#include "ryg_rans/rans64.h"
8 #include "ryg_rans/rans_byte.h"
9
10 #include <memory>
11
12 #define WIDTH 1280
13 #define HEIGHT 720
14 #define NUM_SYMS 256
15 #define ESCAPE_LIMIT (NUM_SYMS - 1)
16
17 using namespace std;
18
19 void fdct_int32(short *const In);
20 void idct_int32(short *const In);
21
22 unsigned char pix[WIDTH * HEIGHT];
23 short coeff[WIDTH * HEIGHT];
24
25 static const unsigned char std_luminance_quant_tbl[64] = {
26 #if 0
27         16,  11,  10,  16,  24,  40,  51,  61,
28         12,  12,  14,  19,  26,  58,  60,  55,
29         14,  13,  16,  24,  40,  57,  69,  56,
30         14,  17,  22,  29,  51,  87,  80,  62,
31         18,  22,  37,  56,  68, 109, 103,  77,
32         24,  35,  55,  64,  81, 104, 113,  92,
33         49,  64,  78,  87, 103, 121, 120, 101,
34         72,  92,  95,  98, 112, 100, 103,  99
35 #endif
36     16, 16, 19, 22, 26, 27, 29, 34,
37     16, 16, 22, 24, 27, 29, 34, 37,
38     19, 22, 26, 27, 29, 34, 34, 38,
39     22, 22, 26, 27, 29, 34, 37, 40,
40     22, 26, 27, 29, 32, 35, 40, 48,
41     26, 27, 29, 32, 35, 40, 48, 58,
42     26, 27, 29, 34, 38, 46, 56, 69,
43     27, 29, 35, 38, 46, 56, 69, 83
44 };
45
46 struct SymbolStats
47 {
48     uint32_t freqs[NUM_SYMS];
49     uint32_t cum_freqs[NUM_SYMS + 1];
50
51     void clear();
52     void count_freqs(uint8_t const* in, size_t nbytes);
53     void calc_cum_freqs();
54     void normalize_freqs(uint32_t target_total);
55 };
56
57 void SymbolStats::clear()
58 {
59     for (int i=0; i < NUM_SYMS; i++)
60         freqs[i] = 0;
61 }
62
63 void SymbolStats::count_freqs(uint8_t const* in, size_t nbytes)
64 {
65     clear();
66
67     for (size_t i=0; i < nbytes; i++)
68         freqs[in[i]]++;
69 }
70
71 void SymbolStats::calc_cum_freqs()
72 {
73     cum_freqs[0] = 0;
74     for (int i=0; i < NUM_SYMS; i++)
75         cum_freqs[i+1] = cum_freqs[i] + freqs[i];
76 }
77
78 void SymbolStats::normalize_freqs(uint32_t target_total)
79 {
80     assert(target_total >= NUM_SYMS);
81
82     calc_cum_freqs();
83     uint32_t cur_total = cum_freqs[NUM_SYMS];
84
85     if (cur_total == 0) return;
86
87     // resample distribution based on cumulative freqs
88     for (int i = 1; i <= NUM_SYMS; i++)
89         cum_freqs[i] = ((uint64_t)target_total * cum_freqs[i])/cur_total;
90
91     // if we nuked any non-0 frequency symbol to 0, we need to steal
92     // the range to make the frequency nonzero from elsewhere.
93     //
94     // this is not at all optimal, i'm just doing the first thing that comes to mind.
95     for (int i=0; i < NUM_SYMS; i++) {
96         if (freqs[i] && cum_freqs[i+1] == cum_freqs[i]) {
97             // symbol i was set to zero freq
98
99             // find best symbol to steal frequency from (try to steal from low-freq ones)
100             uint32_t best_freq = ~0u;
101             int best_steal = -1;
102             for (int j=0; j < NUM_SYMS; j++) {
103                 uint32_t freq = cum_freqs[j+1] - cum_freqs[j];
104                 if (freq > 1 && freq < best_freq) {
105                     best_freq = freq;
106                     best_steal = j;
107                 }
108             }
109             assert(best_steal != -1);
110
111             // and steal from it!
112             if (best_steal < i) {
113                 for (int j = best_steal + 1; j <= i; j++)
114                     cum_freqs[j]--;
115             } else {
116                 assert(best_steal > i);
117                 for (int j = i + 1; j <= best_steal; j++)
118                     cum_freqs[j]++;
119             }
120         }
121     }
122
123     // calculate updated freqs and make sure we didn't screw anything up
124     assert(cum_freqs[0] == 0 && cum_freqs[NUM_SYMS] == target_total);
125     for (int i=0; i < NUM_SYMS; i++) {
126         if (freqs[i] == 0)
127             assert(cum_freqs[i+1] == cum_freqs[i]);
128         else
129             assert(cum_freqs[i+1] > cum_freqs[i]);
130
131         // calc updated freq
132         freqs[i] = cum_freqs[i+1] - cum_freqs[i];
133     }
134 }
135
136 SymbolStats stats[64];
137
138 int pick_stats_for(int y, int x)
139 {
140         //return std::min<int>(hypot(x, y), 7);
141         return std::min<int>(x + y, 7);
142         //if (x + y >= 7) return 7;
143         //return x + y;
144 //      return y * 8 + x;
145 #if 0
146         if (y == 0 && x == 0) {
147                 return 0;
148         } else {
149                 return 1;
150         }
151 #endif
152 }
153                 
154
155 void write_varint(int x, FILE *fp)
156 {
157         while (x >= 128) {
158                 putc((x & 0x7f) | 0x80, fp);
159                 x >>= 7;
160         }
161         putc(x, fp);
162 }
163
164 class RansEncoder {
165 public:
166         static constexpr uint32_t prob_bits = 12;
167         static constexpr uint32_t prob_scale = 1 << prob_bits;
168
169         RansEncoder()
170         {
171                 out_buf.reset(new uint8_t[out_max_size]);
172                 sign_buf.reset(new uint8_t[max_num_sign]);
173                 clear();
174         }
175
176         void init_prob(const SymbolStats &s)
177         {
178                 for (int i = 0; i < NUM_SYMS; i++) {
179                         //printf("%d: cumfreqs=%d freqs=%d prob_bits=%d\n", i, s.cum_freqs[i], s.freqs[i], prob_bits);
180                         RansEncSymbolInit(&esyms[i], s.cum_freqs[i], s.freqs[i], prob_bits);
181                 }
182         }
183
184         void clear()
185         {
186                 out_end = out_buf.get() + out_max_size;
187                 sign_end = sign_buf.get() + max_num_sign;
188                 ptr = out_end; // *end* of output buffer
189                 sign_ptr = sign_end; // *end* of output buffer
190                 RansEncInit(&rans);
191                 free_sign_bits = 0;
192         }
193
194         uint32_t save_block(FILE *codedfp)  // Returns number of bytes.
195         {
196                 RansEncFlush(&rans, &ptr);
197                 //printf("post-flush = %08x\n", rans);
198
199                 uint32_t num_rans_bytes = out_end - ptr;
200                 write_varint(num_rans_bytes, codedfp);
201                 //fwrite(&num_rans_bytes, 1, 4, codedfp);
202                 fwrite(ptr, 1, num_rans_bytes, codedfp);
203
204                 //printf("first rANS bytes: %02x %02x %02x %02x %02x %02x %02x %02x\n", ptr[0], ptr[1], ptr[2], ptr[3], ptr[4], ptr[5], ptr[6], ptr[7]);
205
206                 if (free_sign_bits > 0) {
207                         *sign_ptr <<= free_sign_bits;
208                 }
209
210 #if 1
211                 uint32_t num_sign_bytes = sign_end - sign_ptr;
212                 write_varint((num_sign_bytes << 3) | free_sign_bits, codedfp);
213                 fwrite(sign_ptr, 1, num_sign_bytes, codedfp);
214 #endif
215
216                 clear();
217
218                 //printf("Saving block: %d rANS bytes, %d sign bytes\n", num_rans_bytes, num_sign_bytes);
219                 return num_rans_bytes + num_sign_bytes;
220                 //return num_rans_bytes;
221         }
222
223         void encode_coeff(short signed_k)
224         {
225                 //printf("encoding coeff %d\n", signed_k);
226                 short k = abs(signed_k);
227                 if (k >= ESCAPE_LIMIT) {
228                         // Put the coefficient as a 1/(2^12) symbol _before_
229                         // the 255 coefficient, since the decoder will read the
230                         // 255 coefficient first.
231                         RansEncPut(&rans, &ptr, k, 1, prob_bits);
232                         k = ESCAPE_LIMIT;
233                 }
234                 if (k != 0) {
235 #if 1
236                         if (free_sign_bits == 0) {
237                                 --sign_ptr;
238                                 *sign_ptr = 0;
239                                 free_sign_bits = 8;
240                         }
241                         *sign_ptr <<= 1;
242                         *sign_ptr |= (signed_k < 0);
243                         --free_sign_bits;
244 #else
245                         RansEncPut(&rans, &ptr, (k < 0) ? prob_scale / 2 : 0, prob_scale / 2, prob_bits);
246 #endif
247                 }
248                 RansEncPutSymbol(&rans, &ptr, &esyms[k]);
249         }
250
251 private:
252         static constexpr size_t out_max_size = 32 << 20; // 32 MB.
253         static constexpr size_t max_num_sign = 1048576;  // Way too big. And actually bytes.
254
255         unique_ptr<uint8_t[]> out_buf, sign_buf;
256         uint8_t *out_end, *sign_end;
257         uint8_t *ptr, *sign_ptr;
258         RansState rans;
259         size_t free_sign_bits;
260         RansEncSymbol esyms[NUM_SYMS];
261 };
262
263 int main(void)
264 {
265         FILE *fp = fopen("pic.pgm", "rb");
266         fread(pix, 1, WIDTH * HEIGHT, fp);
267         fclose(fp);
268
269         double sum_sq_err = 0.0;
270
271         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
272                 for (unsigned xb = 0; xb < WIDTH; xb += 8) {
273                         // Read one block
274                         short in[64];
275                         for (unsigned y = 0; y < 8; ++y) {
276                                 for (unsigned x = 0; x < 8; ++x) {
277                                         in[y * 8 + x] = pix[(yb + y) * WIDTH + (xb + x)];
278                                 }
279                         }
280
281                         // FDCT it
282                         fdct_int32(in);
283
284                         //constexpr int extra_deadzone = 64;
285                         constexpr int extra_deadzone = 4;
286
287                         for (unsigned y = 0; y < 8; ++y) {
288                                 for (unsigned x = 0; x < 8; ++x) {
289                                         short *c = &in[y * 8 + x];
290                                         *c <<= 3;
291                                         *c = copysign(std::max(abs(*c) - extra_deadzone, 0), *c);
292                                         //*c /= std_luminance_quant_tbl[y * 8 + x];
293                                         *c = (int)(double(*c) / std_luminance_quant_tbl[y * 8 + x]);
294 #if 0
295                                         if (x != 0 || y != 0) {
296                                                 int ss = 1;
297                                                 if (::abs(int(*c)) <= ss) {
298                                                         *c = 0; // eeh
299                                                 } else if (*c > 0) {
300                                                         *c -= ss;  // eeh
301                                                 } else {
302                                                         *c += ss;  // eeh
303                                                 }
304                                         }
305 #endif
306                                 }
307                         }
308
309                         // Store it
310                         for (unsigned y = 0; y < 8; ++y) {
311                                 for (unsigned x = 0; x < 8; ++x) {
312                                         coeff[(yb + y) * WIDTH + (xb + x)] = in[y * 8 + x];
313                                 }
314                         }
315
316                         // and back
317                         for (unsigned y = 0; y < 8; ++y) {
318                                 for (unsigned x = 0; x < 8; ++x) {
319                                         in[y * 8 + x] *= std_luminance_quant_tbl[y * 8 + x];
320                                         if (in[y * 8 + x] > 0) {
321                                                 in[y * 8 + x] += extra_deadzone;
322                                         } else if (in[y * 8 + x] < 0) {
323                                                 in[y * 8 + x] -= extra_deadzone;
324                                         }
325                                         in[y * 8 + x] >>= 3;
326                                 }
327                         }
328
329                         idct_int32(in);
330
331                         for (unsigned y = 0; y < 8; ++y) {
332                                 for (unsigned x = 0; x < 8; ++x) {
333                                         int k = in[y * 8 + x];
334                                         if (k < 0) k = 0;
335                                         if (k > 255) k = 255;
336                                         uint8_t *ptr = &pix[(yb + y) * WIDTH + (xb + x)];
337                                         sum_sq_err += (*ptr - k) * (*ptr - k);
338                                         *ptr = k;
339                                 }
340                         }
341                 }
342         }
343         double mse = sum_sq_err / double(WIDTH * HEIGHT);
344         double psnr_db = 20 * log10(255.0 / sqrt(mse));
345         printf("psnr = %.2f dB\n", psnr_db);
346
347         // DC coefficient pred from the right to left
348         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
349                 for (unsigned xb = 0; xb < WIDTH - 8; xb += 8) {
350                         coeff[yb * WIDTH + xb] -= coeff[yb * WIDTH + (xb + 8)];
351                 }
352         }
353
354         fp = fopen("reconstructed.pgm", "wb");
355         fprintf(fp, "P5\n%d %d\n255\n", WIDTH, HEIGHT);
356         fwrite(pix, 1, WIDTH * HEIGHT, fp);
357         fclose(fp);
358
359         // For each coefficient, make some tables.
360         size_t extra_bits = 0, sign_bits = 0;
361         for (unsigned i = 0; i < 64; ++i) {
362                 stats[i].clear();
363         }
364         for (unsigned y = 0; y < 8; ++y) {
365                 for (unsigned x = 0; x < 8; ++x) {
366                         SymbolStats &s = stats[pick_stats_for(x, y)];
367
368                         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
369                                 for (unsigned xb = 0; xb < WIDTH; xb += 8) {
370                                         short k = abs(coeff[(yb + y) * WIDTH + (xb + x)]);
371                                         if (k >= ESCAPE_LIMIT) {
372                                                 //printf("coeff (%d,%d) had value %d\n", y, x, k);
373                                                 k = ESCAPE_LIMIT;
374                                                 extra_bits += 12;  // escape this one
375                                         }
376                                         //if (y != 0 || x != 0) ++sign_bits;
377                                         if (k != 0) ++sign_bits;
378                                         ++s.freqs[k];
379                                 }
380                         }
381                 }
382         }
383         for (unsigned i = 0; i < 64; ++i) {
384 #if 0
385                 printf("coeff %i:", i);
386                 for (unsigned j = 0; j <= ESCAPE_LIMIT; ++j) {
387                         printf(" %d", stats[i].freqs[j]);
388                 }
389                 printf("\n");
390 #endif
391                 stats[i].normalize_freqs(RansEncoder::prob_scale);
392         }
393
394         FILE *codedfp = fopen("coded.dat", "wb");
395         if (codedfp == nullptr) {
396                 perror("coded.dat");
397                 exit(1);
398         }
399
400         // TODO: varint or something on the freqs
401         for (unsigned i = 0; i < 64; ++i) {
402                 if (stats[i].cum_freqs[NUM_SYMS] == 0) {
403                         continue;
404                 }
405                 printf("writing table %d\n", i);
406 #if 0
407                 for (unsigned j = 0; j <= NUM_SYMS; ++j) {
408                         uint16_t freq = stats[i].cum_freqs[j];
409                         fwrite(&freq, 1, sizeof(freq), codedfp);
410                         printf("%d: %d\n", j, stats[i].freqs[j]);
411                 }
412 #else
413                 // TODO: rather gamma-k or something
414                 for (unsigned j = 0; j < NUM_SYMS; ++j) {
415                 //      write_varint(stats[i].freqs[j], codedfp);
416                 }
417 #endif
418         }
419
420         RansEncoder rans_encoder;
421
422         size_t tot_bytes = 0;
423         for (unsigned y = 0; y < 8; ++y) {
424                 for (unsigned x = 0; x < 8; ++x) {
425                         SymbolStats &s = stats[pick_stats_for(x, y)];
426
427                         rans_encoder.init_prob(s);
428
429                         // need to reverse later
430                         rans_encoder.clear();
431                         size_t num_bytes = 0;
432                         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
433                                 for (unsigned xb = 0; xb < WIDTH; xb += 8) {
434                                         int k = coeff[(yb + y) * WIDTH + (xb + x)];
435                                         //printf("encoding coeff %d xb,yb=%d,%d: %d\n", y*8+x, xb, yb, k);
436                                         rans_encoder.encode_coeff(k);
437                                 }
438                                 if (yb % 16 == 8) {
439                                         num_bytes += rans_encoder.save_block(codedfp);
440                                 }
441                         }
442                         if (HEIGHT % 16 != 0) {
443                                 num_bytes += rans_encoder.save_block(codedfp);
444                         }
445                         tot_bytes += num_bytes;
446                         printf("coeff %d: %ld bytes\n", y * 8 + x, num_bytes);
447                 }
448         }
449         printf("%ld bytes + %ld sign bits (%ld) + %ld escape bits (%ld) = %ld total bytes\n",
450                 tot_bytes - sign_bits / 8 - extra_bits / 8,
451                 sign_bits,
452                 sign_bits / 8,
453                 extra_bits,
454                 extra_bits / 8,
455                 tot_bytes);
456 }