]> git.sesse.net Git - narabu/blob - qdc.cpp
Add color support.
[narabu] / qdc.cpp
1 #include <stdio.h>
2 #include <stdint.h>
3 #include <stdlib.h>
4 #include <assert.h>
5 #include <math.h>
6
7 //#include "ryg_rans/rans64.h"
8 #include "ryg_rans/rans_byte.h"
9 #include "ryg_rans/renormalize.h"
10
11 #include <memory>
12
13 #define WIDTH 1280
14 #define HEIGHT 720
15 #define NUM_SYMS 256
16 #define ESCAPE_LIMIT (NUM_SYMS - 1)
17
18 static constexpr uint32_t prob_bits = 12;
19 static constexpr uint32_t prob_scale = 1 << prob_bits;
20
21 using namespace std;
22
23 void fdct_int32(short *const In);
24 void idct_int32(short *const In);
25
26 unsigned char rgb[WIDTH * HEIGHT * 3];
27 unsigned char pix_y[WIDTH * HEIGHT];
28 unsigned char pix_cb[(WIDTH/2) * HEIGHT];
29 unsigned char pix_cr[(WIDTH/2) * HEIGHT];
30 unsigned char full_pix_cb[WIDTH * HEIGHT];
31 unsigned char full_pix_cr[WIDTH * HEIGHT];
32 short coeff_y[WIDTH * HEIGHT], coeff_cb[(WIDTH/2) * HEIGHT], coeff_cr[(WIDTH/2) * HEIGHT];
33
34 int clamp(int x)
35 {
36         if (x < 0) return 0;
37         if (x > 255) return 255;
38         return x;
39 }
40
41 static const unsigned char std_luminance_quant_tbl[64] = {
42 #if 0
43         16,  11,  10,  16,  24,  40,  51,  61,
44         12,  12,  14,  19,  26,  58,  60,  55,
45         14,  13,  16,  24,  40,  57,  69,  56,
46         14,  17,  22,  29,  51,  87,  80,  62,
47         18,  22,  37,  56,  68, 109, 103,  77,
48         24,  35,  55,  64,  81, 104, 113,  92,
49         49,  64,  78,  87, 103, 121, 120, 101,
50         72,  92,  95,  98, 112, 100, 103,  99
51 #else
52         // ff_mpeg1_default_intra_matrix
53          8, 16, 19, 22, 26, 27, 29, 34,
54         16, 16, 22, 24, 27, 29, 34, 37,                                                 
55         19, 22, 26, 27, 29, 34, 34, 38,                                                 
56         22, 22, 26, 27, 29, 34, 37, 40,
57         22, 26, 27, 29, 32, 35, 40, 48,
58         26, 27, 29, 32, 35, 40, 48, 58,
59         26, 27, 29, 34, 38, 46, 56, 69,
60         27, 29, 35, 38, 46, 56, 69, 83
61 #endif
62 };
63
64 struct SymbolStats
65 {
66     uint32_t freqs[NUM_SYMS];
67     uint32_t cum_freqs[NUM_SYMS + 1];
68
69     void clear();
70     void count_freqs(uint8_t const* in, size_t nbytes);
71     void calc_cum_freqs();
72     void normalize_freqs(uint32_t target_total);
73 };
74
75 void SymbolStats::clear()
76 {
77     for (int i=0; i < NUM_SYMS; i++)
78         freqs[i] = 0;
79 }
80
81 void SymbolStats::count_freqs(uint8_t const* in, size_t nbytes)
82 {
83     clear();
84
85     for (size_t i=0; i < nbytes; i++)
86         freqs[in[i]]++;
87 }
88
89 void SymbolStats::calc_cum_freqs()
90 {
91     cum_freqs[0] = 0;
92     for (int i=0; i < NUM_SYMS; i++)
93         cum_freqs[i+1] = cum_freqs[i] + freqs[i];
94 }
95
96 static double cache[NUM_SYMS + 1][prob_scale + 1];
97 static double log2cache[prob_scale + 1];
98 static int64_t cachefill = 0;
99
100 double find_optimal_cost(const uint32_t *cum_freqs, int sym_to, int available_slots)
101 {
102         assert(sym_to >= 0);
103
104         while (sym_to > 0 && cum_freqs[sym_to] == cum_freqs[sym_to - 1]) { --sym_to; }
105         if (cache[sym_to][available_slots] >= 0.0) {
106                 //printf("CACHE: %d,%d\n", sym_to, available_slots);
107                 return cache[sym_to][available_slots];
108         }
109         if (sym_to == 0) {
110                 return 0.0;
111         }
112         if (sym_to == 1) {
113                 return cum_freqs[0] * log2cache[available_slots];
114         }
115         if (available_slots == 1) {
116                 return cum_freqs[0] * log2cache[1] + find_optimal_cost(cum_freqs, sym_to - 1, 0);
117         }
118
119 //      printf("UNCACHE: %d,%d\n", sym_to, available_slots);
120 #if 0
121         // ok, test all possible options for the last symbol (TODO: save the choice)
122         double best_so_far = HUGE_VAL;
123         //for (int i = num_syms - 1; i < available_slots; ++i) {
124         double f = freqs[sym_to - 1];
125         for (int i = available_slots; i --> 0; ) {
126                 double cost1 = f * log2cache[available_slots - i];
127                 double cost2 = find_optimal_cost(freqs, sym_to - 1, i);
128
129                 if (sym_to == 3 && available_slots == 838) {
130                         printf("%d %f\n", i, cost1 + cost2);
131                 } else
132                 if (cost1 + cost2 > best_so_far) {
133                         break;
134                 }
135                 best_so_far = cost1 + cost2;
136         }
137 #elif 1
138         // Minimize the number of total bits spent as a function of how many slots
139         // we assign to this symbol.
140         //
141         // The cost function is convex (I don't know how to prove it, but it makes
142         // intuitively a lot of sense). Find a reasonable guess and see what way
143         // we should search, then iterate until we either hit the end or we start
144         // increasing again.
145         double f = cum_freqs[sym_to - 1] - cum_freqs[sym_to - 2];
146         double start = lrint(available_slots * f / cum_freqs[sym_to - 1]);
147
148         int x1 = std::max<int>(floor(start), 1);
149         int x2 = x1 + 1;
150
151         double f1 = f * log2cache[x1] + find_optimal_cost(cum_freqs, sym_to - 1, available_slots - x1);
152         double f2 = f * log2cache[x2] + find_optimal_cost(cum_freqs, sym_to - 1, available_slots - x2);
153
154         int x, direction;  // -1 or +1
155         double best_so_far = std::min(f1, f2);
156         if (isinf(f1) && isinf(f2)) {
157                 // The cost isn't infinite due to the first term, so we need to go downwards
158                 // to give the second term more room to breathe.
159                 x = x1;
160                 direction = -1;
161         } else if (f1 < f2) {
162                 x = x1;
163                 direction = -1;
164         } else {
165                 x = x2;
166                 direction = 1;
167         }
168
169         //printf("[%d,%d] freq=%ld cumfreq=%d From %d and %d, chose %d [%f] and direction=%d\n",
170         //      sym_to, available_slots, freqs[sym_to - 1], cum_freqs[sym_to - 1], x1, x2, x, best_so_far, direction);
171
172         while ((x + direction) > 0 && (x + direction) <= available_slots) {
173                 x += direction;
174                 double fn = f * log2cache[x] + find_optimal_cost(cum_freqs, sym_to - 1, available_slots - x);
175         //      printf("[%d,%d] %d is %f\n", sym_to, available_slots, x, fn);
176                 if (fn > best_so_far) {
177                         break;
178                 }
179                 best_so_far = fn;
180         }
181 #endif
182         if (++cachefill % 131072 == 0) {
183         //      printf("%d,%d = %f (cachefill = %.2f%%)\n", sym_to, available_slots, best_so_far,
184         //              100.0 * (cachefill / double((NUM_SYMS + 1) * (prob_scale + 1))));
185         }
186         assert(best_so_far >= 0.0);
187         assert(cache[sym_to][available_slots] < 0.0);
188         cache[sym_to][available_slots] = best_so_far;
189         return best_so_far;
190 }
191
192 double find_optimal_cost(const uint32_t *cum_freqs, const uint64_t *freqs)
193 {
194         for (int j = 0; j <= NUM_SYMS; ++j) {
195                 for (unsigned k = 0; k <= prob_scale; ++k) {
196                         cache[j][k] = -1.0;
197                 }
198         }
199         for (unsigned k = 0; k <= prob_scale; ++k) {
200                 log2cache[k] = -log2(k * (1.0 / prob_scale));
201                 //printf("log2cache[%d] = %f\n", k, log2cache[k]);
202         }
203         cachefill = 0;
204         double ret = find_optimal_cost(cum_freqs, NUM_SYMS, prob_scale);
205         printf("Used %ld function invocations\n", cachefill);
206         return ret;
207 }
208
209 void SymbolStats::normalize_freqs(uint32_t target_total)
210 {
211     uint64_t real_freq[NUM_SYMS + 1];  // hack
212
213     assert(target_total >= NUM_SYMS);
214
215     calc_cum_freqs();
216     uint32_t cur_total = cum_freqs[NUM_SYMS];
217
218     if (cur_total == 0) return;
219
220     double ideal_cost = 0.0;
221     for (int i = 1; i <= NUM_SYMS; i++)
222     {
223       real_freq[i] = cum_freqs[i] - cum_freqs[i - 1];
224       if (real_freq[i] > 0)
225         ideal_cost -= real_freq[i] * log2(real_freq[i] / double(cur_total));
226     }
227
228     OptimalRenormalize(cum_freqs, NUM_SYMS, prob_scale);
229
230 #if 0
231     double optimal_cost = find_optimal_cost(cum_freqs + 1, real_freq + 1);
232
233     // resample distribution based on cumulative freqs
234     for (int i = 1; i <= NUM_SYMS; i++)
235         //cum_freqs[i] = ((uint64_t)target_total * cum_freqs[i])/cur_total;
236         cum_freqs[i] = lrint(cum_freqs[i] * double(target_total) / cur_total);
237
238     // if we nuked any non-0 frequency symbol to 0, we need to steal
239     // the range to make the frequency nonzero from elsewhere.
240     //
241     // this is not at all optimal, i'm just doing the first thing that comes to mind.
242     for (int i=0; i < NUM_SYMS; i++) {
243         if (freqs[i] && cum_freqs[i+1] == cum_freqs[i]) {
244             // symbol i was set to zero freq
245
246             // find best symbol to steal frequency from (try to steal from low-freq ones)
247             uint32_t best_freq = ~0u;
248             int best_steal = -1;
249             for (int j=0; j < NUM_SYMS; j++) {
250                 uint32_t freq = cum_freqs[j+1] - cum_freqs[j];
251                 if (freq > 1 && freq < best_freq) {
252                     best_freq = freq;
253                     best_steal = j;
254                 }
255             }
256             assert(best_steal != -1);
257
258             // and steal from it!
259             if (best_steal < i) {
260                 for (int j = best_steal + 1; j <= i; j++)
261                     cum_freqs[j]--;
262             } else {
263                 assert(best_steal > i);
264                 for (int j = i + 1; j <= best_steal; j++)
265                     cum_freqs[j]++;
266             }
267         }
268     }
269 #endif
270
271     // calculate updated freqs and make sure we didn't screw anything up
272     assert(cum_freqs[0] == 0 && cum_freqs[NUM_SYMS] == target_total);
273     for (int i=0; i < NUM_SYMS; i++) {
274         if (freqs[i] == 0)
275             assert(cum_freqs[i+1] == cum_freqs[i]);
276         else
277             assert(cum_freqs[i+1] > cum_freqs[i]);
278
279         // calc updated freq
280         freqs[i] = cum_freqs[i+1] - cum_freqs[i];
281     }
282
283     double calc_cost = 0.0;
284     for (int i = 1; i <= NUM_SYMS; i++)
285     {
286       uint64_t freq = cum_freqs[i] - cum_freqs[i - 1];
287       if (real_freq[i] > 0)
288         calc_cost -= real_freq[i] * log2(freq / double(target_total));
289     }
290
291     static double total_loss = 0.0;
292     total_loss += calc_cost - ideal_cost;
293     static double total_loss_with_dp = 0.0;
294         double optimal_cost = 0.0;
295     //total_loss_with_dp += optimal_cost - ideal_cost;
296     printf("ideal cost = %.0f bits, DP cost = %.0f bits, calc cost = %.0f bits (loss = %.2f bytes, total loss = %.2f bytes, total loss with DP = %.2f bytes)\n",
297                 ideal_cost, optimal_cost,
298                  calc_cost, (calc_cost - ideal_cost) / 8.0, total_loss / 8.0, total_loss_with_dp / 8.0);
299 }
300
301 SymbolStats stats[64];
302
303 int pick_stats_for(int y, int x)
304 {
305         //return std::min<int>(hypot(x, y), 7);
306         return std::min<int>(x + y, 7);
307         //if (x + y >= 7) return 7;
308         //return x + y;
309         //return y * 8 + x;
310 #if 0
311         if (y == 0 && x == 0) {
312                 return 0;
313         } else {
314                 return 1;
315         }
316 #endif
317 }
318                 
319
320 void write_varint(int x, FILE *fp)
321 {
322         while (x >= 128) {
323                 putc((x & 0x7f) | 0x80, fp);
324                 x >>= 7;
325         }
326         putc(x, fp);
327 }
328
329 class RansEncoder {
330 public:
331         RansEncoder()
332         {
333                 out_buf.reset(new uint8_t[out_max_size]);
334                 sign_buf.reset(new uint8_t[max_num_sign]);
335                 clear();
336         }
337
338         void init_prob(const SymbolStats &s)
339         {
340                 for (int i = 0; i < NUM_SYMS; i++) {
341                         printf("%d: cumfreqs=%d freqs=%d prob_bits=%d\n", i, s.cum_freqs[i], s.freqs[i], prob_bits);
342                         RansEncSymbolInit(&esyms[i], s.cum_freqs[i], s.freqs[i], prob_bits);
343                 }
344         }
345
346         void clear()
347         {
348                 out_end = out_buf.get() + out_max_size;
349                 sign_end = sign_buf.get() + max_num_sign;
350                 ptr = out_end; // *end* of output buffer
351                 sign_ptr = sign_end; // *end* of output buffer
352                 RansEncInit(&rans);
353                 free_sign_bits = 0;
354         }
355
356         uint32_t save_block(FILE *codedfp)  // Returns number of bytes.
357         {
358                 RansEncFlush(&rans, &ptr);
359                 //printf("post-flush = %08x\n", rans);
360
361                 uint32_t num_rans_bytes = out_end - ptr;
362                 write_varint(num_rans_bytes, codedfp);
363                 //fwrite(&num_rans_bytes, 1, 4, codedfp);
364                 fwrite(ptr, 1, num_rans_bytes, codedfp);
365
366                 //printf("first rANS bytes: %02x %02x %02x %02x %02x %02x %02x %02x\n", ptr[0], ptr[1], ptr[2], ptr[3], ptr[4], ptr[5], ptr[6], ptr[7]);
367
368                 if (free_sign_bits > 0) {
369                         *sign_ptr <<= free_sign_bits;
370                 }
371
372 #if 1
373                 uint32_t num_sign_bytes = sign_end - sign_ptr;
374                 write_varint((num_sign_bytes << 3) | free_sign_bits, codedfp);
375                 fwrite(sign_ptr, 1, num_sign_bytes, codedfp);
376 #endif
377
378                 clear();
379
380                 printf("Saving block: %d rANS bytes, %d sign bytes\n", num_rans_bytes, num_sign_bytes);
381                 return num_rans_bytes + num_sign_bytes;
382                 //return num_rans_bytes;
383         }
384
385         void encode_coeff(short signed_k)
386         {
387                 printf("encoding coeff %d\n", signed_k);
388                 short k = abs(signed_k);
389                 if (k >= ESCAPE_LIMIT) {
390                         // Put the coefficient as a 1/(2^12) symbol _before_
391                         // the 255 coefficient, since the decoder will read the
392                         // 255 coefficient first.
393                         RansEncPut(&rans, &ptr, k, 1, prob_bits);
394                         k = ESCAPE_LIMIT;
395                 }
396                 if (k != 0) {
397 #if 1
398                         if (free_sign_bits == 0) {
399                                 --sign_ptr;
400                                 *sign_ptr = 0;
401                                 free_sign_bits = 8;
402                         }
403                         *sign_ptr <<= 1;
404                         *sign_ptr |= (signed_k < 0);
405                         --free_sign_bits;
406 #else
407                         RansEncPut(&rans, &ptr, (k < 0) ? prob_scale / 2 : 0, prob_scale / 2, prob_bits);
408 #endif
409                 }
410                 RansEncPutSymbol(&rans, &ptr, &esyms[k]);
411         }
412
413 private:
414         static constexpr size_t out_max_size = 32 << 20; // 32 MB.
415         static constexpr size_t max_num_sign = 1048576;  // Way too big. And actually bytes.
416
417         unique_ptr<uint8_t[]> out_buf, sign_buf;
418         uint8_t *out_end, *sign_end;
419         uint8_t *ptr, *sign_ptr;
420         RansState rans;
421         size_t free_sign_bits;
422         RansEncSymbol esyms[NUM_SYMS];
423 };
424
425 static constexpr int dc_scalefac = 8;  // Matches the FDCT's gain.
426 static constexpr double quant_scalefac = 4.0;  // whatever?
427
428 static inline int quantize(int f, int coeff_idx)
429 {
430         if (coeff_idx == 0) {
431                 return f / dc_scalefac;
432         }
433         if (f == 0) {
434                 return 0;
435         }
436
437         const int w = std_luminance_quant_tbl[coeff_idx];
438         const int s = quant_scalefac;
439         int sign_f = (f > 0) ? 1 : -1;
440         return (32 * f + sign_f * w * s) / (2 * w * s);
441 }
442
443 static inline int unquantize(int qf, int coeff_idx)
444 {
445         if (coeff_idx == 0) {
446                 return qf * dc_scalefac;
447         }
448         if (qf == 0) {
449                 return 0;
450         }
451
452         const int w = std_luminance_quant_tbl[coeff_idx];
453         const int s = quant_scalefac;
454         return (2 * qf * w * s) / 32;
455 }
456
457 void readpix(unsigned char *ptr, const char *filename)
458 {
459         FILE *fp = fopen(filename, "rb");
460         if (fp == nullptr) {
461                 perror(filename);
462                 exit(1);
463         }
464
465         fseek(fp, 0, SEEK_END);
466         long len = ftell(fp);
467         assert(len >= WIDTH * HEIGHT * 3);
468         fseek(fp, len - WIDTH * HEIGHT * 3, SEEK_SET);
469
470         fread(ptr, 1, WIDTH * HEIGHT * 3, fp);
471         fclose(fp);
472 }
473
474 void convert_ycbcr()
475 {
476         double coeff[3] = { 0.2126, 0.7152, 0.0722 };  // sum = 1.0
477         double cb_fac = 1.0 / (coeff[0] + coeff[1] + 1.0f - coeff[2]);  // 0.539
478         double cr_fac = 1.0 / (1.0f - coeff[0] + coeff[1] + coeff[2]);  // 0.635 
479
480         unique_ptr<float[]> temp_cb(new float[WIDTH * HEIGHT]);
481         unique_ptr<float[]> temp_cr(new float[WIDTH * HEIGHT]);
482         for (unsigned yb = 0; yb < HEIGHT; ++yb) {
483                 for (unsigned xb = 0; xb < WIDTH; ++xb) {
484                         int r = rgb[((yb * WIDTH) + xb) * 3 + 0];
485                         int g = rgb[((yb * WIDTH) + xb) * 3 + 1];
486                         int b = rgb[((yb * WIDTH) + xb) * 3 + 2];
487                         double y = std::min(std::max(coeff[0] * r + coeff[1] * g + coeff[2] * b, 0.0), 255.0);
488                         double cb = (b - y) * cb_fac + 128.0;
489                         double cr = (r - y) * cr_fac + 128.0;
490                         pix_y[(yb * WIDTH) + xb] = lrint(y);
491                         temp_cb[(yb * WIDTH) + xb] = cb;
492                         temp_cr[(yb * WIDTH) + xb] = cr;
493                         full_pix_cb[(yb * WIDTH) + xb] = lrint(std::min(std::max(cb, 0.0), 255.0));
494                         full_pix_cr[(yb * WIDTH) + xb] = lrint(std::min(std::max(cr, 0.0), 255.0));
495                 }
496         }
497
498         // Simple 4:2:2 subsampling with left convention.
499         for (unsigned yb = 0; yb < HEIGHT; ++yb) {
500                 for (unsigned xb = 0; xb < WIDTH / 2; ++xb) {
501                         int c0 = yb * WIDTH + std::max(int(xb) * 2 - 1, 0);
502                         int c1 = yb * WIDTH + xb * 2;
503                         int c2 = yb * WIDTH + xb * 2 + 1;
504                         
505                         double cb = 0.25 * temp_cb[c0] + 0.5 * temp_cb[c1] + 0.25 * temp_cb[c2];
506                         double cr = 0.25 * temp_cr[c0] + 0.5 * temp_cr[c1] + 0.25 * temp_cr[c2];
507                         cb = std::min(std::max(cb, 0.0), 255.0);
508                         cr = std::min(std::max(cr, 0.0), 255.0);
509                         pix_cb[(yb * WIDTH/2) + xb] = lrint(cb);
510                         pix_cr[(yb * WIDTH/2) + xb] = lrint(cr);
511                 }
512         }
513 }
514
515 int main(int argc, char **argv)
516 {
517         if (argc >= 2)
518                 readpix(rgb, argv[1]);
519         else
520                 readpix(rgb, "color.pnm");
521         convert_ycbcr();
522
523         double sum_sq_err = 0.0;
524         //double last_cb_cfl_fac = 0.0;
525         //double last_cr_cfl_fac = 0.0;
526
527         // DCT and quantize luma
528         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
529                 for (unsigned xb = 0; xb < WIDTH; xb += 8) {
530                         // Read one block
531                         short in_y[64];
532                         for (unsigned y = 0; y < 8; ++y) {
533                                 for (unsigned x = 0; x < 8; ++x) {
534                                         in_y[y * 8 + x] = pix_y[(yb + y) * WIDTH + (xb + x)];
535                                 }
536                         }
537
538                         // FDCT it
539                         fdct_int32(in_y);
540
541                         for (unsigned y = 0; y < 8; ++y) {
542                                 for (unsigned x = 0; x < 8; ++x) {
543                                         int coeff_idx = y * 8 + x;
544                                         int k = quantize(in_y[coeff_idx], coeff_idx);
545                                         coeff_y[(yb + y) * WIDTH + (xb + x)] = k;
546
547                                         // Store back for reconstruction / PSNR calculation
548                                         in_y[coeff_idx] = unquantize(k, coeff_idx);
549                                 }
550                         }
551
552                         idct_int32(in_y);
553
554                         for (unsigned y = 0; y < 8; ++y) {
555                                 for (unsigned x = 0; x < 8; ++x) {
556                                         int k = clamp(in_y[y * 8 + x]);
557                                         uint8_t *ptr = &pix_y[(yb + y) * WIDTH + (xb + x)];
558                                         sum_sq_err += (*ptr - k) * (*ptr - k);
559                                         *ptr = k;
560                                 }
561                         }
562                 }
563         }
564         double mse = sum_sq_err / double(WIDTH * HEIGHT);
565         double psnr_db = 20 * log10(255.0 / sqrt(mse));
566         printf("psnr = %.2f dB\n", psnr_db);
567
568         //double chroma_energy = 0.0, chroma_energy_pred = 0.0;
569
570         // DCT and quantize chroma
571         //double last_cb_cfl_fac = 0.0, last_cr_cfl_fac = 0.0;
572         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
573                 for (unsigned xb = 0; xb < WIDTH/2; xb += 8) {
574 #if 0
575                         // TF switch: Two 8x8 luma blocks -> one 16x8 block, then drop high frequencies
576                         printf("in blocks:\n");
577                         for (unsigned y = 0; y < 8; ++y) {
578                                 for (unsigned x = 0; x < 8; ++x) {
579                                         short a = coeff_y[(yb + y) * WIDTH + (xb*2 + x)];
580                                         printf(" %4d", a);
581                                 }
582                                 printf(" | ");
583                                 for (unsigned x = 0; x < 8; ++x) {
584                                         short b = coeff_y[(yb + y) * WIDTH + (xb*2 + x + 8)];
585                                         printf(" %4d", b);
586                                 }
587                                 printf("\n");
588                         }
589
590                         short in_y[64];
591                         for (unsigned y = 0; y < 8; ++y) {
592                                 for (unsigned x = 0; x < 4; ++x) {
593                                         short a = coeff_y[(yb + y) * WIDTH + (xb*2 + x)];
594                                         short b = coeff_y[(yb + y) * WIDTH + (xb*2 + x + 8)];
595                                         b = a - b;
596                                         a = 2 * a - b;
597                                         in_y[y * 8 + x * 2 + 0] = a;
598                                         in_y[y * 8 + x * 2 + 1] = b;
599                                 }
600                         }
601
602                         printf("tf-ed block:\n");
603                         for (unsigned y = 0; y < 8; ++y) {
604                                 for (unsigned x = 0; x < 8; ++x) {
605                                         short a = in_y[y * 8 + x];
606                                         printf(" %4d", a);
607                                 }
608                                 printf("\n");
609                         }
610 #else
611                         // Read Y block with no tf switch (from reconstructed luma)
612                         short in_y[64];
613                         for (unsigned y = 0; y < 8; ++y) {
614                                 for (unsigned x = 0; x < 8; ++x) {
615                                         in_y[y * 8 + x] = pix_y[(yb + y) * (WIDTH) + (xb + x) * 2];
616                                 }
617                         }
618                         fdct_int32(in_y);
619 #endif
620
621                         // Read one block
622                         short in_cb[64], in_cr[64];
623                         for (unsigned y = 0; y < 8; ++y) {
624                                 for (unsigned x = 0; x < 8; ++x) {
625                                         in_cb[y * 8 + x] = pix_cb[(yb + y) * (WIDTH/2) + (xb + x)];
626                                         in_cr[y * 8 + x] = pix_cr[(yb + y) * (WIDTH/2) + (xb + x)];
627                                 }
628                         }
629
630                         // FDCT it
631                         fdct_int32(in_cb);
632                         fdct_int32(in_cr);
633
634 #if 0
635                         // Chroma from luma
636                         double x0 = in_y[1];
637                         double x1 = in_y[8];
638                         double x2 = in_y[9];
639                         double denom = (x0 * x0 + x1 * x1 + x2 * x2);
640                         //double denom = (x1 * x1);
641         
642                         double y0 = in_cb[1];
643                         double y1 = in_cb[8];
644                         double y2 = in_cb[9];
645                         double cb_cfl_fac = (x0 * y0 + x1 * y1 + x2 * y2) / denom;
646                         //double cb_cfl_fac = (x1 * y1) / denom;
647
648                         for (unsigned y = 0; y < 8; ++y) {
649                                 for (unsigned x = 0; x < 8; ++x) {
650                                         short a = in_y[y * 8 + x];
651                                         printf(" %4d", a);
652                                 }
653                                 printf(" | ");
654                                 for (unsigned x = 0; x < 8; ++x) {
655                                         short a = in_cb[y * 8 + x];
656                                         printf(" %4d", a);
657                                 }
658                                 printf("\n");
659                         }
660                         printf("(%d,%d,%d) -> (%d,%d,%d) gives %f\n",
661                                 in_y[1], in_y[8], in_y[9], 
662                                 in_cb[1], in_cb[8], in_cb[9],
663                                 cb_cfl_fac);
664
665                         y0 = in_cr[1];
666                         y1 = in_cr[8];
667                         y2 = in_cr[9];
668                         double cr_cfl_fac = (x0 * y0 + x1 * y1 + x2 * y2) / denom;
669                         //double cr_cfl_fac = (x1 * y1) / denom;
670                         printf("cb CfL = %7.3f  dc = %5d    cr CfL = %7.3f  dc = %d\n",
671                                 cb_cfl_fac, in_cb[0] - in_y[0],
672                                 cr_cfl_fac, in_cr[0] - in_y[0]);
673
674                         if (denom == 0.0) { cb_cfl_fac = cr_cfl_fac = 0.0; }
675
676                         // CHEAT
677                         //last_cb_cfl_fac = cb_cfl_fac;
678                         //last_cr_cfl_fac = cr_cfl_fac;
679
680                         for (unsigned coeff_idx = 1; coeff_idx < 64; ++coeff_idx) {
681                                 //printf("%2d: cb = %3d prediction = %f * %3d = %7.3f\n", coeff_idx, in_cb[coeff_idx], last_cb_cfl_fac, in_y[coeff_idx], last_cb_cfl_fac * in_y[coeff_idx]);
682                                 //printf("%2d: cr = %3d prediction = %f * %3d = %7.3f\n", coeff_idx, in_cr[coeff_idx], last_cr_cfl_fac, in_y[coeff_idx], last_cr_cfl_fac * in_y[coeff_idx]);
683                                 double cb_pred = last_cb_cfl_fac * in_y[coeff_idx];
684                                 chroma_energy += in_cb[coeff_idx] * in_cb[coeff_idx];
685                                 chroma_energy_pred += (in_cb[coeff_idx] - cb_pred) * (in_cb[coeff_idx] - cb_pred);
686
687                                 //in_cb[coeff_idx] -= lrint(last_cb_cfl_fac * in_y[coeff_idx]);
688                                 //in_cr[coeff_idx] -= lrint(last_cr_cfl_fac * in_y[coeff_idx]);
689                                 //in_cr[coeff_idx] -= lrint(last_cr_cfl_fac * in_y[coeff_idx]);
690                                 //in_cb[coeff_idx] = lrint(in_y[coeff_idx] * (1.0 / sqrt(2)));
691                                 //in_cr[coeff_idx] = lrint(in_y[coeff_idx] * (1.0 / sqrt(2)));
692                                 //in_cb[coeff_idx] = lrint(in_y[coeff_idx]);
693                                 //in_cr[coeff_idx] = lrint(in_y[coeff_idx]);
694                         }
695                         //in_cb[0] += 1024;
696                         //in_cr[0] += 1024;
697                         //in_cb[0] -= in_y[0];
698                         //in_cr[0] -= in_y[0];
699 #endif
700
701                         for (unsigned y = 0; y < 8; ++y) {
702                                 for (unsigned x = 0; x < 8; ++x) {
703                                         int coeff_idx = y * 8 + x;
704                                         int k_cb = quantize(in_cb[coeff_idx], coeff_idx);
705                                         coeff_cb[(yb + y) * (WIDTH/2) + (xb + x)] = k_cb;
706                                         int k_cr = quantize(in_cr[coeff_idx], coeff_idx);
707                                         coeff_cr[(yb + y) * (WIDTH/2) + (xb + x)] = k_cr;
708
709                                         // Store back for reconstruction / PSNR calculation
710                                         in_cb[coeff_idx] = unquantize(k_cb, coeff_idx);
711                                         in_cr[coeff_idx] = unquantize(k_cr, coeff_idx);
712                                 }
713                         }
714
715                         idct_int32(in_y);  // DEBUG
716                         idct_int32(in_cb);
717                         idct_int32(in_cr);
718
719                         for (unsigned y = 0; y < 8; ++y) {
720                                 for (unsigned x = 0; x < 8; ++x) {
721                                         pix_cb[(yb + y) * (WIDTH/2) + (xb + x)] = clamp(in_cb[y * 8 + x]);
722                                         pix_cr[(yb + y) * (WIDTH/2) + (xb + x)] = clamp(in_cr[y * 8 + x]);
723
724                         //              pix_cb[(yb + y) * (WIDTH/2) + (xb + x)] = in_y[y * 8 + x];
725                         //              pix_cr[(yb + y) * (WIDTH/2) + (xb + x)] = in_y[y * 8 + x];
726                                 }
727                         }
728
729 #if 0
730                         last_cb_cfl_fac = cb_cfl_fac;
731                         last_cr_cfl_fac = cr_cfl_fac;
732 #endif
733                 }
734         }
735
736 #if 0
737         printf("chroma_energy = %f, with_pred = %f\n",
738                 chroma_energy / (WIDTH * HEIGHT), chroma_energy_pred / (WIDTH * HEIGHT));
739 #endif
740
741         // DC coefficient pred from the right to left
742         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
743                 for (unsigned xb = 0; xb < WIDTH - 8; xb += 8) {
744                         coeff_y[yb * WIDTH + xb] -= coeff_y[yb * WIDTH + (xb + 8)];
745                 }
746         }
747
748         FILE *fp = fopen("reconstructed.pgm", "wb");
749         fprintf(fp, "P5\n%d %d\n255\n", WIDTH, HEIGHT);
750         fwrite(pix_y, 1, WIDTH * HEIGHT, fp);
751         fclose(fp);
752
753         fp = fopen("reconstructed.pnm", "wb");
754         fprintf(fp, "P6\n%d %d\n255\n", WIDTH, HEIGHT);
755         for (unsigned yb = 0; yb < HEIGHT; ++yb) {
756                 for (unsigned xb = 0; xb < WIDTH; ++xb) {
757                         int y = pix_y[(yb * WIDTH) + xb];
758                         int cb, cr;
759                         int c0 = yb * (WIDTH/2) + xb/2;
760                         if (xb % 2 == 0) {
761                                 cb = pix_cb[c0] - 128.0;
762                                 cr = pix_cr[c0] - 128.0;
763                         } else {
764                                 int c1 = yb * (WIDTH/2) + std::min<int>(xb/2 + 1, WIDTH/2 - 1);
765                                 cb = 0.5 * (pix_cb[c0] + pix_cb[c1]) - 128.0;
766                                 cr = 0.5 * (pix_cr[c0] + pix_cr[c1]) - 128.0;
767                         }
768
769                         double r = y + 1.5748 * cr;
770                         double g = y - 0.1873 * cb - 0.4681 * cr;
771                         double b = y + 1.8556 * cb;
772
773                         putc(clamp(lrint(r)), fp);
774                         putc(clamp(lrint(g)), fp);
775                         putc(clamp(lrint(b)), fp);
776                 }
777         }
778         fclose(fp);
779
780         // For each coefficient, make some tables.
781         size_t extra_bits = 0, sign_bits = 0;
782         for (unsigned i = 0; i < 64; ++i) {
783                 stats[i].clear();
784         }
785         for (unsigned y = 0; y < 8; ++y) {
786                 for (unsigned x = 0; x < 8; ++x) {
787                         SymbolStats &s_luma = stats[pick_stats_for(x, y)];
788                         SymbolStats &s_chroma = stats[pick_stats_for(x, y) + 8];  // HACK
789                         //SymbolStats &s_chroma = stats[pick_stats_for(x, y)];
790
791                         // Luma
792                         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
793                                 for (unsigned xb = 0; xb < WIDTH; xb += 8) {
794                                         short k = abs(coeff_y[(yb + y) * WIDTH + (xb + x)]);
795                                         if (k >= ESCAPE_LIMIT) {
796                                                 k = ESCAPE_LIMIT;
797                                                 extra_bits += 12;  // escape this one
798                                         }
799                                         if (k != 0) ++sign_bits;
800                                         ++s_luma.freqs[k];
801                                 }
802                         }
803                         // Chroma
804                         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
805                                 for (unsigned xb = 0; xb < WIDTH/2; xb += 8) {
806                                         short k_cb = abs(coeff_cb[(yb + y) * WIDTH/2 + (xb + x)]);
807                                         short k_cr = abs(coeff_cr[(yb + y) * WIDTH/2 + (xb + x)]);
808                                         if (k_cb >= ESCAPE_LIMIT) {
809                                                 k_cb = ESCAPE_LIMIT;
810                                                 extra_bits += 12;  // escape this one
811                                         }
812                                         if (k_cr >= ESCAPE_LIMIT) {
813                                                 k_cr = ESCAPE_LIMIT;
814                                                 extra_bits += 12;  // escape this one
815                                         }
816                                         if (k_cb != 0) ++sign_bits;
817                                         if (k_cr != 0) ++sign_bits;
818                                         ++s_chroma.freqs[k_cb];
819                                         ++s_chroma.freqs[k_cr];
820                                 }
821                         }
822                 }
823         }
824         for (unsigned i = 0; i < 64; ++i) {
825                 stats[i].normalize_freqs(prob_scale);
826         }
827
828         FILE *codedfp = fopen("coded.dat", "wb");
829         if (codedfp == nullptr) {
830                 perror("coded.dat");
831                 exit(1);
832         }
833
834         // TODO: rather gamma-k or something
835         for (unsigned i = 0; i < 64; ++i) {
836                 if (stats[i].cum_freqs[NUM_SYMS] == 0) {
837                         continue;
838                 }
839                 printf("writing table %d\n", i);
840                 for (unsigned j = 0; j < NUM_SYMS; ++j) {
841                         write_varint(stats[i].freqs[j], codedfp);
842                 }
843         }
844
845         RansEncoder rans_encoder;
846
847         size_t tot_bytes = 0;
848
849         // Luma
850         for (unsigned y = 0; y < 8; ++y) {
851                 for (unsigned x = 0; x < 8; ++x) {
852                         SymbolStats &s_luma = stats[pick_stats_for(x, y)];
853                         rans_encoder.init_prob(s_luma);
854
855                         // Luma
856
857                         // need to reverse later
858                         rans_encoder.clear();
859                         size_t num_bytes = 0;
860                         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
861                                 for (unsigned xb = 0; xb < WIDTH; xb += 8) {
862                                         int k = coeff_y[(yb + y) * WIDTH + (xb + x)];
863                                         //printf("encoding coeff %d xb,yb=%d,%d: %d\n", y*8+x, xb, yb, k);
864                                         rans_encoder.encode_coeff(k);
865                                 }
866                                 if (yb % 16 == 8) {
867                                         num_bytes += rans_encoder.save_block(codedfp);
868                                 }
869                         }
870                         if (HEIGHT % 16 != 0) {
871                                 num_bytes += rans_encoder.save_block(codedfp);
872                         }
873                         tot_bytes += num_bytes;
874                         printf("coeff %d Y': %ld bytes\n", y * 8 + x, num_bytes);
875                 }
876         }
877
878         // Cb
879         for (unsigned y = 0; y < 8; ++y) {
880                 for (unsigned x = 0; x < 8; ++x) {
881                         SymbolStats &s_chroma = stats[pick_stats_for(x, y) + 8];
882                         //SymbolStats &s_chroma = stats[pick_stats_for(x, y)];
883                         rans_encoder.init_prob(s_chroma);
884
885                         rans_encoder.clear();
886                         size_t num_bytes = 0;
887                         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
888                                 for (unsigned xb = 0; xb < WIDTH/2; xb += 8) {
889                                         int k = coeff_cb[(yb + y) * WIDTH/2 + (xb + x)];
890                                         rans_encoder.encode_coeff(k);
891                                 }
892                                 if (yb % 16 == 8) {
893                                         num_bytes += rans_encoder.save_block(codedfp);
894                                 }
895                         }
896                         if (HEIGHT % 16 != 0) {
897                                 num_bytes += rans_encoder.save_block(codedfp);
898                         }
899                         tot_bytes += num_bytes;
900                         printf("coeff %d Cb: %ld bytes\n", y * 8 + x, num_bytes);
901                 }
902         }
903
904         // Cr
905         for (unsigned y = 0; y < 8; ++y) {
906                 for (unsigned x = 0; x < 8; ++x) {
907                         SymbolStats &s_chroma = stats[pick_stats_for(x, y) + 8];
908                         //SymbolStats &s_chroma = stats[pick_stats_for(x, y)];
909                         rans_encoder.init_prob(s_chroma);
910
911                         rans_encoder.clear();
912                         size_t num_bytes = 0;
913                         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
914                                 for (unsigned xb = 0; xb < WIDTH/2; xb += 8) {
915                                         int k = coeff_cr[(yb + y) * WIDTH/2 + (xb + x)];
916                                         rans_encoder.encode_coeff(k);
917                                 }
918                                 if (yb % 16 == 8) {
919                                         num_bytes += rans_encoder.save_block(codedfp);
920                                 }
921                         }
922                         if (HEIGHT % 16 != 0) {
923                                 num_bytes += rans_encoder.save_block(codedfp);
924                         }
925                         tot_bytes += num_bytes;
926                         printf("coeff %d Cr: %ld bytes\n", y * 8 + x, num_bytes);
927                 }
928         }
929
930         printf("%ld bytes + %ld sign bits (%ld) + %ld escape bits (%ld) = %ld total bytes\n",
931                 tot_bytes - sign_bits / 8 - extra_bits / 8,
932                 sign_bits,
933                 sign_bits / 8,
934                 extra_bits,
935                 extra_bits / 8,
936                 tot_bytes);
937 }