]> git.sesse.net Git - narabu/blob - qdc.cpp
Encoder with 4x4 blocks (using TF switching).
[narabu] / qdc.cpp
1 #include <stdio.h>
2 #include <stdint.h>
3 #include <stdlib.h>
4 #include <string.h>
5 #include <assert.h>
6 #include <math.h>
7
8 //#include "ryg_rans/rans64.h"
9 #include "ryg_rans/rans_byte.h"
10
11 #include <memory>
12
13 #define WIDTH 1280
14 #define HEIGHT 720
15 #define NUM_SYMS 256
16 #define ESCAPE_LIMIT (NUM_SYMS - 1)
17
18 using namespace std;
19
20 static constexpr int dc_scalefac = 8;  // Matches the FDCT's gain.
21 static double quant_scalefac = 5.0;  // whatever?
22 static double lambda = 0.1;
23
24 void fdct_int32(short *const In);
25 void idct_int32(short *const In);
26
27 unsigned char pix_4x4[WIDTH * HEIGHT], pix_8x8[WIDTH * HEIGHT], pix[WIDTH * HEIGHT];
28 short global_coeff8x8[WIDTH * HEIGHT], global_coeff4x4[WIDTH * HEIGHT];
29 double err_8x8[(WIDTH/8) * (HEIGHT/8)], err_4x4[(WIDTH/8) * (HEIGHT/8)];
30
31 static const unsigned char quant_8x8[64] = {
32 #if 0
33         16,  11,  10,  16,  24,  40,  51,  61,
34         12,  12,  14,  19,  26,  58,  60,  55,
35         14,  13,  16,  24,  40,  57,  69,  56,
36         14,  17,  22,  29,  51,  87,  80,  62,
37         18,  22,  37,  56,  68, 109, 103,  77,
38         24,  35,  55,  64,  81, 104, 113,  92,
39         49,  64,  78,  87, 103, 121, 120, 101,
40         72,  92,  95,  98, 112, 100, 103,  99
41 #elif 1
42         // ff_mpeg1_default_intra_matrix
43          8, 16, 19, 22, 26, 27, 29, 34,
44         16, 16, 22, 24, 27, 29, 34, 37,                                                 
45         19, 22, 26, 27, 29, 34, 34, 38,                                                 
46         22, 22, 26, 27, 29, 34, 37, 40,
47         22, 26, 27, 29, 32, 35, 40, 48,
48         26, 27, 29, 32, 35, 40, 48, 58,
49         26, 27, 29, 34, 38, 46, 56, 69,
50         27, 29, 35, 38, 46, 56, 69, 83
51 #endif
52 };
53 static const unsigned char quant_4x4[16] = {
54          8, 17, 27, 37,
55         17, 27, 37, 43,
56         27, 37, 43, 49,
57         37, 43, 49, 56
58         //8,  8,  8,  8,
59         //8,  8,  8,  8,
60         //8,  8,  8,  8,
61         //8,  8,  8,  8
62         // 8, 19, 26, 29,
63         //19, 26, 29, 34,                                                 
64         //22, 27, 32, 40,
65         //26, 29, 38, 56,
66 };
67
68 struct SymbolStats
69 {
70     uint32_t freqs[NUM_SYMS];
71     uint32_t cum_freqs[NUM_SYMS + 1];
72
73     void clear();
74     void count_freqs(uint8_t const* in, size_t nbytes);
75     void calc_cum_freqs();
76     void normalize_freqs(uint32_t target_total);
77 };
78
79 void SymbolStats::clear()
80 {
81     for (int i=0; i < NUM_SYMS; i++)
82         freqs[i] = 0;
83 }
84
85 void SymbolStats::count_freqs(uint8_t const* in, size_t nbytes)
86 {
87     clear();
88
89     for (size_t i=0; i < nbytes; i++)
90         freqs[in[i]]++;
91 }
92
93 void SymbolStats::calc_cum_freqs()
94 {
95     cum_freqs[0] = 0;
96     for (int i=0; i < NUM_SYMS; i++)
97         cum_freqs[i+1] = cum_freqs[i] + freqs[i];
98 }
99
100 void SymbolStats::normalize_freqs(uint32_t target_total)
101 {
102     assert(target_total >= NUM_SYMS);
103
104     calc_cum_freqs();
105     uint32_t cur_total = cum_freqs[NUM_SYMS];
106
107     if (cur_total == 0) return;
108
109     // resample distribution based on cumulative freqs
110     for (int i = 1; i <= NUM_SYMS; i++)
111         cum_freqs[i] = ((uint64_t)target_total * cum_freqs[i])/cur_total;
112
113     // if we nuked any non-0 frequency symbol to 0, we need to steal
114     // the range to make the frequency nonzero from elsewhere.
115     //
116     // this is not at all optimal, i'm just doing the first thing that comes to mind.
117     for (int i=0; i < NUM_SYMS; i++) {
118         if (freqs[i] && cum_freqs[i+1] == cum_freqs[i]) {
119             // symbol i was set to zero freq
120
121             // find best symbol to steal frequency from (try to steal from low-freq ones)
122             uint32_t best_freq = ~0u;
123             int best_steal = -1;
124             for (int j=0; j < NUM_SYMS; j++) {
125                 uint32_t freq = cum_freqs[j+1] - cum_freqs[j];
126                 if (freq > 1 && freq < best_freq) {
127                     best_freq = freq;
128                     best_steal = j;
129                 }
130             }
131             assert(best_steal != -1);
132
133             // and steal from it!
134             if (best_steal < i) {
135                 for (int j = best_steal + 1; j <= i; j++)
136                     cum_freqs[j]--;
137             } else {
138                 assert(best_steal > i);
139                 for (int j = i + 1; j <= best_steal; j++)
140                     cum_freqs[j]++;
141             }
142         }
143     }
144
145     // calculate updated freqs and make sure we didn't screw anything up
146     assert(cum_freqs[0] == 0 && cum_freqs[NUM_SYMS] == target_total);
147     for (int i=0; i < NUM_SYMS; i++) {
148         if (freqs[i] == 0)
149             assert(cum_freqs[i+1] == cum_freqs[i]);
150         else
151             assert(cum_freqs[i+1] > cum_freqs[i]);
152
153         // calc updated freq
154         freqs[i] = cum_freqs[i+1] - cum_freqs[i];
155     }
156 }
157
158 SymbolStats stats[64];
159
160 int pick_stats_for(int y, int x, bool is_4x4)
161 {
162         if (is_4x4) {
163                 return 8 + std::min<int>(x + y, 3);
164         }
165         //return std::min<int>(hypot(x, y), 7);
166         return std::min<int>(x + y, 7);
167         //if (x + y >= 7) return 7;
168         //return x + y;
169         //return y * 8 + x;
170 #if 0
171         if (y == 0 && x == 0) {
172                 return 0;
173         } else {
174                 return 1;
175         }
176 #endif
177 }
178                 
179
180 void write_varint(int x, FILE *fp)
181 {
182         while (x >= 128) {
183                 putc((x & 0x7f) | 0x80, fp);
184                 x >>= 7;
185         }
186         putc(x, fp);
187 }
188
189 class RansEncoder {
190 public:
191         static constexpr uint32_t prob_bits = 12;
192         static constexpr uint32_t prob_scale = 1 << prob_bits;
193
194         RansEncoder()
195         {
196                 out_buf.reset(new uint8_t[out_max_size]);
197                 sign_buf.reset(new uint8_t[max_num_sign]);
198                 clear();
199         }
200
201         void init_prob(const SymbolStats &s1, const SymbolStats &s2)
202         {
203                 for (int i = 0; i < NUM_SYMS; i++) {
204                         //printf("%d: cumfreqs=%d freqs=%d prob_bits=%d\n", i, s.cum_freqs[i], s.freqs[i], prob_bits);
205                         RansEncSymbolInit(&esyms[i], s1.cum_freqs[i], s1.freqs[i], prob_bits);
206                 }
207         }
208
209         void clear()
210         {
211                 out_end = out_buf.get() + out_max_size;
212                 sign_end = sign_buf.get() + max_num_sign;
213                 ptr = out_end; // *end* of output buffer
214                 sign_ptr = sign_end; // *end* of output buffer
215                 RansEncInit(&rans);
216                 free_sign_bits = 0;
217         }
218
219         uint32_t save_block(FILE *codedfp)  // Returns number of bytes.
220         {
221                 RansEncFlush(&rans, &ptr);
222                 //printf("post-flush = %08x\n", rans);
223
224                 uint32_t num_rans_bytes = out_end - ptr;
225                 write_varint(num_rans_bytes, codedfp);
226                 //fwrite(&num_rans_bytes, 1, 4, codedfp);
227                 fwrite(ptr, 1, num_rans_bytes, codedfp);
228
229                 //printf("first rANS bytes: %02x %02x %02x %02x %02x %02x %02x %02x\n", ptr[0], ptr[1], ptr[2], ptr[3], ptr[4], ptr[5], ptr[6], ptr[7]);
230
231                 if (free_sign_bits > 0) {
232                         *sign_ptr <<= free_sign_bits;
233                 }
234
235 #if 1
236                 uint32_t num_sign_bytes = sign_end - sign_ptr;
237                 write_varint((num_sign_bytes << 3) | free_sign_bits, codedfp);
238                 fwrite(sign_ptr, 1, num_sign_bytes, codedfp);
239 #endif
240
241                 clear();
242
243                 //printf("Saving block: %d rANS bytes, %d sign bytes\n", num_rans_bytes, num_sign_bytes);
244                 return num_rans_bytes + num_sign_bytes;
245                 //return num_rans_bytes;
246         }
247
248         void encode_coeff(short signed_k)
249         {
250                 //printf("encoding coeff %d\n", signed_k);
251                 short k = abs(signed_k);
252                 if (k >= ESCAPE_LIMIT) {
253                         // Put the coefficient as a 1/(2^12) symbol _before_
254                         // the 255 coefficient, since the decoder will read the
255                         // 255 coefficient first.
256                         RansEncPut(&rans, &ptr, k, 1, prob_bits);
257                         k = ESCAPE_LIMIT;
258                 }
259                 if (k != 0) {
260 #if 1
261                         if (free_sign_bits == 0) {
262                                 --sign_ptr;
263                                 *sign_ptr = 0;
264                                 free_sign_bits = 8;
265                         }
266                         *sign_ptr <<= 1;
267                         *sign_ptr |= (signed_k < 0);
268                         --free_sign_bits;
269 #else
270                         RansEncPut(&rans, &ptr, (k < 0) ? prob_scale / 2 : 0, prob_scale / 2, prob_bits);
271 #endif
272                 }
273                 RansEncPutSymbol(&rans, &ptr, &esyms[k]);
274         }
275
276 private:
277         static constexpr size_t out_max_size = 32 << 20; // 32 MB.
278         static constexpr size_t max_num_sign = 1048576;  // Way too big. And actually bytes.
279
280         unique_ptr<uint8_t[]> out_buf, sign_buf;
281         uint8_t *out_end, *sign_end;
282         uint8_t *ptr, *sign_ptr;
283         RansState rans;
284         size_t free_sign_bits;
285         RansEncSymbol esyms[NUM_SYMS];
286 };
287
288 static inline int quantize8x8(int f, int coeff_idx)
289 {
290         if (coeff_idx == 0) {
291                 return f / dc_scalefac;
292         }
293         if (f == 0) {
294                 return 0;
295         }
296
297         const int w = quant_8x8[coeff_idx];
298         const int s = quant_scalefac;
299         int sign_f = (f > 0) ? 1 : -1;
300         return (32 * f + sign_f * w * s) / (2 * w * s);
301 }
302
303 static inline int unquantize8x8(int qf, int coeff_idx)
304 {
305         if (coeff_idx == 0) {
306                 return qf * dc_scalefac;
307         }
308         if (qf == 0) {
309                 return 0;
310         }
311
312         const int w = quant_8x8[coeff_idx];
313         const int s = quant_scalefac;
314         return (2 * qf * w * s) / 32;
315 }
316
317 static inline int quantize4x4(int f, int coeff_idx)
318 {
319         if (coeff_idx == 0) {
320                 return f / (dc_scalefac/2);
321         }
322         if (f == 0) {
323                 return 0;
324         }
325
326         const int w = quant_4x4[coeff_idx];
327         const int s = quant_scalefac;
328         int sign_f = (f > 0) ? 1 : -1;
329         return (64 * f + sign_f * w * s) / (2 * w * s);
330 }
331
332 static inline int unquantize4x4(int qf, int coeff_idx)
333 {
334         if (coeff_idx == 0) {
335                 return qf * (dc_scalefac/2);
336         }
337         if (qf == 0) {
338                 return 0;
339         }
340
341         const int w = quant_4x4[coeff_idx];
342         const int s = quant_scalefac;
343         return (2 * qf * w * s) / 64;
344 }
345
346 // https://people.xiph.org/~xiphmont/demo/daala/demo3.shtml
347
348 static inline void tf_switch(short *a, short *b, short *c, short *d)
349 {
350         *b = *a - *b;
351         *c = *c + *d;
352         short e = (*c - *b)/2;
353         *a = *a + e;
354         *d = *d - e;
355         *c = *a - *c;
356         *b = *b - *d;
357 }
358
359 static inline void tf_switch_second_stage(short *b, short *d, short *f, short *h)
360 {
361         *b += *d / 2;
362         *d -= *b / 2;
363         *d += *f / 2;
364         *f -= *d / 2;
365         *f += *h / 2;
366         *h -= *f / 2;
367 }
368
369 static inline void tf_switch_second_stage_inv(short *b, short *d, short *f, short *h)
370 {
371         *h += *f / 2;
372         *f -= *h / 2;
373         *f += *d / 2;
374         *d -= *f / 2;
375         *d += *b / 2;
376         *b -= *d / 2;
377 }
378
379 static void convert_8x8to4x4(short *c)
380 {
381         for (unsigned x = 0; x < 8; ++x) {
382                 tf_switch_second_stage_inv(&c[1 * 8 + x], &c[3 * 8 + x], &c[5 * 8 + x], &c[7 * 8 + x]);
383         }
384         for (unsigned y = 0; y < 8; ++y) {
385                 tf_switch_second_stage_inv(&c[y * 8 + 1], &c[y * 8 + 3], &c[y * 8 + 5], &c[y * 8 + 7]);
386         }
387         for (unsigned y = 0; y < 4; ++y) {
388                 for (unsigned x = 0; x < 4; ++x) {
389                         tf_switch(&c[(y*2) * 8 + x*2], &c[(y*2) * 8 + (x*2+1)], &c[(y*2+1)*8 + x*2], &c[(y*2+1)*8 + (x*2+1)]);
390                 }
391         }
392         short d[64] = {
393                 c[0*8 + 0], c[0*8 + 2], c[0*8 + 4], c[0*8 + 6], c[0*8 + 1], c[0*8 + 3], c[0*8 + 5], c[0*8 + 7],
394                 c[2*8 + 0], c[2*8 + 2], c[2*8 + 4], c[2*8 + 6], c[2*8 + 1], c[2*8 + 3], c[2*8 + 5], c[2*8 + 7],
395                 c[4*8 + 0], c[4*8 + 2], c[4*8 + 4], c[4*8 + 6], c[4*8 + 1], c[4*8 + 3], c[4*8 + 5], c[4*8 + 7],
396                 c[6*8 + 0], c[6*8 + 2], c[6*8 + 4], c[6*8 + 6], c[6*8 + 1], c[6*8 + 3], c[6*8 + 5], c[6*8 + 7],
397                 c[1*8 + 0], c[1*8 + 2], c[1*8 + 4], c[1*8 + 6], c[1*8 + 1], c[1*8 + 3], c[1*8 + 5], c[1*8 + 7],
398                 c[3*8 + 0], c[3*8 + 2], c[3*8 + 4], c[3*8 + 6], c[3*8 + 1], c[3*8 + 3], c[3*8 + 5], c[3*8 + 7],
399                 c[5*8 + 0], c[5*8 + 2], c[5*8 + 4], c[5*8 + 6], c[5*8 + 1], c[5*8 + 3], c[5*8 + 5], c[5*8 + 7],
400                 c[7*8 + 0], c[7*8 + 2], c[7*8 + 4], c[7*8 + 6], c[7*8 + 1], c[7*8 + 3], c[7*8 + 5], c[7*8 + 7]
401         };
402         memcpy(c, d, sizeof(d));
403 }
404
405 static void convert_4x4to8x8(short *c)
406 {
407         short d[64] = {
408                 c[0*8 + 0], c[0*8 + 4], c[0*8 + 1], c[0*8 + 5], c[0*8 + 2], c[0*8 + 6], c[0*8 + 3], c[0*8 + 7],
409                 c[4*8 + 0], c[4*8 + 4], c[4*8 + 1], c[4*8 + 5], c[4*8 + 2], c[4*8 + 6], c[4*8 + 3], c[4*8 + 7],
410                 c[1*8 + 0], c[1*8 + 4], c[1*8 + 1], c[1*8 + 5], c[1*8 + 2], c[1*8 + 6], c[1*8 + 3], c[1*8 + 7],
411                 c[5*8 + 0], c[5*8 + 4], c[5*8 + 1], c[5*8 + 5], c[5*8 + 2], c[5*8 + 6], c[5*8 + 3], c[5*8 + 7],
412                 c[2*8 + 0], c[2*8 + 4], c[2*8 + 1], c[2*8 + 5], c[2*8 + 2], c[2*8 + 6], c[2*8 + 3], c[2*8 + 7],
413                 c[6*8 + 0], c[6*8 + 4], c[6*8 + 1], c[6*8 + 5], c[6*8 + 2], c[6*8 + 6], c[6*8 + 3], c[6*8 + 7],
414                 c[3*8 + 0], c[3*8 + 4], c[3*8 + 1], c[3*8 + 5], c[3*8 + 2], c[3*8 + 6], c[3*8 + 3], c[3*8 + 7],
415                 c[7*8 + 0], c[7*8 + 4], c[7*8 + 1], c[7*8 + 5], c[7*8 + 2], c[7*8 + 6], c[7*8 + 3], c[7*8 + 7]
416         };
417
418         for (unsigned y = 0; y < 4; ++y) {
419                 for (unsigned x = 0; x < 4; ++x) {
420                         tf_switch(&d[(y*2) * 8 + x*2], &d[(y*2) * 8 + (x*2+1)], &d[(y*2+1)*8 + x*2], &d[(y*2+1)*8 + (x*2+1)]);
421                 }
422         }
423         for (unsigned y = 0; y < 8; ++y) {
424                 tf_switch_second_stage(&d[y * 8 + 1], &d[y * 8 + 3], &d[y * 8 + 5], &d[y * 8 + 7]);
425         }
426         for (unsigned x = 0; x < 8; ++x) {
427                 tf_switch_second_stage(&d[1 * 8 + x], &d[3 * 8 + x], &d[5 * 8 + x], &d[7 * 8 + x]);
428         }
429
430         memcpy(c, d, sizeof(d));
431 }
432
433 int main(int argc, char **argv)
434 {
435         if (argc >= 2) quant_scalefac = atof(argv[1]);
436         if (argc >= 3) lambda = atof(argv[2]);
437
438         FILE *fp = fopen("pic.pgm", "rb");
439         fread(pix, 1, WIDTH * HEIGHT, fp);
440         fclose(fp);
441
442         double sum_sq_err = 0.0;
443
444         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
445                 for (unsigned xb = 0; xb < WIDTH; xb += 8) {
446                         // Read one block
447                         short in[64], reconstructed8x8[64];
448                         short reconstructed4x4[64];
449                         for (unsigned y = 0; y < 8; ++y) {
450                                 for (unsigned x = 0; x < 8; ++x) {
451                                         in[y * 8 + x] = pix[(yb + y) * WIDTH + (xb + x)];
452                         //              in[y * 8 + x] = 128;
453                                 }
454                         }
455
456                         // FDCT it
457                         fdct_int32(in);
458
459                         // quant 8x8
460                         for (unsigned y = 0; y < 8; ++y) {
461                                 for (unsigned x = 0; x < 8; ++x) {
462                                         int coeff_idx = y * 8 + x;
463                                         int k = quantize8x8(in[coeff_idx], coeff_idx);
464                                         global_coeff8x8[(yb + y) * WIDTH + (xb + x)] = k;
465
466                                         // Store back for reconstruction / PSNR calculation
467                                         reconstructed8x8[coeff_idx] = unquantize8x8(k, coeff_idx);
468                                 }
469                         }
470
471 #if 0
472                         printf("before TF switch:\n");
473                         for (unsigned y = 0; y < 8; ++y) {
474                                 for (unsigned x = 0; x < 8; ++x) {
475                                         printf("%4d ", in[y * 8 + x]);
476                                 }
477                                 printf("\n");
478                         }
479                         convert_8x8to4x4(in);
480                         printf("after TF switch:\n");
481                         for (unsigned y = 0; y < 8; ++y) {
482                                 for (unsigned x = 0; x < 8; ++x) {
483                                         printf("%4d ", in[y * 8 + x]);
484                                 }
485                                 printf("\n");
486                         }
487                         convert_4x4to8x8(in);
488                         printf("after TF switch and back:\n");
489                         for (unsigned y = 0; y < 8; ++y) {
490                                 for (unsigned x = 0; x < 8; ++x) {
491                                         printf("%4d ", in[y * 8 + x]);
492                                 }
493                                 printf("\n");
494                         }
495 #endif
496
497                         // reconstruct 8x8
498                         idct_int32(reconstructed8x8);
499
500                         double sum_sq_err8x8 = 0.0;
501                         for (unsigned y = 0; y < 8; ++y) {
502                                 for (unsigned x = 0; x < 8; ++x) {
503                                         int k = reconstructed8x8[y * 8 + x];
504                                         if (k < 0) k = 0;
505                                         if (k > 255) k = 255;
506                                         uint8_t *ptr = &pix[(yb + y) * WIDTH + (xb + x)];
507                                         sum_sq_err8x8 += (*ptr - k) * (*ptr - k);
508                                         pix_8x8[(yb + y) * WIDTH + (xb + x)] = k;
509 //                                      *ptr = k;
510                                 }
511                         }
512                         sum_sq_err += sum_sq_err8x8;
513
514                         // now let's try 4x4
515                         convert_8x8to4x4(in);
516                         for (unsigned y = 0; y < 8; ++y) {
517                                 for (unsigned x = 0; x < 8; ++x) {
518                                         int coeff_idx = y * 8 + x;
519                                         int subcoeff_idx = (y%4) * 4 + (x%4);
520                                         int k = quantize4x4(in[coeff_idx], subcoeff_idx);
521                                         global_coeff4x4[(yb + y) * WIDTH + (xb + x)] = k;
522
523                                         // Store back for reconstruction / PSNR calculation
524                                         reconstructed4x4[coeff_idx] = unquantize4x4(k, subcoeff_idx);
525                                 }
526                         }
527
528                         // reconstruct 4x4
529                         convert_4x4to8x8(reconstructed4x4);
530                         idct_int32(reconstructed4x4);
531
532                         double sum_sq_err4x4 = 0.0;
533                         for (unsigned y = 0; y < 8; ++y) {
534                                 for (unsigned x = 0; x < 8; ++x) {
535                                         int k = reconstructed4x4[y * 8 + x];
536                                         if (k < 0) k = 0;
537                                         if (k > 255) k = 255;
538                                         uint8_t *ptr = &pix[(yb + y) * WIDTH + (xb + x)];
539                                         sum_sq_err4x4 += (*ptr - k) * (*ptr - k);
540                                         //*ptr = k;
541                                         pix_4x4[(yb + y) * WIDTH + (xb + x)] = k;
542                                 }
543                         }
544
545                         err_8x8[(yb/8) * (WIDTH/8) + (xb/8)] = sum_sq_err8x8;
546                         err_4x4[(yb/8) * (WIDTH/8) + (xb/8)] = sum_sq_err4x4;
547                         //printf("err 8x8 = %6.2f  err 4x4 = %6.2f  win = %d\n", sum_sq_err8x8, sum_sq_err4x4, sum_sq_err4x4 < sum_sq_err8x8);
548                         //sum_sq_err += sum_sq_err4x4;
549                 }
550         }
551         double mse = sum_sq_err / double(WIDTH * HEIGHT);
552         double psnr_db = 20 * log10(255.0 / sqrt(mse));
553         printf("psnr = %.2f dB\n", psnr_db);
554
555         // DC coefficient pred from the right to left
556         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
557                 for (unsigned xb = 0; xb < WIDTH - 8; xb += 8) {
558                         global_coeff8x8[yb * WIDTH + xb] -= global_coeff8x8[yb * WIDTH + (xb + 8)];
559                 }
560         }
561         for (unsigned yb = 0; yb < HEIGHT; yb += 4) {
562                 for (unsigned xb = 0; xb < WIDTH - 4; xb += 4) {
563                         global_coeff4x4[yb * WIDTH + xb] -= global_coeff4x4[yb * WIDTH + (xb + 4)];
564                 }
565         }
566
567         // For each coefficient, make some tables.
568         size_t extra_bits = 0, sign_bits = 0;
569         for (unsigned i = 0; i < 64; ++i) {
570                 stats[i].clear();
571         }
572         for (unsigned y = 0; y < 8; ++y) {
573                 for (unsigned x = 0; x < 8; ++x) {
574                         SymbolStats &s = stats[pick_stats_for(x, y, false)];
575
576                         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
577                                 for (unsigned xb = 0; xb < WIDTH; xb += 8) {
578                                         short k = abs(global_coeff8x8[(yb + y) * WIDTH + (xb + x)]);
579                                         if (k >= ESCAPE_LIMIT) {
580                                                 //printf("coeff (%d,%d) had value %d\n", y, x, k);
581                                                 k = ESCAPE_LIMIT;
582                                                 extra_bits += 12;  // escape this one
583                                         }
584                                         //if (y != 0 || x != 0) ++sign_bits;
585                                         if (k != 0) ++sign_bits;
586                                         ++s.freqs[k];
587                                 }
588                         }
589                 }
590         }
591         for (unsigned y = 0; y < 4; ++y) {
592                 for (unsigned x = 0; x < 4; ++x) {
593                         SymbolStats &s = stats[pick_stats_for(x, y, true)];
594
595                         for (unsigned yb = 0; yb < HEIGHT; yb += 4) {
596                                 for (unsigned xb = 0; xb < WIDTH; xb += 4) {
597                                         short k = abs(global_coeff4x4[(yb + y) * WIDTH + (xb + x)]);
598                                         if (k >= ESCAPE_LIMIT) {
599                                                 k = ESCAPE_LIMIT;
600                                                 extra_bits += 12;  // escape this one
601                                         }
602                                         if (k != 0) ++sign_bits;
603                                         ++s.freqs[k];
604                                 }
605                         }
606                 }
607         }
608         for (unsigned i = 0; i < 64; ++i) {
609 #if 0
610                 printf("coeff %i:", i);
611                 for (unsigned j = 0; j <= ESCAPE_LIMIT; ++j) {
612                         printf(" %d", stats[i].freqs[j]);
613                 }
614                 printf("\n");
615 #endif
616                 stats[i].normalize_freqs(RansEncoder::prob_scale);
617         }
618
619         FILE *codedfp = fopen("coded.dat", "wb");
620         if (codedfp == nullptr) {
621                 perror("coded.dat");
622                 exit(1);
623         }
624
625         // TODO: varint or something on the freqs
626         for (unsigned i = 0; i < 64; ++i) {
627                 if (stats[i].cum_freqs[NUM_SYMS] == 0) {
628                         continue;
629                 }
630                 printf("writing table %d\n", i);
631 #if 0
632                 for (unsigned j = 0; j <= NUM_SYMS; ++j) {
633                         uint16_t freq = stats[i].cum_freqs[j];
634                         fwrite(&freq, 1, sizeof(freq), codedfp);
635                         printf("%d: %d\n", j, stats[i].freqs[j]);
636                 }
637 #else
638                 // TODO: rather gamma-k or something
639                 for (unsigned j = 0; j < NUM_SYMS; ++j) {
640                         write_varint(stats[i].freqs[j], codedfp);
641                 }
642 #endif
643         }
644
645         RansEncoder rans_encoder_8x8, rans_encoder_4x4;
646
647         double total_bits_8x8 = 0.0, total_bits_4x4 = 0.0, total_bits_chosen = 0.0;
648         int num_chosen = 0, tot_chosen = 0;
649
650         size_t tot_bytes = 0;
651         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
652                 for (unsigned xb = 0; xb < WIDTH; xb += 8) {
653                         double bits_8x8 = 0.0, bits_4x4 = 0.0;
654
655                         //rans_encoder.init_prob(s1, s2);
656                         //rans_encoder.clear();
657                         size_t num_bytes = 0;
658                         for (unsigned y = 0; y < 8; ++y) {
659                                 for (unsigned x = 0; x < 8; ++x) {
660                                         SymbolStats &s1 = stats[pick_stats_for(x, y, false)];
661                                         SymbolStats &s2 = stats[pick_stats_for(x%4, y%4, true)];
662
663                                         int k8 = global_coeff8x8[(yb + y) * WIDTH + (xb + x)];
664                                         int k4 = global_coeff4x4[(yb + y) * WIDTH + (xb + x)];
665
666                                         if (k8 != 0) ++bits_8x8;  // sign bits
667                                         if (k4 != 0) ++bits_4x4;
668                                         k8 = abs(k8); k4 = abs(k4);
669         
670                                         if (k8 >= ESCAPE_LIMIT) { k8 = ESCAPE_LIMIT; bits_8x8 += 12.0; }
671                                         if (k4 >= ESCAPE_LIMIT) { k4 = ESCAPE_LIMIT; bits_4x4 += 12.0; }
672                                         
673                                         bits_8x8 -= log2(s1.freqs[k8] / 4096.0);
674                                         bits_4x4 -= log2(s2.freqs[k4] / 4096.0);
675                                 }
676 //                              if (yb % 16 == 8) {
677 //                                      num_bytes += rans_encoder.save_block(codedfp);
678 //                              }
679                         }
680 //                      if (HEIGHT % 16 != 0) {
681 //                              num_bytes += rans_encoder.save_block(codedfp);
682 //                      }
683                         tot_bytes += num_bytes;
684                         total_bits_8x8 += bits_8x8;
685                         total_bits_4x4 += bits_4x4;
686                         auto e8 = err_8x8[(yb/8)*(WIDTH/8) + (xb/8)];
687                         auto e4 = err_4x4[(yb/8)*(WIDTH/8) + (xb/8)];
688                         double rd8 = sqrt(e8) + lambda * bits_8x8;
689                         double rd4 = sqrt(e4) + lambda * bits_4x4;
690                         const unsigned char *spix = (rd4 < rd8) ? pix_4x4 : pix_8x8;
691                         unsigned char col = (rd4 < rd8) ? 255 : 0;
692                         total_bits_chosen += (rd4 < rd8) ? bits_4x4 : bits_8x8;
693                         num_chosen += (rd4 < rd8);
694                         ++tot_chosen;
695                         for (unsigned y = 0; y < 8; ++y) {
696                                 for (unsigned x = 0; x < 8; ++x) {
697                                         pix[(yb + y) * WIDTH + (xb + x)] = spix[(yb + y) * WIDTH + (xb + x)];
698                                         //pix[(yb + y) * WIDTH + (xb + x)] = col;
699                                 }
700                         }       
701                         printf("block (%d,%d): 8x8 %.2f bits [err=%.2f], 4x4 %.2f bits [err=%.2f], win_bits = %d, win_err = %d, win = %d\n",
702                                 yb, xb,
703                                 bits_8x8, sqrt(e8),
704                                 bits_4x4, sqrt(e4),
705                                 bits_4x4 < bits_8x8,
706                                 e4 < e8,
707                                 rd4 < rd8);
708                 }
709         }
710         printf("4x4: %.2f bits, 8x8: %.2f bits, chosen: %d/%d times, %.2f bits (%.0f bytes)\n", total_bits_4x4, total_bits_8x8,
711                 num_chosen, tot_chosen, total_bits_chosen, total_bits_chosen / 8.0);
712         printf("%ld bytes + %ld sign bits (%ld) + %ld escape bits (%ld) = %ld total bytes\n",
713                 tot_bytes - sign_bits / 8 - extra_bits / 8,
714                 sign_bits,
715                 sign_bits / 8,
716                 extra_bits,
717                 extra_bits / 8,
718                 tot_bytes);
719
720         fp = fopen("reconstructed.pgm", "wb");
721         fprintf(fp, "P5\n%d %d\n255\n", WIDTH, HEIGHT);
722         fwrite(pix, 1, WIDTH * HEIGHT, fp);
723         fclose(fp);
724
725 }