]> git.sesse.net Git - narabu/blob - qdc.cpp
Revert "k-means instead of k-medoids; doesn't work as well, so just keep it here...
[narabu] / qdc.cpp
1 #include <stdio.h>
2 #include <stdint.h>
3 #include <stdlib.h>
4 #include <string.h>
5 #include <assert.h>
6 #include <math.h>
7
8 //#include "ryg_rans/rans64.h"
9 #include "ryg_rans/rans_byte.h"
10 #include "ryg_rans/renormalize.h"
11
12 #include <algorithm>
13 #include <memory>
14 #include <numeric>
15 #include <random>
16 #include <vector>
17 #include <unordered_map>
18
19 #define WIDTH 1280
20 #define HEIGHT 720
21 #define NUM_SYMS 256
22 #define ESCAPE_LIMIT (NUM_SYMS - 1)
23
24 // If you set this to 1, the program will try to optimize the placement
25 // of coefficients to rANS probability distributions. This is randomized,
26 // so you might want to run it a few times.
27 #define FIND_OPTIMAL_STREAM_ASSIGNMENT 0
28 #define NUM_CLUSTERS 8
29
30 static constexpr uint32_t prob_bits = 12;
31 static constexpr uint32_t prob_scale = 1 << prob_bits;
32
33 using namespace std;
34
35 void fdct_int32(short *const In);
36 void idct_int32(short *const In);
37
38 unsigned char rgb[WIDTH * HEIGHT * 3];
39 unsigned char pix_y[WIDTH * HEIGHT];
40 unsigned char pix_cb[(WIDTH/2) * HEIGHT];
41 unsigned char pix_cr[(WIDTH/2) * HEIGHT];
42 unsigned char full_pix_cb[WIDTH * HEIGHT];
43 unsigned char full_pix_cr[WIDTH * HEIGHT];
44 short coeff_y[WIDTH * HEIGHT], coeff_cb[(WIDTH/2) * HEIGHT], coeff_cr[(WIDTH/2) * HEIGHT];
45
46 int clamp(int x)
47 {
48         if (x < 0) return 0;
49         if (x > 255) return 255;
50         return x;
51 }
52
53 static const unsigned char std_luminance_quant_tbl[64] = {
54 #if 0
55         16,  11,  10,  16,  24,  40,  51,  61,
56         12,  12,  14,  19,  26,  58,  60,  55,
57         14,  13,  16,  24,  40,  57,  69,  56,
58         14,  17,  22,  29,  51,  87,  80,  62,
59         18,  22,  37,  56,  68, 109, 103,  77,
60         24,  35,  55,  64,  81, 104, 113,  92,
61         49,  64,  78,  87, 103, 121, 120, 101,
62         72,  92,  95,  98, 112, 100, 103,  99
63 #else
64         // ff_mpeg1_default_intra_matrix
65          8, 16, 19, 22, 26, 27, 29, 34,
66         16, 16, 22, 24, 27, 29, 34, 37,                                                 
67         19, 22, 26, 27, 29, 34, 34, 38,                                                 
68         22, 22, 26, 27, 29, 34, 37, 40,
69         22, 26, 27, 29, 32, 35, 40, 48,
70         26, 27, 29, 32, 35, 40, 48, 58,
71         26, 27, 29, 34, 38, 46, 56, 69,
72         27, 29, 35, 38, 46, 56, 69, 83
73 #endif
74 };
75
76 struct SymbolStats
77 {
78     uint32_t freqs[NUM_SYMS];
79     uint32_t cum_freqs[NUM_SYMS + 1];
80
81     void clear();
82     void calc_cum_freqs();
83     void normalize_freqs(uint32_t target_total);
84 };
85
86 void SymbolStats::clear()
87 {
88     for (int i=0; i < NUM_SYMS; i++)
89         freqs[i] = 0;
90 }
91
92 void SymbolStats::calc_cum_freqs()
93 {
94     cum_freqs[0] = 0;
95     for (int i=0; i < NUM_SYMS; i++)
96         cum_freqs[i+1] = cum_freqs[i] + freqs[i];
97 }
98
99 void SymbolStats::normalize_freqs(uint32_t target_total)
100 {
101     uint64_t real_freq[NUM_SYMS + 1];  // hack
102
103     assert(target_total >= NUM_SYMS);
104
105     calc_cum_freqs();
106     uint32_t cur_total = cum_freqs[NUM_SYMS];
107
108     if (cur_total == 0) return;
109
110     double ideal_cost = 0.0;
111     for (int i = 1; i <= NUM_SYMS; i++)
112     {
113       real_freq[i] = cum_freqs[i] - cum_freqs[i - 1];
114       if (real_freq[i] > 0)
115         ideal_cost -= real_freq[i] * log2(real_freq[i] / double(cur_total));
116     }
117
118     OptimalRenormalize(cum_freqs, NUM_SYMS, prob_scale);
119
120     // calculate updated freqs and make sure we didn't screw anything up
121     assert(cum_freqs[0] == 0 && cum_freqs[NUM_SYMS] == target_total);
122     for (int i=0; i < NUM_SYMS; i++) {
123         if (freqs[i] == 0)
124             assert(cum_freqs[i+1] == cum_freqs[i]);
125         else
126             assert(cum_freqs[i+1] > cum_freqs[i]);
127
128         // calc updated freq
129         freqs[i] = cum_freqs[i+1] - cum_freqs[i];
130     }
131
132     double calc_cost = 0.0;
133     for (int i = 1; i <= NUM_SYMS; i++)
134     {
135       uint64_t freq = cum_freqs[i] - cum_freqs[i - 1];
136       if (real_freq[i] > 0)
137         calc_cost -= real_freq[i] * log2(freq / double(target_total));
138     }
139
140     static double total_loss = 0.0;
141     total_loss += calc_cost - ideal_cost;
142     static double total_loss_with_dp = 0.0;
143         double optimal_cost = 0.0;
144     //total_loss_with_dp += optimal_cost - ideal_cost;
145     printf("ideal cost = %.0f bits, DP cost = %.0f bits, calc cost = %.0f bits (loss = %.2f bytes, total loss = %.2f bytes, total loss with DP = %.2f bytes)\n",
146                 ideal_cost, optimal_cost,
147                  calc_cost, (calc_cost - ideal_cost) / 8.0, total_loss / 8.0, total_loss_with_dp / 8.0);
148 }
149
150 SymbolStats stats[128];
151
152 #if FIND_OPTIMAL_STREAM_ASSIGNMENT
153 // Distance from one stream to the other, based on a hacked-up K-L divergence.
154 float kl_dist[64][64];
155 #endif
156
157 int pick_stats_for(int x, int y, bool is_chroma)
158 {
159 #if FIND_OPTIMAL_STREAM_ASSIGNMENT
160         return y * 8 + x + is_chroma * 64;
161 #else
162         return std::min<int>(x + y, 7) + is_chroma * 8;
163 #endif
164 }
165                 
166
167 void write_varint(int x, FILE *fp)
168 {
169         while (x >= 128) {
170                 putc((x & 0x7f) | 0x80, fp);
171                 x >>= 7;
172         }
173         putc(x, fp);
174 }
175
176 class RansEncoder {
177 public:
178         RansEncoder()
179         {
180                 out_buf.reset(new uint8_t[out_max_size]);
181                 clear();
182         }
183
184         void init_prob(SymbolStats &s)
185         {
186                 for (int i = 0; i < NUM_SYMS; i++) {
187                         printf("%d: cumfreqs=%d freqs=%d prob_bits=%d\n", i, s.cum_freqs[i], s.freqs[i], prob_bits + 1);
188                         RansEncSymbolInit(&esyms[i], s.cum_freqs[i], s.freqs[i], prob_bits + 1);
189                 }
190                 sign_bias = s.cum_freqs[256];
191         }
192
193         void clear()
194         {
195                 out_end = out_buf.get() + out_max_size;
196                 ptr = out_end; // *end* of output buffer
197                 RansEncInit(&rans);
198         }
199
200         uint32_t save_block(FILE *codedfp)  // Returns number of bytes.
201         {
202                 RansEncFlush(&rans, &ptr);
203                 //printf("post-flush = %08x\n", rans);
204
205                 uint32_t num_rans_bytes = out_end - ptr;
206 #if 0
207                 if (num_rans_bytes == 4) {
208                         uint32_t block;
209                         memcpy(&block, ptr, 4);
210
211                         if (block == last_block) {
212                                 write_varint(0, codedfp);
213                                 clear();
214                                 return 1;
215                         }
216
217                         last_block = block;
218                 } else {
219                         last_block = 0;
220                 }
221 #endif
222
223                 write_varint(num_rans_bytes, codedfp);
224                 //fwrite(&num_rans_bytes, 1, 4, codedfp);
225                 fwrite(ptr, 1, num_rans_bytes, codedfp);
226
227                 //printf("first rANS bytes: %02x %02x %02x %02x %02x %02x %02x %02x\n", ptr[0], ptr[1], ptr[2], ptr[3], ptr[4], ptr[5], ptr[6], ptr[7]);
228
229
230                 clear();
231
232                 printf("Saving block: %d rANS bytes\n", num_rans_bytes);
233                 return num_rans_bytes;
234                 //return num_rans_bytes;
235         }
236
237         void encode_coeff(short signed_k)
238         {
239                 //printf("encoding coeff %d (sym %d), rans before encoding = %08x\n", signed_k, ((abs(signed_k) - 1) & 255), rans);
240                 unsigned short k = abs(signed_k);
241                 if (k >= ESCAPE_LIMIT) {
242                         // Put the coefficient as a 1/(2^12) symbol _before_
243                         // the 255 coefficient, since the decoder will read the
244                         // 255 coefficient first.
245                         RansEncPut(&rans, &ptr, k, 1, prob_bits);
246                         k = ESCAPE_LIMIT;
247                 }
248                 RansEncPutSymbol(&rans, &ptr, &esyms[(k - 1) & 255]);
249                 if (signed_k < 0) {
250                         rans += sign_bias;
251                 }
252         }
253
254 private:
255         static constexpr size_t out_max_size = 32 << 20; // 32 MB.
256         static constexpr size_t max_num_sign = 1048576;  // Way too big. And actually bytes.
257
258         unique_ptr<uint8_t[]> out_buf;
259         uint8_t *out_end;
260         uint8_t *ptr;
261         RansState rans;
262         RansEncSymbol esyms[NUM_SYMS];
263         uint32_t sign_bias;
264
265         uint32_t last_block = 0;  // Not a valid 4-byte rANS block (?)
266 };
267
268 static constexpr int dc_scalefac = 8;  // Matches the FDCT's gain.
269 static constexpr double quant_scalefac = 4.0;  // whatever?
270
271 static inline int quantize(int f, int coeff_idx)
272 {
273         if (coeff_idx == 0) {
274                 return f / dc_scalefac;
275         }
276         if (f == 0) {
277                 return 0;
278         }
279
280         const int w = std_luminance_quant_tbl[coeff_idx];
281         const int s = quant_scalefac;
282         int sign_f = (f > 0) ? 1 : -1;
283         return (32 * f + sign_f * w * s) / (2 * w * s);
284 }
285
286 static inline int unquantize(int qf, int coeff_idx)
287 {
288         if (coeff_idx == 0) {
289                 return qf * dc_scalefac;
290         }
291         if (qf == 0) {
292                 return 0;
293         }
294
295         const int w = std_luminance_quant_tbl[coeff_idx];
296         const int s = quant_scalefac;
297         return (2 * qf * w * s) / 32;
298 }
299
300 void readpix(unsigned char *ptr, const char *filename)
301 {
302         FILE *fp = fopen(filename, "rb");
303         if (fp == nullptr) {
304                 perror(filename);
305                 exit(1);
306         }
307
308         fseek(fp, 0, SEEK_END);
309         long len = ftell(fp);
310         assert(len >= WIDTH * HEIGHT * 3);
311         fseek(fp, len - WIDTH * HEIGHT * 3, SEEK_SET);
312
313         fread(ptr, 1, WIDTH * HEIGHT * 3, fp);
314         fclose(fp);
315 }
316
317 void convert_ycbcr()
318 {
319         double coeff[3] = { 0.2126, 0.7152, 0.0722 };  // sum = 1.0
320         double cb_fac = 1.0 / (coeff[0] + coeff[1] + 1.0f - coeff[2]);  // 0.539
321         double cr_fac = 1.0 / (1.0f - coeff[0] + coeff[1] + coeff[2]);  // 0.635 
322
323         unique_ptr<float[]> temp_cb(new float[WIDTH * HEIGHT]);
324         unique_ptr<float[]> temp_cr(new float[WIDTH * HEIGHT]);
325         for (unsigned yb = 0; yb < HEIGHT; ++yb) {
326                 for (unsigned xb = 0; xb < WIDTH; ++xb) {
327                         int r = rgb[((yb * WIDTH) + xb) * 3 + 0];
328                         int g = rgb[((yb * WIDTH) + xb) * 3 + 1];
329                         int b = rgb[((yb * WIDTH) + xb) * 3 + 2];
330                         double y = std::min(std::max(coeff[0] * r + coeff[1] * g + coeff[2] * b, 0.0), 255.0);
331                         double cb = (b - y) * cb_fac + 128.0;
332                         double cr = (r - y) * cr_fac + 128.0;
333                         pix_y[(yb * WIDTH) + xb] = lrint(y);
334                         temp_cb[(yb * WIDTH) + xb] = cb;
335                         temp_cr[(yb * WIDTH) + xb] = cr;
336                         full_pix_cb[(yb * WIDTH) + xb] = lrint(std::min(std::max(cb, 0.0), 255.0));
337                         full_pix_cr[(yb * WIDTH) + xb] = lrint(std::min(std::max(cr, 0.0), 255.0));
338                 }
339         }
340
341         // Simple 4:2:2 subsampling with left convention.
342         for (unsigned yb = 0; yb < HEIGHT; ++yb) {
343                 for (unsigned xb = 0; xb < WIDTH / 2; ++xb) {
344                         int c0 = yb * WIDTH + std::max(int(xb) * 2 - 1, 0);
345                         int c1 = yb * WIDTH + xb * 2;
346                         int c2 = yb * WIDTH + xb * 2 + 1;
347                         
348                         double cb = 0.25 * temp_cb[c0] + 0.5 * temp_cb[c1] + 0.25 * temp_cb[c2];
349                         double cr = 0.25 * temp_cr[c0] + 0.5 * temp_cr[c1] + 0.25 * temp_cr[c2];
350                         cb = std::min(std::max(cb, 0.0), 255.0);
351                         cr = std::min(std::max(cr, 0.0), 255.0);
352                         pix_cb[(yb * WIDTH/2) + xb] = lrint(cb);
353                         pix_cr[(yb * WIDTH/2) + xb] = lrint(cr);
354                 }
355         }
356 }
357
358 #if FIND_OPTIMAL_STREAM_ASSIGNMENT
359 double find_best_assignment(const int *medoids, int *assignment)
360 {
361         double current_score = 0.0;
362         for (int i = 0; i < 64; ++i) {
363                 int best_medoid = medoids[0];
364                 float best_medoid_score = kl_dist[i][medoids[0]];
365                 for (int j = 1; j < NUM_CLUSTERS; ++j) {
366                         if (kl_dist[i][medoids[j]] < best_medoid_score) {
367                                 best_medoid = medoids[j];
368                                 best_medoid_score = kl_dist[i][medoids[j]];
369                         }
370                 }
371                 assignment[i] = best_medoid;
372                 current_score += best_medoid_score;
373         }
374         return current_score;
375 }
376
377 void find_optimal_stream_assignment(int base)
378 {
379         double inv_sum[64];
380         for (unsigned i = 0; i < 64; ++i) {
381                 double s = 0.0;
382                 for (unsigned k = 0; k < 256; ++k) {
383                         s += stats[i + base].freqs[k] + 0.5;
384                 }
385                 inv_sum[i] = 1.0 / s;
386         }
387
388         for (unsigned i = 0; i < 64; ++i) {
389                 for (unsigned j = 0; j < 64; ++j) {
390                         double d = 0.0;
391                         for (unsigned k = 0; k < 256; ++k) {
392                                 double p1 = (stats[i + base].freqs[k] + 0.5) * inv_sum[i];
393                                 double p2 = (stats[j + base].freqs[k] + 0.5) * inv_sum[j];
394
395                                 // K-L divergence is asymmetric; this is a hack.
396                                 d += p1 * log(p1 / p2);
397                                 d += p2 * log(p2 / p1);
398                         }
399                         kl_dist[i][j] = d;
400                         //printf("%.3f ", d);
401                 }
402                 //printf("\n");
403         }
404
405         // k-medoids init
406         int medoids[64];  // only the first NUM_CLUSTERS matter
407         bool is_medoid[64] = { false };
408         std::iota(medoids, medoids + 64, 0);
409         std::random_device rd;
410         std::mt19937 g(rd());
411         std::shuffle(medoids, medoids + 64, g);
412         for (int i = 0; i < NUM_CLUSTERS; ++i) {
413                 printf("%d ", medoids[i]);
414                 is_medoid[medoids[i]] = true;
415         }
416         printf("\n");
417
418         // assign each data point to the closest medoid
419         int assignment[64];
420         double current_score = find_best_assignment(medoids, assignment);
421
422         for (int i = 0; i < 1000; ++i) {
423                 printf("iter %d\n", i);
424                 bool any_changed = false;
425                 for (int m = 0; m < NUM_CLUSTERS; ++m) {
426                         for (int o = 0; o < 64; ++o) {
427                                 if (is_medoid[o]) continue;
428                                 int old_medoid = medoids[m];
429                                 medoids[m] = o;
430
431                                 int new_assignment[64];
432                                 double candidate_score = find_best_assignment(medoids, new_assignment);
433
434                                 if (candidate_score < current_score) {
435                                         current_score = candidate_score;
436                                         memcpy(assignment, new_assignment, sizeof(assignment));
437
438                                         is_medoid[old_medoid] = false;
439                                         is_medoid[medoids[m]] = true;
440                                         printf("%f: ", current_score);
441                                         for (int i = 0; i < 64; ++i) {
442                                                 printf("%d ", assignment[i]);
443                                         }
444                                         printf("\n");
445                                         any_changed = true;
446                                 } else {
447                                         medoids[m] = old_medoid;
448                                 }
449                         }
450                 }
451                 if (!any_changed) break;
452         }
453         printf("\n");
454         std::unordered_map<int, int> rmap;
455         for (int i = 0; i < 64; ++i) {
456                 if (i % 8 == 0) printf("\n");
457                 if (!rmap.count(assignment[i])) {
458                         rmap.emplace(assignment[i], rmap.size());
459                 }
460                 printf("%d, ", rmap[assignment[i]]);
461         }
462         printf("\n");
463 }
464 #endif
465
466 int main(int argc, char **argv)
467 {
468         if (argc >= 2)
469                 readpix(rgb, argv[1]);
470         else
471                 readpix(rgb, "color.pnm");
472         convert_ycbcr();
473
474         double sum_sq_err = 0.0;
475         //double last_cb_cfl_fac = 0.0;
476         //double last_cr_cfl_fac = 0.0;
477
478         // DCT and quantize luma
479         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
480                 for (unsigned xb = 0; xb < WIDTH; xb += 8) {
481                         // Read one block
482                         short in_y[64];
483                         for (unsigned y = 0; y < 8; ++y) {
484                                 for (unsigned x = 0; x < 8; ++x) {
485                                         in_y[y * 8 + x] = pix_y[(yb + y) * WIDTH + (xb + x)];
486                                 }
487                         }
488
489                         // FDCT it
490                         fdct_int32(in_y);
491
492                         for (unsigned y = 0; y < 8; ++y) {
493                                 for (unsigned x = 0; x < 8; ++x) {
494                                         int coeff_idx = y * 8 + x;
495                                         int k = quantize(in_y[coeff_idx], coeff_idx);
496                                         coeff_y[(yb + y) * WIDTH + (xb + x)] = k;
497
498                                         // Store back for reconstruction / PSNR calculation
499                                         in_y[coeff_idx] = unquantize(k, coeff_idx);
500                                 }
501                         }
502
503                         idct_int32(in_y);
504
505                         for (unsigned y = 0; y < 8; ++y) {
506                                 for (unsigned x = 0; x < 8; ++x) {
507                                         int k = clamp(in_y[y * 8 + x]);
508                                         uint8_t *ptr = &pix_y[(yb + y) * WIDTH + (xb + x)];
509                                         sum_sq_err += (*ptr - k) * (*ptr - k);
510                                         *ptr = k;
511                                 }
512                         }
513                 }
514         }
515         double mse = sum_sq_err / double(WIDTH * HEIGHT);
516         double psnr_db = 20 * log10(255.0 / sqrt(mse));
517         printf("psnr = %.2f dB\n", psnr_db);
518
519         //double chroma_energy = 0.0, chroma_energy_pred = 0.0;
520
521         // DCT and quantize chroma
522         //double last_cb_cfl_fac = 0.0, last_cr_cfl_fac = 0.0;
523         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
524                 for (unsigned xb = 0; xb < WIDTH/2; xb += 8) {
525 #if 0
526                         // TF switch: Two 8x8 luma blocks -> one 16x8 block, then drop high frequencies
527                         printf("in blocks:\n");
528                         for (unsigned y = 0; y < 8; ++y) {
529                                 for (unsigned x = 0; x < 8; ++x) {
530                                         short a = coeff_y[(yb + y) * WIDTH + (xb*2 + x)];
531                                         printf(" %4d", a);
532                                 }
533                                 printf(" | ");
534                                 for (unsigned x = 0; x < 8; ++x) {
535                                         short b = coeff_y[(yb + y) * WIDTH + (xb*2 + x + 8)];
536                                         printf(" %4d", b);
537                                 }
538                                 printf("\n");
539                         }
540
541                         short in_y[64];
542                         for (unsigned y = 0; y < 8; ++y) {
543                                 for (unsigned x = 0; x < 4; ++x) {
544                                         short a = coeff_y[(yb + y) * WIDTH + (xb*2 + x)];
545                                         short b = coeff_y[(yb + y) * WIDTH + (xb*2 + x + 8)];
546                                         b = a - b;
547                                         a = 2 * a - b;
548                                         in_y[y * 8 + x * 2 + 0] = a;
549                                         in_y[y * 8 + x * 2 + 1] = b;
550                                 }
551                         }
552
553                         printf("tf-ed block:\n");
554                         for (unsigned y = 0; y < 8; ++y) {
555                                 for (unsigned x = 0; x < 8; ++x) {
556                                         short a = in_y[y * 8 + x];
557                                         printf(" %4d", a);
558                                 }
559                                 printf("\n");
560                         }
561 #else
562                         // Read Y block with no tf switch (from reconstructed luma)
563                         short in_y[64];
564                         for (unsigned y = 0; y < 8; ++y) {
565                                 for (unsigned x = 0; x < 8; ++x) {
566                                         in_y[y * 8 + x] = pix_y[(yb + y) * (WIDTH) + (xb + x) * 2];
567                                 }
568                         }
569                         fdct_int32(in_y);
570 #endif
571
572                         // Read one block
573                         short in_cb[64], in_cr[64];
574                         for (unsigned y = 0; y < 8; ++y) {
575                                 for (unsigned x = 0; x < 8; ++x) {
576                                         in_cb[y * 8 + x] = pix_cb[(yb + y) * (WIDTH/2) + (xb + x)];
577                                         in_cr[y * 8 + x] = pix_cr[(yb + y) * (WIDTH/2) + (xb + x)];
578                                 }
579                         }
580
581                         // FDCT it
582                         fdct_int32(in_cb);
583                         fdct_int32(in_cr);
584
585 #if 0
586                         // Chroma from luma
587                         double x0 = in_y[1];
588                         double x1 = in_y[8];
589                         double x2 = in_y[9];
590                         double denom = (x0 * x0 + x1 * x1 + x2 * x2);
591                         //double denom = (x1 * x1);
592         
593                         double y0 = in_cb[1];
594                         double y1 = in_cb[8];
595                         double y2 = in_cb[9];
596                         double cb_cfl_fac = (x0 * y0 + x1 * y1 + x2 * y2) / denom;
597                         //double cb_cfl_fac = (x1 * y1) / denom;
598
599                         for (unsigned y = 0; y < 8; ++y) {
600                                 for (unsigned x = 0; x < 8; ++x) {
601                                         short a = in_y[y * 8 + x];
602                                         printf(" %4d", a);
603                                 }
604                                 printf(" | ");
605                                 for (unsigned x = 0; x < 8; ++x) {
606                                         short a = in_cb[y * 8 + x];
607                                         printf(" %4d", a);
608                                 }
609                                 printf("\n");
610                         }
611                         printf("(%d,%d,%d) -> (%d,%d,%d) gives %f\n",
612                                 in_y[1], in_y[8], in_y[9], 
613                                 in_cb[1], in_cb[8], in_cb[9],
614                                 cb_cfl_fac);
615
616                         y0 = in_cr[1];
617                         y1 = in_cr[8];
618                         y2 = in_cr[9];
619                         double cr_cfl_fac = (x0 * y0 + x1 * y1 + x2 * y2) / denom;
620                         //double cr_cfl_fac = (x1 * y1) / denom;
621                         printf("cb CfL = %7.3f  dc = %5d    cr CfL = %7.3f  dc = %d\n",
622                                 cb_cfl_fac, in_cb[0] - in_y[0],
623                                 cr_cfl_fac, in_cr[0] - in_y[0]);
624
625                         if (denom == 0.0) { cb_cfl_fac = cr_cfl_fac = 0.0; }
626
627                         // CHEAT
628                         //last_cb_cfl_fac = cb_cfl_fac;
629                         //last_cr_cfl_fac = cr_cfl_fac;
630
631                         for (unsigned coeff_idx = 1; coeff_idx < 64; ++coeff_idx) {
632                                 //printf("%2d: cb = %3d prediction = %f * %3d = %7.3f\n", coeff_idx, in_cb[coeff_idx], last_cb_cfl_fac, in_y[coeff_idx], last_cb_cfl_fac * in_y[coeff_idx]);
633                                 //printf("%2d: cr = %3d prediction = %f * %3d = %7.3f\n", coeff_idx, in_cr[coeff_idx], last_cr_cfl_fac, in_y[coeff_idx], last_cr_cfl_fac * in_y[coeff_idx]);
634                                 double cb_pred = last_cb_cfl_fac * in_y[coeff_idx];
635                                 chroma_energy += in_cb[coeff_idx] * in_cb[coeff_idx];
636                                 chroma_energy_pred += (in_cb[coeff_idx] - cb_pred) * (in_cb[coeff_idx] - cb_pred);
637
638                                 //in_cb[coeff_idx] -= lrint(last_cb_cfl_fac * in_y[coeff_idx]);
639                                 //in_cr[coeff_idx] -= lrint(last_cr_cfl_fac * in_y[coeff_idx]);
640                                 //in_cr[coeff_idx] -= lrint(last_cr_cfl_fac * in_y[coeff_idx]);
641                                 //in_cb[coeff_idx] = lrint(in_y[coeff_idx] * (1.0 / sqrt(2)));
642                                 //in_cr[coeff_idx] = lrint(in_y[coeff_idx] * (1.0 / sqrt(2)));
643                                 //in_cb[coeff_idx] = lrint(in_y[coeff_idx]);
644                                 //in_cr[coeff_idx] = lrint(in_y[coeff_idx]);
645                         }
646                         //in_cb[0] += 1024;
647                         //in_cr[0] += 1024;
648                         //in_cb[0] -= in_y[0];
649                         //in_cr[0] -= in_y[0];
650 #endif
651
652                         for (unsigned y = 0; y < 8; ++y) {
653                                 for (unsigned x = 0; x < 8; ++x) {
654                                         int coeff_idx = y * 8 + x;
655                                         int k_cb = quantize(in_cb[coeff_idx], coeff_idx);
656                                         coeff_cb[(yb + y) * (WIDTH/2) + (xb + x)] = k_cb;
657                                         int k_cr = quantize(in_cr[coeff_idx], coeff_idx);
658                                         coeff_cr[(yb + y) * (WIDTH/2) + (xb + x)] = k_cr;
659
660                                         // Store back for reconstruction / PSNR calculation
661                                         in_cb[coeff_idx] = unquantize(k_cb, coeff_idx);
662                                         in_cr[coeff_idx] = unquantize(k_cr, coeff_idx);
663                                 }
664                         }
665
666                         idct_int32(in_y);  // DEBUG
667                         idct_int32(in_cb);
668                         idct_int32(in_cr);
669
670                         for (unsigned y = 0; y < 8; ++y) {
671                                 for (unsigned x = 0; x < 8; ++x) {
672                                         pix_cb[(yb + y) * (WIDTH/2) + (xb + x)] = clamp(in_cb[y * 8 + x]);
673                                         pix_cr[(yb + y) * (WIDTH/2) + (xb + x)] = clamp(in_cr[y * 8 + x]);
674
675                         //              pix_cb[(yb + y) * (WIDTH/2) + (xb + x)] = in_y[y * 8 + x];
676                         //              pix_cr[(yb + y) * (WIDTH/2) + (xb + x)] = in_y[y * 8 + x];
677                                 }
678                         }
679
680 #if 0
681                         last_cb_cfl_fac = cb_cfl_fac;
682                         last_cr_cfl_fac = cr_cfl_fac;
683 #endif
684                 }
685         }
686
687 #if 0
688         printf("chroma_energy = %f, with_pred = %f\n",
689                 chroma_energy / (WIDTH * HEIGHT), chroma_energy_pred / (WIDTH * HEIGHT));
690 #endif
691
692         // DC coefficient pred from the right to left
693         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
694                 for (unsigned xb = 0; xb < WIDTH - 8; xb += 8) {
695                         coeff_y[yb * WIDTH + xb] -= coeff_y[yb * WIDTH + (xb + 8)];
696                 }
697         }
698
699         FILE *fp = fopen("reconstructed.pgm", "wb");
700         fprintf(fp, "P5\n%d %d\n255\n", WIDTH, HEIGHT);
701         fwrite(pix_y, 1, WIDTH * HEIGHT, fp);
702         fclose(fp);
703
704         fp = fopen("reconstructed.pnm", "wb");
705         fprintf(fp, "P6\n%d %d\n255\n", WIDTH, HEIGHT);
706         for (unsigned yb = 0; yb < HEIGHT; ++yb) {
707                 for (unsigned xb = 0; xb < WIDTH; ++xb) {
708                         int y = pix_y[(yb * WIDTH) + xb];
709                         int cb, cr;
710                         int c0 = yb * (WIDTH/2) + xb/2;
711                         if (xb % 2 == 0) {
712                                 cb = pix_cb[c0] - 128.0;
713                                 cr = pix_cr[c0] - 128.0;
714                         } else {
715                                 int c1 = yb * (WIDTH/2) + std::min<int>(xb/2 + 1, WIDTH/2 - 1);
716                                 cb = 0.5 * (pix_cb[c0] + pix_cb[c1]) - 128.0;
717                                 cr = 0.5 * (pix_cr[c0] + pix_cr[c1]) - 128.0;
718                         }
719
720                         double r = y + 1.5748 * cr;
721                         double g = y - 0.1873 * cb - 0.4681 * cr;
722                         double b = y + 1.8556 * cb;
723
724                         putc(clamp(lrint(r)), fp);
725                         putc(clamp(lrint(g)), fp);
726                         putc(clamp(lrint(b)), fp);
727                 }
728         }
729         fclose(fp);
730
731         // For each coefficient, make some tables.
732         size_t extra_bits = 0;
733         for (unsigned i = 0; i < 64; ++i) {
734                 stats[i].clear();
735         }
736         for (unsigned y = 0; y < 8; ++y) {
737                 for (unsigned x = 0; x < 8; ++x) {
738                         SymbolStats &s_luma = stats[pick_stats_for(x, y, false)];
739                         SymbolStats &s_chroma = stats[pick_stats_for(x, y, true)];
740
741                         // Luma
742                         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
743                                 for (unsigned xb = 0; xb < WIDTH; xb += 8) {
744                                         unsigned short k = abs(coeff_y[(yb + y) * WIDTH + (xb + x)]);
745                                         if (k >= ESCAPE_LIMIT) {
746                                                 k = ESCAPE_LIMIT;
747                                                 extra_bits += 12;  // escape this one
748                                         }
749                                         ++s_luma.freqs[(k - 1) & 255];
750                                 }
751                         }
752                         // Chroma
753                         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
754                                 for (unsigned xb = 0; xb < WIDTH/2; xb += 8) {
755                                         unsigned short k_cb = abs(coeff_cb[(yb + y) * WIDTH/2 + (xb + x)]);
756                                         unsigned short k_cr = abs(coeff_cr[(yb + y) * WIDTH/2 + (xb + x)]);
757                                         if (k_cb >= ESCAPE_LIMIT) {
758                                                 k_cb = ESCAPE_LIMIT;
759                                                 extra_bits += 12;  // escape this one
760                                         }
761                                         if (k_cr >= ESCAPE_LIMIT) {
762                                                 k_cr = ESCAPE_LIMIT;
763                                                 extra_bits += 12;  // escape this one
764                                         }
765                                         ++s_chroma.freqs[(k_cb - 1) & 255];
766                                         ++s_chroma.freqs[(k_cr - 1) & 255];
767                                 }
768                         }
769                 }
770         }
771
772 #if FIND_OPTIMAL_STREAM_ASSIGNMENT
773         printf("Luma:\n");
774         find_optimal_stream_assignment(0);
775         printf("Chroma:\n");
776         find_optimal_stream_assignment(64);
777         exit(0);
778 #endif
779
780         for (unsigned i = 0; i < 64; ++i) {
781                 stats[i].freqs[255] /= 2;  // zero, has no sign bits (yes, this is trickery)
782                 stats[i].normalize_freqs(prob_scale);
783                 stats[i].cum_freqs[256] += stats[i].freqs[255];
784                 stats[i].freqs[255] *= 2;
785         }
786
787         FILE *codedfp = fopen("coded.dat", "wb");
788         if (codedfp == nullptr) {
789                 perror("coded.dat");
790                 exit(1);
791         }
792
793         // TODO: rather gamma-k or something
794         for (unsigned i = 0; i < 64; ++i) {
795                 if (stats[i].cum_freqs[NUM_SYMS] == 0) {
796                         continue;
797                 }
798                 printf("writing table %d\n", i);
799                 for (unsigned j = 0; j < NUM_SYMS; ++j) {
800                         write_varint(stats[i].freqs[j], codedfp);
801                 }
802         }
803
804         RansEncoder rans_encoder;
805
806         size_t tot_bytes = 0;
807
808         // Luma
809         for (unsigned y = 0; y < 8; ++y) {
810                 for (unsigned x = 0; x < 8; ++x) {
811                         SymbolStats &s_luma = stats[pick_stats_for(x, y, false)];
812                         rans_encoder.init_prob(s_luma);
813
814                         // Luma
815                         std::vector<int> lens;
816
817                         // need to reverse later
818                         rans_encoder.clear();
819                         size_t num_bytes = 0;
820                         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
821                                 for (unsigned xb = 0; xb < WIDTH; xb += 8) {
822                                         int k = coeff_y[(yb + y) * WIDTH + (xb + x)];
823                                         //printf("encoding coeff %d xb,yb=%d,%d: %d\n", y*8+x, xb, yb, k);
824                                         rans_encoder.encode_coeff(k);
825                                 }
826                                 if (yb % 16 == 8) {
827                                         int l = rans_encoder.save_block(codedfp);
828                                         num_bytes += l;
829                                         lens.push_back(l);
830                                 }
831                         }
832                         if (HEIGHT % 16 != 0) {
833                                 num_bytes += rans_encoder.save_block(codedfp);
834                         }
835                         tot_bytes += num_bytes;
836                         printf("coeff %d Y': %ld bytes\n", y * 8 + x, num_bytes);
837
838                         double sum_l = 0.0;
839                         for (int l : lens) {
840                                 sum_l += l;
841                         }
842                         double avg_l = sum_l / lens.size();
843
844                         double sum_sql = 0.0;
845                         for (int l : lens) {
846                                 sum_sql += (l - avg_l) * (l - avg_l);
847                         }
848                         double stddev_l = sqrt(sum_sql / (lens.size() - 1));
849                         printf("coeff %d: avg=%.2f bytes, stddev=%.2f bytes\n", y*8+x, avg_l, stddev_l);
850                 }
851         }
852
853         // Cb
854         for (unsigned y = 0; y < 8; ++y) {
855                 for (unsigned x = 0; x < 8; ++x) {
856                         SymbolStats &s_chroma = stats[pick_stats_for(x, y, true)];
857                         rans_encoder.init_prob(s_chroma);
858
859                         rans_encoder.clear();
860                         size_t num_bytes = 0;
861                         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
862                                 for (unsigned xb = 0; xb < WIDTH/2; xb += 8) {
863                                         int k = coeff_cb[(yb + y) * WIDTH/2 + (xb + x)];
864                                         rans_encoder.encode_coeff(k);
865                                 }
866                                 if (yb % 16 == 8) {
867                                         num_bytes += rans_encoder.save_block(codedfp);
868                                 }
869                         }
870                         if (HEIGHT % 16 != 0) {
871                                 num_bytes += rans_encoder.save_block(codedfp);
872                         }
873                         tot_bytes += num_bytes;
874                         printf("coeff %d Cb: %ld bytes\n", y * 8 + x, num_bytes);
875                 }
876         }
877
878         // Cr
879         for (unsigned y = 0; y < 8; ++y) {
880                 for (unsigned x = 0; x < 8; ++x) {
881                         SymbolStats &s_chroma = stats[pick_stats_for(x, y, true)];
882                         rans_encoder.init_prob(s_chroma);
883
884                         rans_encoder.clear();
885                         size_t num_bytes = 0;
886                         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
887                                 for (unsigned xb = 0; xb < WIDTH/2; xb += 8) {
888                                         int k = coeff_cr[(yb + y) * WIDTH/2 + (xb + x)];
889                                         rans_encoder.encode_coeff(k);
890                                 }
891                                 if (yb % 16 == 8) {
892                                         num_bytes += rans_encoder.save_block(codedfp);
893                                 }
894                         }
895                         if (HEIGHT % 16 != 0) {
896                                 num_bytes += rans_encoder.save_block(codedfp);
897                         }
898                         tot_bytes += num_bytes;
899                         printf("coeff %d Cr: %ld bytes\n", y * 8 + x, num_bytes);
900                 }
901         }
902
903         printf("%ld bytes + %ld escape bits (%ld) = %ld total bytes\n",
904                 tot_bytes - extra_bits / 8,
905                 extra_bits,
906                 extra_bits / 8,
907                 tot_bytes);
908 }