]> git.sesse.net Git - narabu/blob - qdc.cpp
Change quantization to MPEG-2, some other changes.
[narabu] / qdc.cpp
1 #include <stdio.h>
2 #include <stdint.h>
3 #include <stdlib.h>
4 #include <assert.h>
5 #include <math.h>
6
7 //#include "ryg_rans/rans64.h"
8 #include "ryg_rans/rans_byte.h"
9
10 #include <memory>
11
12 #define WIDTH 1280
13 #define HEIGHT 720
14 #define NUM_SYMS 256
15 #define ESCAPE_LIMIT (NUM_SYMS - 1)
16
17 using namespace std;
18
19 void fdct_int32(short *const In);
20 void idct_int32(short *const In);
21
22 unsigned char pix[WIDTH * HEIGHT];
23 short coeff[WIDTH * HEIGHT];
24
25 static const unsigned char std_luminance_quant_tbl[64] = {
26 #if 0
27         16,  11,  10,  16,  24,  40,  51,  61,
28         12,  12,  14,  19,  26,  58,  60,  55,
29         14,  13,  16,  24,  40,  57,  69,  56,
30         14,  17,  22,  29,  51,  87,  80,  62,
31         18,  22,  37,  56,  68, 109, 103,  77,
32         24,  35,  55,  64,  81, 104, 113,  92,
33         49,  64,  78,  87, 103, 121, 120, 101,
34         72,  92,  95,  98, 112, 100, 103,  99
35 #else
36         // ff_mpeg1_default_intra_matrix
37          8, 16, 19, 22, 26, 27, 29, 34,
38         16, 16, 22, 24, 27, 29, 34, 37,                                                 
39         19, 22, 26, 27, 29, 34, 34, 38,                                                 
40         22, 22, 26, 27, 29, 34, 37, 40,
41         22, 26, 27, 29, 32, 35, 40, 48,
42         26, 27, 29, 32, 35, 40, 48, 58,
43         26, 27, 29, 34, 38, 46, 56, 69,
44         27, 29, 35, 38, 46, 56, 69, 83
45 #endif
46 };
47
48 struct SymbolStats
49 {
50     uint32_t freqs[NUM_SYMS];
51     uint32_t cum_freqs[NUM_SYMS + 1];
52
53     void clear();
54     void count_freqs(uint8_t const* in, size_t nbytes);
55     void calc_cum_freqs();
56     void normalize_freqs(uint32_t target_total);
57 };
58
59 void SymbolStats::clear()
60 {
61     for (int i=0; i < NUM_SYMS; i++)
62         freqs[i] = 0;
63 }
64
65 void SymbolStats::count_freqs(uint8_t const* in, size_t nbytes)
66 {
67     clear();
68
69     for (size_t i=0; i < nbytes; i++)
70         freqs[in[i]]++;
71 }
72
73 void SymbolStats::calc_cum_freqs()
74 {
75     cum_freqs[0] = 0;
76     for (int i=0; i < NUM_SYMS; i++)
77         cum_freqs[i+1] = cum_freqs[i] + freqs[i];
78 }
79
80 void SymbolStats::normalize_freqs(uint32_t target_total)
81 {
82     assert(target_total >= NUM_SYMS);
83
84     calc_cum_freqs();
85     uint32_t cur_total = cum_freqs[NUM_SYMS];
86
87     if (cur_total == 0) return;
88
89     // resample distribution based on cumulative freqs
90     for (int i = 1; i <= NUM_SYMS; i++)
91         cum_freqs[i] = ((uint64_t)target_total * cum_freqs[i])/cur_total;
92
93     // if we nuked any non-0 frequency symbol to 0, we need to steal
94     // the range to make the frequency nonzero from elsewhere.
95     //
96     // this is not at all optimal, i'm just doing the first thing that comes to mind.
97     for (int i=0; i < NUM_SYMS; i++) {
98         if (freqs[i] && cum_freqs[i+1] == cum_freqs[i]) {
99             // symbol i was set to zero freq
100
101             // find best symbol to steal frequency from (try to steal from low-freq ones)
102             uint32_t best_freq = ~0u;
103             int best_steal = -1;
104             for (int j=0; j < NUM_SYMS; j++) {
105                 uint32_t freq = cum_freqs[j+1] - cum_freqs[j];
106                 if (freq > 1 && freq < best_freq) {
107                     best_freq = freq;
108                     best_steal = j;
109                 }
110             }
111             assert(best_steal != -1);
112
113             // and steal from it!
114             if (best_steal < i) {
115                 for (int j = best_steal + 1; j <= i; j++)
116                     cum_freqs[j]--;
117             } else {
118                 assert(best_steal > i);
119                 for (int j = i + 1; j <= best_steal; j++)
120                     cum_freqs[j]++;
121             }
122         }
123     }
124
125     // calculate updated freqs and make sure we didn't screw anything up
126     assert(cum_freqs[0] == 0 && cum_freqs[NUM_SYMS] == target_total);
127     for (int i=0; i < NUM_SYMS; i++) {
128         if (freqs[i] == 0)
129             assert(cum_freqs[i+1] == cum_freqs[i]);
130         else
131             assert(cum_freqs[i+1] > cum_freqs[i]);
132
133         // calc updated freq
134         freqs[i] = cum_freqs[i+1] - cum_freqs[i];
135     }
136 }
137
138 SymbolStats stats[64];
139
140 int pick_stats_for(int y, int x)
141 {
142         //return std::min<int>(hypot(x, y), 7);
143         return std::min<int>(x + y, 7);
144         //if (x + y >= 7) return 7;
145         //return x + y;
146         //return y * 8 + x;
147 #if 0
148         if (y == 0 && x == 0) {
149                 return 0;
150         } else {
151                 return 1;
152         }
153 #endif
154 }
155                 
156
157 void write_varint(int x, FILE *fp)
158 {
159         while (x >= 128) {
160                 putc((x & 0x7f) | 0x80, fp);
161                 x >>= 7;
162         }
163         putc(x, fp);
164 }
165
166 class RansEncoder {
167 public:
168         static constexpr uint32_t prob_bits = 12;
169         static constexpr uint32_t prob_scale = 1 << prob_bits;
170
171         RansEncoder()
172         {
173                 out_buf.reset(new uint8_t[out_max_size]);
174                 sign_buf.reset(new uint8_t[max_num_sign]);
175                 clear();
176         }
177
178         void init_prob(const SymbolStats &s1, const SymbolStats &s2)
179         {
180                 for (int i = 0; i < NUM_SYMS; i++) {
181                         //printf("%d: cumfreqs=%d freqs=%d prob_bits=%d\n", i, s.cum_freqs[i], s.freqs[i], prob_bits);
182                         RansEncSymbolInit(&esyms[i], s1.cum_freqs[i], s1.freqs[i], prob_bits);
183                 }
184         }
185
186         void clear()
187         {
188                 out_end = out_buf.get() + out_max_size;
189                 sign_end = sign_buf.get() + max_num_sign;
190                 ptr = out_end; // *end* of output buffer
191                 sign_ptr = sign_end; // *end* of output buffer
192                 RansEncInit(&rans);
193                 free_sign_bits = 0;
194         }
195
196         uint32_t save_block(FILE *codedfp)  // Returns number of bytes.
197         {
198                 RansEncFlush(&rans, &ptr);
199                 //printf("post-flush = %08x\n", rans);
200
201                 uint32_t num_rans_bytes = out_end - ptr;
202                 write_varint(num_rans_bytes, codedfp);
203                 //fwrite(&num_rans_bytes, 1, 4, codedfp);
204                 fwrite(ptr, 1, num_rans_bytes, codedfp);
205
206                 //printf("first rANS bytes: %02x %02x %02x %02x %02x %02x %02x %02x\n", ptr[0], ptr[1], ptr[2], ptr[3], ptr[4], ptr[5], ptr[6], ptr[7]);
207
208                 if (free_sign_bits > 0) {
209                         *sign_ptr <<= free_sign_bits;
210                 }
211
212 #if 1
213                 uint32_t num_sign_bytes = sign_end - sign_ptr;
214                 write_varint((num_sign_bytes << 3) | free_sign_bits, codedfp);
215                 fwrite(sign_ptr, 1, num_sign_bytes, codedfp);
216 #endif
217
218                 clear();
219
220                 //printf("Saving block: %d rANS bytes, %d sign bytes\n", num_rans_bytes, num_sign_bytes);
221                 return num_rans_bytes + num_sign_bytes;
222                 //return num_rans_bytes;
223         }
224
225         void encode_coeff(short signed_k)
226         {
227                 //printf("encoding coeff %d\n", signed_k);
228                 short k = abs(signed_k);
229                 if (k >= ESCAPE_LIMIT) {
230                         // Put the coefficient as a 1/(2^12) symbol _before_
231                         // the 255 coefficient, since the decoder will read the
232                         // 255 coefficient first.
233                         RansEncPut(&rans, &ptr, k, 1, prob_bits);
234                         k = ESCAPE_LIMIT;
235                 }
236                 if (k != 0) {
237 #if 1
238                         if (free_sign_bits == 0) {
239                                 --sign_ptr;
240                                 *sign_ptr = 0;
241                                 free_sign_bits = 8;
242                         }
243                         *sign_ptr <<= 1;
244                         *sign_ptr |= (signed_k < 0);
245                         --free_sign_bits;
246 #else
247                         RansEncPut(&rans, &ptr, (k < 0) ? prob_scale / 2 : 0, prob_scale / 2, prob_bits);
248 #endif
249                 }
250                 RansEncPutSymbol(&rans, &ptr, &esyms[k]);
251         }
252
253 private:
254         static constexpr size_t out_max_size = 32 << 20; // 32 MB.
255         static constexpr size_t max_num_sign = 1048576;  // Way too big. And actually bytes.
256
257         unique_ptr<uint8_t[]> out_buf, sign_buf;
258         uint8_t *out_end, *sign_end;
259         uint8_t *ptr, *sign_ptr;
260         RansState rans;
261         size_t free_sign_bits;
262         RansEncSymbol esyms[NUM_SYMS];
263 };
264
265 static constexpr int dc_scalefac = 8;  // Matches the FDCT's gain.
266 static constexpr double quant_scalefac = 4.0;  // whatever?
267
268 static inline int quantize(int f, int coeff_idx)
269 {
270         if (coeff_idx == 0) {
271                 return f / dc_scalefac;
272         }
273         if (f == 0) {
274                 return 0;
275         }
276
277         const int w = std_luminance_quant_tbl[coeff_idx];
278         const int s = quant_scalefac;
279         int sign_f = (f > 0) ? 1 : -1;
280         return (32 * f + sign_f * w * s) / (2 * w * s);
281 }
282
283 static inline int unquantize(int qf, int coeff_idx)
284 {
285         if (coeff_idx == 0) {
286                 return qf * dc_scalefac;
287         }
288         if (qf == 0) {
289                 return 0;
290         }
291
292         const int w = std_luminance_quant_tbl[coeff_idx];
293         const int s = quant_scalefac;
294         return (2 * qf * w * s) / 32;
295 }
296
297 int main(void)
298 {
299         FILE *fp = fopen("pic.pgm", "rb");
300         fread(pix, 1, WIDTH * HEIGHT, fp);
301         fclose(fp);
302
303         double sum_sq_err = 0.0;
304
305         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
306                 for (unsigned xb = 0; xb < WIDTH; xb += 8) {
307                         // Read one block
308                         short in[64];
309                         for (unsigned y = 0; y < 8; ++y) {
310                                 for (unsigned x = 0; x < 8; ++x) {
311                                         in[y * 8 + x] = pix[(yb + y) * WIDTH + (xb + x)];
312                                 }
313                         }
314
315                         // FDCT it
316                         fdct_int32(in);
317
318                         for (unsigned y = 0; y < 8; ++y) {
319                                 for (unsigned x = 0; x < 8; ++x) {
320                                         int coeff_idx = y * 8 + x;
321                                         int k = quantize(in[coeff_idx], coeff_idx);
322                                         coeff[(yb + y) * WIDTH + (xb + x)] = k;
323
324                                         // Store back for reconstruction / PSNR calculation
325                                         in[coeff_idx] = unquantize(k, coeff_idx);
326                                 }
327                         }
328
329                         idct_int32(in);
330
331                         for (unsigned y = 0; y < 8; ++y) {
332                                 for (unsigned x = 0; x < 8; ++x) {
333                                         int k = in[y * 8 + x];
334                                         if (k < 0) k = 0;
335                                         if (k > 255) k = 255;
336                                         uint8_t *ptr = &pix[(yb + y) * WIDTH + (xb + x)];
337                                         sum_sq_err += (*ptr - k) * (*ptr - k);
338                                         *ptr = k;
339                                 }
340                         }
341                 }
342         }
343         double mse = sum_sq_err / double(WIDTH * HEIGHT);
344         double psnr_db = 20 * log10(255.0 / sqrt(mse));
345         printf("psnr = %.2f dB\n", psnr_db);
346
347         // DC coefficient pred from the right to left
348         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
349                 for (unsigned xb = 0; xb < WIDTH - 8; xb += 8) {
350                         coeff[yb * WIDTH + xb] -= coeff[yb * WIDTH + (xb + 8)];
351                 }
352         }
353
354         fp = fopen("reconstructed.pgm", "wb");
355         fprintf(fp, "P5\n%d %d\n255\n", WIDTH, HEIGHT);
356         fwrite(pix, 1, WIDTH * HEIGHT, fp);
357         fclose(fp);
358
359         // For each coefficient, make some tables.
360         size_t extra_bits = 0, sign_bits = 0;
361         for (unsigned i = 0; i < 64; ++i) {
362                 stats[i].clear();
363         }
364         for (unsigned y = 0; y < 8; ++y) {
365                 for (unsigned x = 0; x < 8; ++x) {
366                         SymbolStats &s = stats[pick_stats_for(x, y)];
367
368                         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
369                                 for (unsigned xb = 0; xb < WIDTH; xb += 8) {
370                                         short k = abs(coeff[(yb + y) * WIDTH + (xb + x)]);
371                                         if (k >= ESCAPE_LIMIT) {
372                                                 //printf("coeff (%d,%d) had value %d\n", y, x, k);
373                                                 k = ESCAPE_LIMIT;
374                                                 extra_bits += 12;  // escape this one
375                                         }
376                                         //if (y != 0 || x != 0) ++sign_bits;
377                                         if (k != 0) ++sign_bits;
378                                         ++s.freqs[k];
379                                 }
380                         }
381                 }
382         }
383         for (unsigned i = 0; i < 64; ++i) {
384 #if 0
385                 printf("coeff %i:", i);
386                 for (unsigned j = 0; j <= ESCAPE_LIMIT; ++j) {
387                         printf(" %d", stats[i].freqs[j]);
388                 }
389                 printf("\n");
390 #endif
391                 stats[i].normalize_freqs(RansEncoder::prob_scale);
392         }
393
394         FILE *codedfp = fopen("coded.dat", "wb");
395         if (codedfp == nullptr) {
396                 perror("coded.dat");
397                 exit(1);
398         }
399
400         // TODO: varint or something on the freqs
401         for (unsigned i = 0; i < 64; ++i) {
402                 if (stats[i].cum_freqs[NUM_SYMS] == 0) {
403                         continue;
404                 }
405                 printf("writing table %d\n", i);
406 #if 0
407                 for (unsigned j = 0; j <= NUM_SYMS; ++j) {
408                         uint16_t freq = stats[i].cum_freqs[j];
409                         fwrite(&freq, 1, sizeof(freq), codedfp);
410                         printf("%d: %d\n", j, stats[i].freqs[j]);
411                 }
412 #else
413                 // TODO: rather gamma-k or something
414                 for (unsigned j = 0; j < NUM_SYMS; ++j) {
415                         write_varint(stats[i].freqs[j], codedfp);
416                 }
417 #endif
418         }
419
420         RansEncoder rans_encoder;
421
422         size_t tot_bytes = 0;
423         for (unsigned y = 0; y < 8; ++y) {
424                 for (unsigned x = 0; x < 8; ++x) {
425                         SymbolStats &s1 = stats[pick_stats_for(x, y)];
426                         SymbolStats &s2 = stats[pick_stats_for(x, y) + 8];
427
428                         rans_encoder.init_prob(s1, s2);
429
430                         // need to reverse later
431                         rans_encoder.clear();
432                         size_t num_bytes = 0;
433                         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
434                                 for (unsigned xb = 0; xb < WIDTH; xb += 8) {
435                                         int k = coeff[(yb + y) * WIDTH + (xb + x)];
436                                         //printf("encoding coeff %d xb,yb=%d,%d: %d\n", y*8+x, xb, yb, k);
437                                         rans_encoder.encode_coeff(k);
438                                 }
439                                 if (yb % 16 == 8) {
440                                         num_bytes += rans_encoder.save_block(codedfp);
441                                 }
442                         }
443                         if (HEIGHT % 16 != 0) {
444                                 num_bytes += rans_encoder.save_block(codedfp);
445                         }
446                         tot_bytes += num_bytes;
447                         printf("coeff %d: %ld bytes\n", y * 8 + x, num_bytes);
448                 }
449         }
450         printf("%ld bytes + %ld sign bits (%ld) + %ld escape bits (%ld) = %ld total bytes\n",
451                 tot_bytes - sign_bits / 8 - extra_bits / 8,
452                 sign_bits,
453                 sign_bits / 8,
454                 extra_bits,
455                 extra_bits / 8,
456                 tot_bytes);
457 }