]> git.sesse.net Git - narabu/blob - qdc.cpp
k-means instead of k-medoids; doesn't work as well, so just keep it here to be immedi...
[narabu] / qdc.cpp
1 #include <stdio.h>
2 #include <stdint.h>
3 #include <stdlib.h>
4 #include <string.h>
5 #include <assert.h>
6 #include <math.h>
7
8 //#include "ryg_rans/rans64.h"
9 #include "ryg_rans/rans_byte.h"
10 #include "ryg_rans/renormalize.h"
11
12 #include <algorithm>
13 #include <memory>
14 #include <numeric>
15 #include <random>
16 #include <vector>
17 #include <unordered_map>
18
19 #define WIDTH 1280
20 #define HEIGHT 720
21 #define NUM_SYMS 256
22 #define ESCAPE_LIMIT (NUM_SYMS - 1)
23
24 // If you set this to 1, the program will try to optimize the placement
25 // of coefficients to rANS probability distributions. This is randomized,
26 // so you might want to run it a few times.
27 #define FIND_OPTIMAL_STREAM_ASSIGNMENT 0
28 #define NUM_CLUSTERS 8
29
30 static constexpr uint32_t prob_bits = 12;
31 static constexpr uint32_t prob_scale = 1 << prob_bits;
32
33 using namespace std;
34
35 void fdct_int32(short *const In);
36 void idct_int32(short *const In);
37
38 unsigned char rgb[WIDTH * HEIGHT * 3];
39 unsigned char pix_y[WIDTH * HEIGHT];
40 unsigned char pix_cb[(WIDTH/2) * HEIGHT];
41 unsigned char pix_cr[(WIDTH/2) * HEIGHT];
42 unsigned char full_pix_cb[WIDTH * HEIGHT];
43 unsigned char full_pix_cr[WIDTH * HEIGHT];
44 short coeff_y[WIDTH * HEIGHT], coeff_cb[(WIDTH/2) * HEIGHT], coeff_cr[(WIDTH/2) * HEIGHT];
45
46 int clamp(int x)
47 {
48         if (x < 0) return 0;
49         if (x > 255) return 255;
50         return x;
51 }
52
53 static const unsigned char std_luminance_quant_tbl[64] = {
54 #if 0
55         16,  11,  10,  16,  24,  40,  51,  61,
56         12,  12,  14,  19,  26,  58,  60,  55,
57         14,  13,  16,  24,  40,  57,  69,  56,
58         14,  17,  22,  29,  51,  87,  80,  62,
59         18,  22,  37,  56,  68, 109, 103,  77,
60         24,  35,  55,  64,  81, 104, 113,  92,
61         49,  64,  78,  87, 103, 121, 120, 101,
62         72,  92,  95,  98, 112, 100, 103,  99
63 #else
64         // ff_mpeg1_default_intra_matrix
65          8, 16, 19, 22, 26, 27, 29, 34,
66         16, 16, 22, 24, 27, 29, 34, 37,                                                 
67         19, 22, 26, 27, 29, 34, 34, 38,                                                 
68         22, 22, 26, 27, 29, 34, 37, 40,
69         22, 26, 27, 29, 32, 35, 40, 48,
70         26, 27, 29, 32, 35, 40, 48, 58,
71         26, 27, 29, 34, 38, 46, 56, 69,
72         27, 29, 35, 38, 46, 56, 69, 83
73 #endif
74 };
75
76 struct SymbolStats
77 {
78     uint32_t freqs[NUM_SYMS];
79     uint32_t cum_freqs[NUM_SYMS + 1];
80
81     void clear();
82     void calc_cum_freqs();
83     void normalize_freqs(uint32_t target_total);
84 };
85
86 void SymbolStats::clear()
87 {
88     for (int i=0; i < NUM_SYMS; i++)
89         freqs[i] = 0;
90 }
91
92 void SymbolStats::calc_cum_freqs()
93 {
94     cum_freqs[0] = 0;
95     for (int i=0; i < NUM_SYMS; i++)
96         cum_freqs[i+1] = cum_freqs[i] + freqs[i];
97 }
98
99 void SymbolStats::normalize_freqs(uint32_t target_total)
100 {
101     uint64_t real_freq[NUM_SYMS + 1];  // hack
102
103     assert(target_total >= NUM_SYMS);
104
105     calc_cum_freqs();
106     uint32_t cur_total = cum_freqs[NUM_SYMS];
107
108     if (cur_total == 0) return;
109
110     double ideal_cost = 0.0;
111     for (int i = 1; i <= NUM_SYMS; i++)
112     {
113       real_freq[i] = cum_freqs[i] - cum_freqs[i - 1];
114       if (real_freq[i] > 0)
115         ideal_cost -= real_freq[i] * log2(real_freq[i] / double(cur_total));
116     }
117
118     OptimalRenormalize(cum_freqs, NUM_SYMS, prob_scale);
119
120     // calculate updated freqs and make sure we didn't screw anything up
121     assert(cum_freqs[0] == 0 && cum_freqs[NUM_SYMS] == target_total);
122     for (int i=0; i < NUM_SYMS; i++) {
123         if (freqs[i] == 0)
124             assert(cum_freqs[i+1] == cum_freqs[i]);
125         else
126             assert(cum_freqs[i+1] > cum_freqs[i]);
127
128         // calc updated freq
129         freqs[i] = cum_freqs[i+1] - cum_freqs[i];
130     }
131
132     double calc_cost = 0.0;
133     for (int i = 1; i <= NUM_SYMS; i++)
134     {
135       uint64_t freq = cum_freqs[i] - cum_freqs[i - 1];
136       if (real_freq[i] > 0)
137         calc_cost -= real_freq[i] * log2(freq / double(target_total));
138     }
139
140     static double total_loss = 0.0;
141     total_loss += calc_cost - ideal_cost;
142     static double total_loss_with_dp = 0.0;
143         double optimal_cost = 0.0;
144     //total_loss_with_dp += optimal_cost - ideal_cost;
145     printf("ideal cost = %.0f bits, DP cost = %.0f bits, calc cost = %.0f bits (loss = %.2f bytes, total loss = %.2f bytes, total loss with DP = %.2f bytes)\n",
146                 ideal_cost, optimal_cost,
147                  calc_cost, (calc_cost - ideal_cost) / 8.0, total_loss / 8.0, total_loss_with_dp / 8.0);
148 }
149
150 SymbolStats stats[128];
151
152 #if FIND_OPTIMAL_STREAM_ASSIGNMENT
153 // Distance from one stream to the other, based on a hacked-up K-L divergence.
154 float kl_dist[64][64];
155 #endif
156
157 int pick_stats_for(int x, int y, bool is_chroma)
158 {
159 #if FIND_OPTIMAL_STREAM_ASSIGNMENT
160         return y * 8 + x + is_chroma * 64;
161 #else
162         return std::min<int>(x + y, 7) + is_chroma * 8;
163 #endif
164 }
165                 
166
167 void write_varint(int x, FILE *fp)
168 {
169         while (x >= 128) {
170                 putc((x & 0x7f) | 0x80, fp);
171                 x >>= 7;
172         }
173         putc(x, fp);
174 }
175
176 class RansEncoder {
177 public:
178         RansEncoder()
179         {
180                 out_buf.reset(new uint8_t[out_max_size]);
181                 clear();
182         }
183
184         void init_prob(SymbolStats &s)
185         {
186                 for (int i = 0; i < NUM_SYMS; i++) {
187                         printf("%d: cumfreqs=%d freqs=%d prob_bits=%d\n", i, s.cum_freqs[i], s.freqs[i], prob_bits + 1);
188                         RansEncSymbolInit(&esyms[i], s.cum_freqs[i], s.freqs[i], prob_bits + 1);
189                 }
190                 sign_bias = s.cum_freqs[256];
191         }
192
193         void clear()
194         {
195                 out_end = out_buf.get() + out_max_size;
196                 ptr = out_end; // *end* of output buffer
197                 RansEncInit(&rans);
198         }
199
200         uint32_t save_block(FILE *codedfp)  // Returns number of bytes.
201         {
202                 RansEncFlush(&rans, &ptr);
203                 //printf("post-flush = %08x\n", rans);
204
205                 uint32_t num_rans_bytes = out_end - ptr;
206 #if 0
207                 if (num_rans_bytes == 4) {
208                         uint32_t block;
209                         memcpy(&block, ptr, 4);
210
211                         if (block == last_block) {
212                                 write_varint(0, codedfp);
213                                 clear();
214                                 return 1;
215                         }
216
217                         last_block = block;
218                 } else {
219                         last_block = 0;
220                 }
221 #endif
222
223                 write_varint(num_rans_bytes, codedfp);
224                 //fwrite(&num_rans_bytes, 1, 4, codedfp);
225                 fwrite(ptr, 1, num_rans_bytes, codedfp);
226
227                 //printf("first rANS bytes: %02x %02x %02x %02x %02x %02x %02x %02x\n", ptr[0], ptr[1], ptr[2], ptr[3], ptr[4], ptr[5], ptr[6], ptr[7]);
228
229
230                 clear();
231
232                 printf("Saving block: %d rANS bytes\n", num_rans_bytes);
233                 return num_rans_bytes;
234                 //return num_rans_bytes;
235         }
236
237         void encode_coeff(short signed_k)
238         {
239                 //printf("encoding coeff %d (sym %d), rans before encoding = %08x\n", signed_k, ((abs(signed_k) - 1) & 255), rans);
240                 unsigned short k = abs(signed_k);
241                 if (k >= ESCAPE_LIMIT) {
242                         // Put the coefficient as a 1/(2^12) symbol _before_
243                         // the 255 coefficient, since the decoder will read the
244                         // 255 coefficient first.
245                         RansEncPut(&rans, &ptr, k, 1, prob_bits);
246                         k = ESCAPE_LIMIT;
247                 }
248                 RansEncPutSymbol(&rans, &ptr, &esyms[(k - 1) & 255]);
249                 if (signed_k < 0) {
250                         rans += sign_bias;
251                 }
252         }
253
254 private:
255         static constexpr size_t out_max_size = 32 << 20; // 32 MB.
256         static constexpr size_t max_num_sign = 1048576;  // Way too big. And actually bytes.
257
258         unique_ptr<uint8_t[]> out_buf;
259         uint8_t *out_end;
260         uint8_t *ptr;
261         RansState rans;
262         RansEncSymbol esyms[NUM_SYMS];
263         uint32_t sign_bias;
264
265         uint32_t last_block = 0;  // Not a valid 4-byte rANS block (?)
266 };
267
268 static constexpr int dc_scalefac = 8;  // Matches the FDCT's gain.
269 static constexpr double quant_scalefac = 4.0;  // whatever?
270
271 static inline int quantize(int f, int coeff_idx)
272 {
273         if (coeff_idx == 0) {
274                 return f / dc_scalefac;
275         }
276         if (f == 0) {
277                 return 0;
278         }
279
280         const int w = std_luminance_quant_tbl[coeff_idx];
281         const int s = quant_scalefac;
282         int sign_f = (f > 0) ? 1 : -1;
283         return (32 * f + sign_f * w * s) / (2 * w * s);
284 }
285
286 static inline int unquantize(int qf, int coeff_idx)
287 {
288         if (coeff_idx == 0) {
289                 return qf * dc_scalefac;
290         }
291         if (qf == 0) {
292                 return 0;
293         }
294
295         const int w = std_luminance_quant_tbl[coeff_idx];
296         const int s = quant_scalefac;
297         return (2 * qf * w * s) / 32;
298 }
299
300 void readpix(unsigned char *ptr, const char *filename)
301 {
302         FILE *fp = fopen(filename, "rb");
303         if (fp == nullptr) {
304                 perror(filename);
305                 exit(1);
306         }
307
308         fseek(fp, 0, SEEK_END);
309         long len = ftell(fp);
310         assert(len >= WIDTH * HEIGHT * 3);
311         fseek(fp, len - WIDTH * HEIGHT * 3, SEEK_SET);
312
313         fread(ptr, 1, WIDTH * HEIGHT * 3, fp);
314         fclose(fp);
315 }
316
317 void convert_ycbcr()
318 {
319         double coeff[3] = { 0.2126, 0.7152, 0.0722 };  // sum = 1.0
320         double cb_fac = 1.0 / (coeff[0] + coeff[1] + 1.0f - coeff[2]);  // 0.539
321         double cr_fac = 1.0 / (1.0f - coeff[0] + coeff[1] + coeff[2]);  // 0.635 
322
323         unique_ptr<float[]> temp_cb(new float[WIDTH * HEIGHT]);
324         unique_ptr<float[]> temp_cr(new float[WIDTH * HEIGHT]);
325         for (unsigned yb = 0; yb < HEIGHT; ++yb) {
326                 for (unsigned xb = 0; xb < WIDTH; ++xb) {
327                         int r = rgb[((yb * WIDTH) + xb) * 3 + 0];
328                         int g = rgb[((yb * WIDTH) + xb) * 3 + 1];
329                         int b = rgb[((yb * WIDTH) + xb) * 3 + 2];
330                         double y = std::min(std::max(coeff[0] * r + coeff[1] * g + coeff[2] * b, 0.0), 255.0);
331                         double cb = (b - y) * cb_fac + 128.0;
332                         double cr = (r - y) * cr_fac + 128.0;
333                         pix_y[(yb * WIDTH) + xb] = lrint(y);
334                         temp_cb[(yb * WIDTH) + xb] = cb;
335                         temp_cr[(yb * WIDTH) + xb] = cr;
336                         full_pix_cb[(yb * WIDTH) + xb] = lrint(std::min(std::max(cb, 0.0), 255.0));
337                         full_pix_cr[(yb * WIDTH) + xb] = lrint(std::min(std::max(cr, 0.0), 255.0));
338                 }
339         }
340
341         // Simple 4:2:2 subsampling with left convention.
342         for (unsigned yb = 0; yb < HEIGHT; ++yb) {
343                 for (unsigned xb = 0; xb < WIDTH / 2; ++xb) {
344                         int c0 = yb * WIDTH + std::max(int(xb) * 2 - 1, 0);
345                         int c1 = yb * WIDTH + xb * 2;
346                         int c2 = yb * WIDTH + xb * 2 + 1;
347                         
348                         double cb = 0.25 * temp_cb[c0] + 0.5 * temp_cb[c1] + 0.25 * temp_cb[c2];
349                         double cr = 0.25 * temp_cr[c0] + 0.5 * temp_cr[c1] + 0.25 * temp_cr[c2];
350                         cb = std::min(std::max(cb, 0.0), 255.0);
351                         cr = std::min(std::max(cr, 0.0), 255.0);
352                         pix_cb[(yb * WIDTH/2) + xb] = lrint(cb);
353                         pix_cr[(yb * WIDTH/2) + xb] = lrint(cr);
354                 }
355         }
356 }
357
358 #if FIND_OPTIMAL_STREAM_ASSIGNMENT
359 double find_best_assignment(const int *medoids, int *assignment)
360 {
361         double current_score = 0.0;
362         for (int i = 0; i < 64; ++i) {
363                 int best_medoid = medoids[0];
364                 float best_medoid_score = kl_dist[i][medoids[0]];
365                 for (int j = 1; j < NUM_CLUSTERS; ++j) {
366                         if (kl_dist[i][medoids[j]] < best_medoid_score) {
367                                 best_medoid = medoids[j];
368                                 best_medoid_score = kl_dist[i][medoids[j]];
369                         }
370                 }
371                 assignment[i] = best_medoid;
372                 current_score += best_medoid_score;
373         }
374         return current_score;
375 }
376
377 double find_inv_sum(const SymbolStats &stats)
378 {
379         double s = 0.0;
380         for (unsigned j = 0; j < NUM_SYMS; ++j) {
381                 s += stats.freqs[j] + 0.5;
382         }
383         return 1.0 / s;
384 }
385
386 void find_optimal_stream_assignment(int base)
387 {
388         // k-means init; make random assignments
389         std::random_device rd;
390         std::mt19937 g(rd());
391         std::uniform_int_distribution<> u(0, NUM_CLUSTERS - 1);
392         int assignment[64];
393         for (unsigned i = 0; i < 64; ++i) {
394                 assignment[i] = u(g);
395         }       
396         double inv_sum_coeffs[64];
397         for (unsigned i = 0; i < 64; ++i) {
398                 inv_sum_coeffs[i] = find_inv_sum(stats[i + base]);
399         }
400
401         for (unsigned iter = 0; iter < 1000; ++iter) {
402                 // make new clusters based on the current assignments
403                 SymbolStats clusters[NUM_CLUSTERS];
404                 for (unsigned i = 0; i < NUM_CLUSTERS; ++i) {
405                         clusters[i].clear();
406                 }
407                 for (unsigned i = 0; i < 64; ++i) {
408                         for (unsigned j = 0; j < NUM_SYMS; ++j) {
409                                 clusters[assignment[i]].freqs[j] += stats[i + base].freqs[j];
410                         }
411                 }
412
413                 double inv_sum_clusters[NUM_CLUSTERS];
414                 for (unsigned i = 0; i < NUM_CLUSTERS; ++i) {
415                         inv_sum_clusters[i] = find_inv_sum(clusters[i]);
416                 }
417                 
418                 // find new assignments based on distance to the clusters
419                 bool any_changed = false;
420                 double total_d = 0.0;
421                 for (unsigned i = 0; i < 64; ++i) {
422                         int best_assignment = -1;
423                         double best_distance = HUGE_VAL;
424                         for (unsigned j = 0; j < NUM_CLUSTERS; ++j) {
425                                 double d = 0.0;
426                                 for (unsigned k = 0; k < NUM_SYMS; ++k) {
427                                         double p1 = (clusters[j].freqs[k] + 0.5) * inv_sum_clusters[j];
428                                         double p2 = (stats[i + base].freqs[k] + 0.5) * inv_sum_coeffs[i];
429
430                                         // K-L divergence is asymmetric; this is a hack.
431                                         d += p1 * log(p1 / p2);
432                                         d += p2 * log(p2 / p1);
433                                 }
434                                 if (d < best_distance) {
435                                         best_assignment = j;
436                                         best_distance = d;
437                                 }
438                         }
439                         if (assignment[i] != best_assignment) {
440                                 any_changed = true;
441                         }
442                         assignment[i] = best_assignment;
443                         total_d += best_distance;
444                 }
445                 printf("iter %u: %.3f\n", iter, total_d);
446                 if (!any_changed) break;
447         }
448         printf("\n");
449         std::unordered_map<int, int> rmap;
450         for (int i = 0; i < 64; ++i) {
451                 if (i % 8 == 0) printf("\n");
452                 if (!rmap.count(assignment[i])) {
453                         rmap.emplace(assignment[i], rmap.size());
454                 }
455                 printf("%d, ", rmap[assignment[i]]);
456         }
457         printf("\n");
458 }
459 #endif
460
461 int main(int argc, char **argv)
462 {
463         if (argc >= 2)
464                 readpix(rgb, argv[1]);
465         else
466                 readpix(rgb, "color.pnm");
467         convert_ycbcr();
468
469         double sum_sq_err = 0.0;
470         //double last_cb_cfl_fac = 0.0;
471         //double last_cr_cfl_fac = 0.0;
472
473         // DCT and quantize luma
474         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
475                 for (unsigned xb = 0; xb < WIDTH; xb += 8) {
476                         // Read one block
477                         short in_y[64];
478                         for (unsigned y = 0; y < 8; ++y) {
479                                 for (unsigned x = 0; x < 8; ++x) {
480                                         in_y[y * 8 + x] = pix_y[(yb + y) * WIDTH + (xb + x)];
481                                 }
482                         }
483
484                         // FDCT it
485                         fdct_int32(in_y);
486
487                         for (unsigned y = 0; y < 8; ++y) {
488                                 for (unsigned x = 0; x < 8; ++x) {
489                                         int coeff_idx = y * 8 + x;
490                                         int k = quantize(in_y[coeff_idx], coeff_idx);
491                                         coeff_y[(yb + y) * WIDTH + (xb + x)] = k;
492
493                                         // Store back for reconstruction / PSNR calculation
494                                         in_y[coeff_idx] = unquantize(k, coeff_idx);
495                                 }
496                         }
497
498                         idct_int32(in_y);
499
500                         for (unsigned y = 0; y < 8; ++y) {
501                                 for (unsigned x = 0; x < 8; ++x) {
502                                         int k = clamp(in_y[y * 8 + x]);
503                                         uint8_t *ptr = &pix_y[(yb + y) * WIDTH + (xb + x)];
504                                         sum_sq_err += (*ptr - k) * (*ptr - k);
505                                         *ptr = k;
506                                 }
507                         }
508                 }
509         }
510         double mse = sum_sq_err / double(WIDTH * HEIGHT);
511         double psnr_db = 20 * log10(255.0 / sqrt(mse));
512         printf("psnr = %.2f dB\n", psnr_db);
513
514         //double chroma_energy = 0.0, chroma_energy_pred = 0.0;
515
516         // DCT and quantize chroma
517         //double last_cb_cfl_fac = 0.0, last_cr_cfl_fac = 0.0;
518         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
519                 for (unsigned xb = 0; xb < WIDTH/2; xb += 8) {
520 #if 0
521                         // TF switch: Two 8x8 luma blocks -> one 16x8 block, then drop high frequencies
522                         printf("in blocks:\n");
523                         for (unsigned y = 0; y < 8; ++y) {
524                                 for (unsigned x = 0; x < 8; ++x) {
525                                         short a = coeff_y[(yb + y) * WIDTH + (xb*2 + x)];
526                                         printf(" %4d", a);
527                                 }
528                                 printf(" | ");
529                                 for (unsigned x = 0; x < 8; ++x) {
530                                         short b = coeff_y[(yb + y) * WIDTH + (xb*2 + x + 8)];
531                                         printf(" %4d", b);
532                                 }
533                                 printf("\n");
534                         }
535
536                         short in_y[64];
537                         for (unsigned y = 0; y < 8; ++y) {
538                                 for (unsigned x = 0; x < 4; ++x) {
539                                         short a = coeff_y[(yb + y) * WIDTH + (xb*2 + x)];
540                                         short b = coeff_y[(yb + y) * WIDTH + (xb*2 + x + 8)];
541                                         b = a - b;
542                                         a = 2 * a - b;
543                                         in_y[y * 8 + x * 2 + 0] = a;
544                                         in_y[y * 8 + x * 2 + 1] = b;
545                                 }
546                         }
547
548                         printf("tf-ed block:\n");
549                         for (unsigned y = 0; y < 8; ++y) {
550                                 for (unsigned x = 0; x < 8; ++x) {
551                                         short a = in_y[y * 8 + x];
552                                         printf(" %4d", a);
553                                 }
554                                 printf("\n");
555                         }
556 #else
557                         // Read Y block with no tf switch (from reconstructed luma)
558                         short in_y[64];
559                         for (unsigned y = 0; y < 8; ++y) {
560                                 for (unsigned x = 0; x < 8; ++x) {
561                                         in_y[y * 8 + x] = pix_y[(yb + y) * (WIDTH) + (xb + x) * 2];
562                                 }
563                         }
564                         fdct_int32(in_y);
565 #endif
566
567                         // Read one block
568                         short in_cb[64], in_cr[64];
569                         for (unsigned y = 0; y < 8; ++y) {
570                                 for (unsigned x = 0; x < 8; ++x) {
571                                         in_cb[y * 8 + x] = pix_cb[(yb + y) * (WIDTH/2) + (xb + x)];
572                                         in_cr[y * 8 + x] = pix_cr[(yb + y) * (WIDTH/2) + (xb + x)];
573                                 }
574                         }
575
576                         // FDCT it
577                         fdct_int32(in_cb);
578                         fdct_int32(in_cr);
579
580 #if 0
581                         // Chroma from luma
582                         double x0 = in_y[1];
583                         double x1 = in_y[8];
584                         double x2 = in_y[9];
585                         double denom = (x0 * x0 + x1 * x1 + x2 * x2);
586                         //double denom = (x1 * x1);
587         
588                         double y0 = in_cb[1];
589                         double y1 = in_cb[8];
590                         double y2 = in_cb[9];
591                         double cb_cfl_fac = (x0 * y0 + x1 * y1 + x2 * y2) / denom;
592                         //double cb_cfl_fac = (x1 * y1) / denom;
593
594                         for (unsigned y = 0; y < 8; ++y) {
595                                 for (unsigned x = 0; x < 8; ++x) {
596                                         short a = in_y[y * 8 + x];
597                                         printf(" %4d", a);
598                                 }
599                                 printf(" | ");
600                                 for (unsigned x = 0; x < 8; ++x) {
601                                         short a = in_cb[y * 8 + x];
602                                         printf(" %4d", a);
603                                 }
604                                 printf("\n");
605                         }
606                         printf("(%d,%d,%d) -> (%d,%d,%d) gives %f\n",
607                                 in_y[1], in_y[8], in_y[9], 
608                                 in_cb[1], in_cb[8], in_cb[9],
609                                 cb_cfl_fac);
610
611                         y0 = in_cr[1];
612                         y1 = in_cr[8];
613                         y2 = in_cr[9];
614                         double cr_cfl_fac = (x0 * y0 + x1 * y1 + x2 * y2) / denom;
615                         //double cr_cfl_fac = (x1 * y1) / denom;
616                         printf("cb CfL = %7.3f  dc = %5d    cr CfL = %7.3f  dc = %d\n",
617                                 cb_cfl_fac, in_cb[0] - in_y[0],
618                                 cr_cfl_fac, in_cr[0] - in_y[0]);
619
620                         if (denom == 0.0) { cb_cfl_fac = cr_cfl_fac = 0.0; }
621
622                         // CHEAT
623                         //last_cb_cfl_fac = cb_cfl_fac;
624                         //last_cr_cfl_fac = cr_cfl_fac;
625
626                         for (unsigned coeff_idx = 1; coeff_idx < 64; ++coeff_idx) {
627                                 //printf("%2d: cb = %3d prediction = %f * %3d = %7.3f\n", coeff_idx, in_cb[coeff_idx], last_cb_cfl_fac, in_y[coeff_idx], last_cb_cfl_fac * in_y[coeff_idx]);
628                                 //printf("%2d: cr = %3d prediction = %f * %3d = %7.3f\n", coeff_idx, in_cr[coeff_idx], last_cr_cfl_fac, in_y[coeff_idx], last_cr_cfl_fac * in_y[coeff_idx]);
629                                 double cb_pred = last_cb_cfl_fac * in_y[coeff_idx];
630                                 chroma_energy += in_cb[coeff_idx] * in_cb[coeff_idx];
631                                 chroma_energy_pred += (in_cb[coeff_idx] - cb_pred) * (in_cb[coeff_idx] - cb_pred);
632
633                                 //in_cb[coeff_idx] -= lrint(last_cb_cfl_fac * in_y[coeff_idx]);
634                                 //in_cr[coeff_idx] -= lrint(last_cr_cfl_fac * in_y[coeff_idx]);
635                                 //in_cr[coeff_idx] -= lrint(last_cr_cfl_fac * in_y[coeff_idx]);
636                                 //in_cb[coeff_idx] = lrint(in_y[coeff_idx] * (1.0 / sqrt(2)));
637                                 //in_cr[coeff_idx] = lrint(in_y[coeff_idx] * (1.0 / sqrt(2)));
638                                 //in_cb[coeff_idx] = lrint(in_y[coeff_idx]);
639                                 //in_cr[coeff_idx] = lrint(in_y[coeff_idx]);
640                         }
641                         //in_cb[0] += 1024;
642                         //in_cr[0] += 1024;
643                         //in_cb[0] -= in_y[0];
644                         //in_cr[0] -= in_y[0];
645 #endif
646
647                         for (unsigned y = 0; y < 8; ++y) {
648                                 for (unsigned x = 0; x < 8; ++x) {
649                                         int coeff_idx = y * 8 + x;
650                                         int k_cb = quantize(in_cb[coeff_idx], coeff_idx);
651                                         coeff_cb[(yb + y) * (WIDTH/2) + (xb + x)] = k_cb;
652                                         int k_cr = quantize(in_cr[coeff_idx], coeff_idx);
653                                         coeff_cr[(yb + y) * (WIDTH/2) + (xb + x)] = k_cr;
654
655                                         // Store back for reconstruction / PSNR calculation
656                                         in_cb[coeff_idx] = unquantize(k_cb, coeff_idx);
657                                         in_cr[coeff_idx] = unquantize(k_cr, coeff_idx);
658                                 }
659                         }
660
661                         idct_int32(in_y);  // DEBUG
662                         idct_int32(in_cb);
663                         idct_int32(in_cr);
664
665                         for (unsigned y = 0; y < 8; ++y) {
666                                 for (unsigned x = 0; x < 8; ++x) {
667                                         pix_cb[(yb + y) * (WIDTH/2) + (xb + x)] = clamp(in_cb[y * 8 + x]);
668                                         pix_cr[(yb + y) * (WIDTH/2) + (xb + x)] = clamp(in_cr[y * 8 + x]);
669
670                         //              pix_cb[(yb + y) * (WIDTH/2) + (xb + x)] = in_y[y * 8 + x];
671                         //              pix_cr[(yb + y) * (WIDTH/2) + (xb + x)] = in_y[y * 8 + x];
672                                 }
673                         }
674
675 #if 0
676                         last_cb_cfl_fac = cb_cfl_fac;
677                         last_cr_cfl_fac = cr_cfl_fac;
678 #endif
679                 }
680         }
681
682 #if 0
683         printf("chroma_energy = %f, with_pred = %f\n",
684                 chroma_energy / (WIDTH * HEIGHT), chroma_energy_pred / (WIDTH * HEIGHT));
685 #endif
686
687         // DC coefficient pred from the right to left
688         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
689                 for (unsigned xb = 0; xb < WIDTH - 8; xb += 8) {
690                         coeff_y[yb * WIDTH + xb] -= coeff_y[yb * WIDTH + (xb + 8)];
691                 }
692         }
693
694         FILE *fp = fopen("reconstructed.pgm", "wb");
695         fprintf(fp, "P5\n%d %d\n255\n", WIDTH, HEIGHT);
696         fwrite(pix_y, 1, WIDTH * HEIGHT, fp);
697         fclose(fp);
698
699         fp = fopen("reconstructed.pnm", "wb");
700         fprintf(fp, "P6\n%d %d\n255\n", WIDTH, HEIGHT);
701         for (unsigned yb = 0; yb < HEIGHT; ++yb) {
702                 for (unsigned xb = 0; xb < WIDTH; ++xb) {
703                         int y = pix_y[(yb * WIDTH) + xb];
704                         int cb, cr;
705                         int c0 = yb * (WIDTH/2) + xb/2;
706                         if (xb % 2 == 0) {
707                                 cb = pix_cb[c0] - 128.0;
708                                 cr = pix_cr[c0] - 128.0;
709                         } else {
710                                 int c1 = yb * (WIDTH/2) + std::min<int>(xb/2 + 1, WIDTH/2 - 1);
711                                 cb = 0.5 * (pix_cb[c0] + pix_cb[c1]) - 128.0;
712                                 cr = 0.5 * (pix_cr[c0] + pix_cr[c1]) - 128.0;
713                         }
714
715                         double r = y + 1.5748 * cr;
716                         double g = y - 0.1873 * cb - 0.4681 * cr;
717                         double b = y + 1.8556 * cb;
718
719                         putc(clamp(lrint(r)), fp);
720                         putc(clamp(lrint(g)), fp);
721                         putc(clamp(lrint(b)), fp);
722                 }
723         }
724         fclose(fp);
725
726         // For each coefficient, make some tables.
727         size_t extra_bits = 0;
728         for (unsigned i = 0; i < 64; ++i) {
729                 stats[i].clear();
730         }
731         for (unsigned y = 0; y < 8; ++y) {
732                 for (unsigned x = 0; x < 8; ++x) {
733                         SymbolStats &s_luma = stats[pick_stats_for(x, y, false)];
734                         SymbolStats &s_chroma = stats[pick_stats_for(x, y, true)];
735
736                         // Luma
737                         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
738                                 for (unsigned xb = 0; xb < WIDTH; xb += 8) {
739                                         unsigned short k = abs(coeff_y[(yb + y) * WIDTH + (xb + x)]);
740                                         if (k >= ESCAPE_LIMIT) {
741                                                 k = ESCAPE_LIMIT;
742                                                 extra_bits += 12;  // escape this one
743                                         }
744                                         ++s_luma.freqs[(k - 1) & 255];
745                                 }
746                         }
747                         // Chroma
748                         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
749                                 for (unsigned xb = 0; xb < WIDTH/2; xb += 8) {
750                                         unsigned short k_cb = abs(coeff_cb[(yb + y) * WIDTH/2 + (xb + x)]);
751                                         unsigned short k_cr = abs(coeff_cr[(yb + y) * WIDTH/2 + (xb + x)]);
752                                         if (k_cb >= ESCAPE_LIMIT) {
753                                                 k_cb = ESCAPE_LIMIT;
754                                                 extra_bits += 12;  // escape this one
755                                         }
756                                         if (k_cr >= ESCAPE_LIMIT) {
757                                                 k_cr = ESCAPE_LIMIT;
758                                                 extra_bits += 12;  // escape this one
759                                         }
760                                         ++s_chroma.freqs[(k_cb - 1) & 255];
761                                         ++s_chroma.freqs[(k_cr - 1) & 255];
762                                 }
763                         }
764                 }
765         }
766
767 #if FIND_OPTIMAL_STREAM_ASSIGNMENT
768         printf("Luma:\n");
769         find_optimal_stream_assignment(0);
770         printf("Chroma:\n");
771         find_optimal_stream_assignment(64);
772         exit(0);
773 #endif
774
775         for (unsigned i = 0; i < 64; ++i) {
776                 stats[i].freqs[255] /= 2;  // zero, has no sign bits (yes, this is trickery)
777                 stats[i].normalize_freqs(prob_scale);
778                 stats[i].cum_freqs[256] += stats[i].freqs[255];
779                 stats[i].freqs[255] *= 2;
780         }
781
782         FILE *codedfp = fopen("coded.dat", "wb");
783         if (codedfp == nullptr) {
784                 perror("coded.dat");
785                 exit(1);
786         }
787
788         // TODO: rather gamma-k or something
789         for (unsigned i = 0; i < 64; ++i) {
790                 if (stats[i].cum_freqs[NUM_SYMS] == 0) {
791                         continue;
792                 }
793                 printf("writing table %d\n", i);
794                 for (unsigned j = 0; j < NUM_SYMS; ++j) {
795                         write_varint(stats[i].freqs[j], codedfp);
796                 }
797         }
798
799         RansEncoder rans_encoder;
800
801         size_t tot_bytes = 0;
802
803         // Luma
804         for (unsigned y = 0; y < 8; ++y) {
805                 for (unsigned x = 0; x < 8; ++x) {
806                         SymbolStats &s_luma = stats[pick_stats_for(x, y, false)];
807                         rans_encoder.init_prob(s_luma);
808
809                         // Luma
810                         std::vector<int> lens;
811
812                         // need to reverse later
813                         rans_encoder.clear();
814                         size_t num_bytes = 0;
815                         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
816                                 for (unsigned xb = 0; xb < WIDTH; xb += 8) {
817                                         int k = coeff_y[(yb + y) * WIDTH + (xb + x)];
818                                         //printf("encoding coeff %d xb,yb=%d,%d: %d\n", y*8+x, xb, yb, k);
819                                         rans_encoder.encode_coeff(k);
820                                 }
821                                 if (yb % 16 == 8) {
822                                         int l = rans_encoder.save_block(codedfp);
823                                         num_bytes += l;
824                                         lens.push_back(l);
825                                 }
826                         }
827                         if (HEIGHT % 16 != 0) {
828                                 num_bytes += rans_encoder.save_block(codedfp);
829                         }
830                         tot_bytes += num_bytes;
831                         printf("coeff %d Y': %ld bytes\n", y * 8 + x, num_bytes);
832
833                         double sum_l = 0.0;
834                         for (int l : lens) {
835                                 sum_l += l;
836                         }
837                         double avg_l = sum_l / lens.size();
838
839                         double sum_sql = 0.0;
840                         for (int l : lens) {
841                                 sum_sql += (l - avg_l) * (l - avg_l);
842                         }
843                         double stddev_l = sqrt(sum_sql / (lens.size() - 1));
844                         printf("coeff %d: avg=%.2f bytes, stddev=%.2f bytes\n", y*8+x, avg_l, stddev_l);
845                 }
846         }
847
848         // Cb
849         for (unsigned y = 0; y < 8; ++y) {
850                 for (unsigned x = 0; x < 8; ++x) {
851                         SymbolStats &s_chroma = stats[pick_stats_for(x, y, true)];
852                         rans_encoder.init_prob(s_chroma);
853
854                         rans_encoder.clear();
855                         size_t num_bytes = 0;
856                         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
857                                 for (unsigned xb = 0; xb < WIDTH/2; xb += 8) {
858                                         int k = coeff_cb[(yb + y) * WIDTH/2 + (xb + x)];
859                                         rans_encoder.encode_coeff(k);
860                                 }
861                                 if (yb % 16 == 8) {
862                                         num_bytes += rans_encoder.save_block(codedfp);
863                                 }
864                         }
865                         if (HEIGHT % 16 != 0) {
866                                 num_bytes += rans_encoder.save_block(codedfp);
867                         }
868                         tot_bytes += num_bytes;
869                         printf("coeff %d Cb: %ld bytes\n", y * 8 + x, num_bytes);
870                 }
871         }
872
873         // Cr
874         for (unsigned y = 0; y < 8; ++y) {
875                 for (unsigned x = 0; x < 8; ++x) {
876                         SymbolStats &s_chroma = stats[pick_stats_for(x, y, true)];
877                         rans_encoder.init_prob(s_chroma);
878
879                         rans_encoder.clear();
880                         size_t num_bytes = 0;
881                         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
882                                 for (unsigned xb = 0; xb < WIDTH/2; xb += 8) {
883                                         int k = coeff_cr[(yb + y) * WIDTH/2 + (xb + x)];
884                                         rans_encoder.encode_coeff(k);
885                                 }
886                                 if (yb % 16 == 8) {
887                                         num_bytes += rans_encoder.save_block(codedfp);
888                                 }
889                         }
890                         if (HEIGHT % 16 != 0) {
891                                 num_bytes += rans_encoder.save_block(codedfp);
892                         }
893                         tot_bytes += num_bytes;
894                         printf("coeff %d Cr: %ld bytes\n", y * 8 + x, num_bytes);
895                 }
896         }
897
898         printf("%ld bytes + %ld escape bits (%ld) = %ld total bytes\n",
899                 tot_bytes - extra_bits / 8,
900                 extra_bits,
901                 extra_bits / 8,
902                 tot_bytes);
903 }