]> git.sesse.net Git - narabu/blob - qdc.cpp
Make num_blocks a uniform.
[narabu] / qdc.cpp
1 #include <stdio.h>
2 #include <stdint.h>
3 #include <stdlib.h>
4 #include <string.h>
5 #include <assert.h>
6 #include <math.h>
7
8 //#include "ryg_rans/rans64.h"
9 #include "ryg_rans/rans_byte.h"
10 #include "ryg_rans/renormalize.h"
11
12 #include <algorithm>
13 #include <memory>
14 #include <numeric>
15 #include <random>
16 #include <vector>
17 #include <unordered_map>
18
19 #define WIDTH 1280
20 #define HEIGHT 720
21 #define WIDTH_BLOCKS (WIDTH/8)
22 #define WIDTH_BLOCKS_CHROMA (WIDTH/16)
23 #define HEIGHT_BLOCKS (HEIGHT/8)
24 #define NUM_BLOCKS (WIDTH_BLOCKS * HEIGHT_BLOCKS)
25 #define NUM_BLOCKS_CHROMA (WIDTH_BLOCKS_CHROMA * HEIGHT_BLOCKS)
26 #define NUM_SYMS 256
27 #define ESCAPE_LIMIT (NUM_SYMS - 1)
28
29 // If you set this to 1, the program will try to optimize the placement
30 // of coefficients to rANS probability distributions. This is randomized,
31 // so you might want to run it a few times.
32 #define FIND_OPTIMAL_STREAM_ASSIGNMENT 0
33 #define NUM_CLUSTERS 4
34
35 static constexpr uint32_t prob_bits = 12;
36 static constexpr uint32_t prob_scale = 1 << prob_bits;
37
38 using namespace std;
39
40 void fdct_int32(short *const In);
41 void idct_int32(short *const In);
42
43 unsigned char rgb[WIDTH * HEIGHT * 3];
44 unsigned char pix_y[WIDTH * HEIGHT];
45 unsigned char pix_cb[(WIDTH/2) * HEIGHT];
46 unsigned char pix_cr[(WIDTH/2) * HEIGHT];
47 unsigned char full_pix_cb[WIDTH * HEIGHT];
48 unsigned char full_pix_cr[WIDTH * HEIGHT];
49 short coeff_y[WIDTH * HEIGHT], coeff_cb[(WIDTH/2) * HEIGHT], coeff_cr[(WIDTH/2) * HEIGHT];
50
51 int clamp(int x)
52 {
53         if (x < 0) return 0;
54         if (x > 255) return 255;
55         return x;
56 }
57
58 static const unsigned char std_luminance_quant_tbl[64] = {
59 #if 0
60         16,  11,  10,  16,  24,  40,  51,  61,
61         12,  12,  14,  19,  26,  58,  60,  55,
62         14,  13,  16,  24,  40,  57,  69,  56,
63         14,  17,  22,  29,  51,  87,  80,  62,
64         18,  22,  37,  56,  68, 109, 103,  77,
65         24,  35,  55,  64,  81, 104, 113,  92,
66         49,  64,  78,  87, 103, 121, 120, 101,
67         72,  92,  95,  98, 112, 100, 103,  99
68 #else
69         // ff_mpeg1_default_intra_matrix
70          8, 16, 19, 22, 26, 27, 29, 34,
71         16, 16, 22, 24, 27, 29, 34, 37,                                                 
72         19, 22, 26, 27, 29, 34, 34, 38,                                                 
73         22, 22, 26, 27, 29, 34, 37, 40,
74         22, 26, 27, 29, 32, 35, 40, 48,
75         26, 27, 29, 32, 35, 40, 48, 58,
76         26, 27, 29, 34, 38, 46, 56, 69,
77         27, 29, 35, 38, 46, 56, 69, 83
78 #endif
79 };
80
81 struct SymbolStats
82 {
83     uint32_t freqs[NUM_SYMS];
84     uint32_t cum_freqs[NUM_SYMS + 1];
85
86     void clear();
87     void calc_cum_freqs();
88     void normalize_freqs(uint32_t target_total);
89 };
90
91 void SymbolStats::clear()
92 {
93     for (int i=0; i < NUM_SYMS; i++)
94         freqs[i] = 0;
95 }
96
97 void SymbolStats::calc_cum_freqs()
98 {
99     cum_freqs[0] = 0;
100     for (int i=0; i < NUM_SYMS; i++)
101         cum_freqs[i+1] = cum_freqs[i] + freqs[i];
102 }
103
104 void SymbolStats::normalize_freqs(uint32_t target_total)
105 {
106     uint64_t real_freq[NUM_SYMS + 1];  // hack
107
108     assert(target_total >= NUM_SYMS);
109
110     calc_cum_freqs();
111     uint32_t cur_total = cum_freqs[NUM_SYMS];
112
113     if (cur_total == 0) return;
114
115     double ideal_cost = 0.0;
116     for (int i = 1; i <= NUM_SYMS; i++)
117     {
118       real_freq[i] = cum_freqs[i] - cum_freqs[i - 1];
119       if (real_freq[i] > 0)
120         ideal_cost -= real_freq[i] * log2(real_freq[i] / double(cur_total));
121     }
122
123     OptimalRenormalize(cum_freqs, NUM_SYMS, prob_scale);
124
125     // calculate updated freqs and make sure we didn't screw anything up
126     assert(cum_freqs[0] == 0 && cum_freqs[NUM_SYMS] == target_total);
127     for (int i=0; i < NUM_SYMS; i++) {
128         if (freqs[i] == 0)
129             assert(cum_freqs[i+1] == cum_freqs[i]);
130         else
131             assert(cum_freqs[i+1] > cum_freqs[i]);
132
133         // calc updated freq
134         freqs[i] = cum_freqs[i+1] - cum_freqs[i];
135     }
136
137     double calc_cost = 0.0;
138     for (int i = 1; i <= NUM_SYMS; i++)
139     {
140       uint64_t freq = cum_freqs[i] - cum_freqs[i - 1];
141       if (real_freq[i] > 0)
142         calc_cost -= real_freq[i] * log2(freq / double(target_total));
143     }
144
145     static double total_loss = 0.0;
146     total_loss += calc_cost - ideal_cost;
147     static double total_loss_with_dp = 0.0;
148         double optimal_cost = 0.0;
149     //total_loss_with_dp += optimal_cost - ideal_cost;
150     printf("ideal cost = %.0f bits, DP cost = %.0f bits, calc cost = %.0f bits (loss = %.2f bytes, total loss = %.2f bytes, total loss with DP = %.2f bytes)\n",
151                 ideal_cost, optimal_cost,
152                  calc_cost, (calc_cost - ideal_cost) / 8.0, total_loss / 8.0, total_loss_with_dp / 8.0);
153 }
154
155 SymbolStats stats[128];
156
157 #if FIND_OPTIMAL_STREAM_ASSIGNMENT
158 // Distance from one stream to the other, based on a hacked-up K-L divergence.
159 float kl_dist[64][64];
160 #endif
161
162 const int luma_mapping[64] = {
163         0, 0, 1, 1, 2, 2, 3, 3,
164         0, 0, 1, 2, 2, 2, 3, 3,
165         1, 1, 2, 2, 2, 3, 3, 3,
166         1, 1, 2, 2, 2, 3, 3, 3,
167         1, 2, 2, 2, 2, 3, 3, 3,
168         2, 2, 2, 2, 3, 3, 3, 3,
169         2, 2, 3, 3, 3, 3, 3, 3,
170         3, 3, 3, 3, 3, 3, 3, 3,
171 };
172 const int chroma_mapping[64] = {
173         0, 1, 1, 2, 2, 2, 3, 3,
174         1, 1, 2, 2, 2, 3, 3, 3,
175         2, 2, 2, 2, 3, 3, 3, 3,
176         2, 2, 2, 3, 3, 3, 3, 3,
177         2, 3, 3, 3, 3, 3, 3, 3,
178         3, 3, 3, 3, 3, 3, 3, 3,
179         3, 3, 3, 3, 3, 3, 3, 3,
180         3, 3, 3, 3, 3, 3, 3, 3,
181 };
182
183 int pick_stats_for(int x, int y, bool is_chroma)
184 {
185 #if FIND_OPTIMAL_STREAM_ASSIGNMENT
186         return y * 8 + x + is_chroma * 64;
187 #else
188         if (is_chroma) {
189                 return chroma_mapping[y * 8 + x] + 4;
190         } else {
191                 return luma_mapping[y * 8 + x];
192         }
193 #endif
194 }
195                 
196
197 void write_varint(int x, FILE *fp)
198 {
199         while (x >= 128) {
200                 putc((x & 0x7f) | 0x80, fp);
201                 x >>= 7;
202         }
203         putc(x, fp);
204 }
205
206 class RansEncoder {
207 public:
208         RansEncoder()
209         {
210                 out_buf.reset(new uint8_t[out_max_size]);
211                 clear();
212         }
213
214         void init_prob(SymbolStats &s)
215         {
216                 for (int i = 0; i < NUM_SYMS; i++) {
217                         printf("%d: cumfreqs=%d freqs=%d prob_bits=%d\n", i, s.cum_freqs[i], s.freqs[i], prob_bits + 1);
218                         RansEncSymbolInit(&esyms[i], s.cum_freqs[i], s.freqs[i], prob_bits + 1);
219                 }
220                 sign_bias = s.cum_freqs[NUM_SYMS];
221         }
222
223         void clear()
224         {
225                 out_end = out_buf.get() + out_max_size;
226                 ptr = out_end; // *end* of output buffer
227                 RansEncInit(&rans);
228         }
229
230         uint32_t save_block(FILE *codedfp)  // Returns number of bytes.
231         {
232                 RansEncFlush(&rans, &ptr);
233                 //printf("post-flush = %08x\n", rans);
234
235                 uint32_t num_rans_bytes = out_end - ptr;
236 #if 0
237                 if (num_rans_bytes == 4) {
238                         uint32_t block;
239                         memcpy(&block, ptr, 4);
240
241                         if (block == last_block) {
242                                 write_varint(0, codedfp);
243                                 clear();
244                                 return 1;
245                         }
246
247                         last_block = block;
248                 } else {
249                         last_block = 0;
250                 }
251 #endif
252
253                 write_varint(num_rans_bytes, codedfp);
254                 //fwrite(&num_rans_bytes, 1, 4, codedfp);
255                 fwrite(ptr, 1, num_rans_bytes, codedfp);
256
257                 //printf("first rANS bytes: %02x %02x %02x %02x %02x %02x %02x %02x\n", ptr[0], ptr[1], ptr[2], ptr[3], ptr[4], ptr[5], ptr[6], ptr[7]);
258
259
260                 clear();
261
262                 printf("Saving block: %d rANS bytes\n", num_rans_bytes);
263                 return num_rans_bytes;
264                 //return num_rans_bytes;
265         }
266
267         void encode_coeff(short signed_k)
268         {
269                 //printf("encoding coeff %d (sym %d), rans before encoding = %08x\n", signed_k, ((abs(signed_k) - 1) & 255), rans);
270                 unsigned short k = abs(signed_k);
271                 if (k >= ESCAPE_LIMIT) {
272                         // Put the coefficient as a 1/(2^12) symbol _before_
273                         // the 255 coefficient, since the decoder will read the
274                         // 255 coefficient first.
275                         RansEncPut(&rans, &ptr, k, 1, prob_bits);
276                         k = ESCAPE_LIMIT;
277                 }
278                 RansEncPutSymbol(&rans, &ptr, &esyms[(k - 1) & (NUM_SYMS - 1)]);
279                 if (signed_k < 0) {
280                         rans += sign_bias;
281                 }
282         }
283
284 private:
285         static constexpr size_t out_max_size = 32 << 20; // 32 MB.
286         static constexpr size_t max_num_sign = 1048576;  // Way too big. And actually bytes.
287
288         unique_ptr<uint8_t[]> out_buf;
289         uint8_t *out_end;
290         uint8_t *ptr;
291         RansState rans;
292         RansEncSymbol esyms[NUM_SYMS];
293         uint32_t sign_bias;
294
295         uint32_t last_block = 0;  // Not a valid 4-byte rANS block (?)
296 };
297
298 static constexpr int dc_scalefac = 8;  // Matches the FDCT's gain.
299 static constexpr double quant_scalefac = 4.0;  // whatever?
300
301 static inline int quantize(int f, int coeff_idx)
302 {
303         if (coeff_idx == 0) {
304                 return f / dc_scalefac;
305         }
306         if (f == 0) {
307                 return 0;
308         }
309
310         const int w = std_luminance_quant_tbl[coeff_idx];
311         const int s = quant_scalefac;
312         int sign_f = (f > 0) ? 1 : -1;
313         return (32 * f + sign_f * w * s) / (2 * w * s);
314 }
315
316 static inline int unquantize(int qf, int coeff_idx)
317 {
318         if (coeff_idx == 0) {
319                 return qf * dc_scalefac;
320         }
321         if (qf == 0) {
322                 return 0;
323         }
324
325         const int w = std_luminance_quant_tbl[coeff_idx];
326         const int s = quant_scalefac;
327         return (2 * qf * w * s) / 32;
328 }
329
330 void readpix(unsigned char *ptr, const char *filename)
331 {
332         FILE *fp = fopen(filename, "rb");
333         if (fp == nullptr) {
334                 perror(filename);
335                 exit(1);
336         }
337
338         fseek(fp, 0, SEEK_END);
339         long len = ftell(fp);
340         assert(len >= WIDTH * HEIGHT * 3);
341         fseek(fp, len - WIDTH * HEIGHT * 3, SEEK_SET);
342
343         fread(ptr, 1, WIDTH * HEIGHT * 3, fp);
344         fclose(fp);
345 }
346
347 void convert_ycbcr()
348 {
349         double coeff[3] = { 0.2126, 0.7152, 0.0722 };  // sum = 1.0
350         double cb_fac = 1.0 / (coeff[0] + coeff[1] + 1.0f - coeff[2]);  // 0.539
351         double cr_fac = 1.0 / (1.0f - coeff[0] + coeff[1] + coeff[2]);  // 0.635 
352
353         unique_ptr<float[]> temp_cb(new float[WIDTH * HEIGHT]);
354         unique_ptr<float[]> temp_cr(new float[WIDTH * HEIGHT]);
355         for (unsigned yb = 0; yb < HEIGHT; ++yb) {
356                 for (unsigned xb = 0; xb < WIDTH; ++xb) {
357                         int r = rgb[((yb * WIDTH) + xb) * 3 + 0];
358                         int g = rgb[((yb * WIDTH) + xb) * 3 + 1];
359                         int b = rgb[((yb * WIDTH) + xb) * 3 + 2];
360                         double y = std::min(std::max(coeff[0] * r + coeff[1] * g + coeff[2] * b, 0.0), 255.0);
361                         double cb = (b - y) * cb_fac + 128.0;
362                         double cr = (r - y) * cr_fac + 128.0;
363                         pix_y[(yb * WIDTH) + xb] = lrint(y);
364                         temp_cb[(yb * WIDTH) + xb] = cb;
365                         temp_cr[(yb * WIDTH) + xb] = cr;
366                         full_pix_cb[(yb * WIDTH) + xb] = lrint(std::min(std::max(cb, 0.0), 255.0));
367                         full_pix_cr[(yb * WIDTH) + xb] = lrint(std::min(std::max(cr, 0.0), 255.0));
368                 }
369         }
370
371         // Simple 4:2:2 subsampling with left convention.
372         for (unsigned yb = 0; yb < HEIGHT; ++yb) {
373                 for (unsigned xb = 0; xb < WIDTH / 2; ++xb) {
374                         int c0 = yb * WIDTH + std::max(int(xb) * 2 - 1, 0);
375                         int c1 = yb * WIDTH + xb * 2;
376                         int c2 = yb * WIDTH + xb * 2 + 1;
377                         
378                         double cb = 0.25 * temp_cb[c0] + 0.5 * temp_cb[c1] + 0.25 * temp_cb[c2];
379                         double cr = 0.25 * temp_cr[c0] + 0.5 * temp_cr[c1] + 0.25 * temp_cr[c2];
380                         cb = std::min(std::max(cb, 0.0), 255.0);
381                         cr = std::min(std::max(cr, 0.0), 255.0);
382                         pix_cb[(yb * WIDTH/2) + xb] = lrint(cb);
383                         pix_cr[(yb * WIDTH/2) + xb] = lrint(cr);
384                 }
385         }
386 }
387
388 #if FIND_OPTIMAL_STREAM_ASSIGNMENT
389 double find_best_assignment(const int *medoids, int *assignment)
390 {
391         double current_score = 0.0;
392         for (int i = 0; i < 64; ++i) {
393                 int best_medoid = medoids[0];
394                 float best_medoid_score = kl_dist[i][medoids[0]];
395                 for (int j = 1; j < NUM_CLUSTERS; ++j) {
396                         if (kl_dist[i][medoids[j]] < best_medoid_score) {
397                                 best_medoid = medoids[j];
398                                 best_medoid_score = kl_dist[i][medoids[j]];
399                         }
400                 }
401                 assignment[i] = best_medoid;
402                 current_score += best_medoid_score;
403         }
404         return current_score;
405 }
406
407 void find_optimal_stream_assignment(int base)
408 {
409         double inv_sum[64];
410         for (unsigned i = 0; i < 64; ++i) {
411                 double s = 0.0;
412                 for (unsigned k = 0; k < NUM_SYMS; ++k) {
413                         s += stats[i + base].freqs[k] + 0.5;
414                 }
415                 inv_sum[i] = 1.0 / s;
416         }
417
418         for (unsigned i = 0; i < 64; ++i) {
419                 for (unsigned j = 0; j < 64; ++j) {
420                         double d = 0.0;
421                         for (unsigned k = 0; k < NUM_SYMS; ++k) {
422                                 double p1 = (stats[i + base].freqs[k] + 0.5) * inv_sum[i];
423                                 double p2 = (stats[j + base].freqs[k] + 0.5) * inv_sum[j];
424
425                                 // K-L divergence is asymmetric; this is a hack.
426                                 d += p1 * log(p1 / p2);
427                                 d += p2 * log(p2 / p1);
428                         }
429                         kl_dist[i][j] = d;
430                         //printf("%.3f ", d);
431                 }
432                 //printf("\n");
433         }
434
435         // k-medoids init
436         int medoids[64];  // only the first NUM_CLUSTERS matter
437         bool is_medoid[64] = { false };
438         std::iota(medoids, medoids + 64, 0);
439         std::random_device rd;
440         std::mt19937 g(rd());
441         std::shuffle(medoids, medoids + 64, g);
442         for (int i = 0; i < NUM_CLUSTERS; ++i) {
443                 printf("%d ", medoids[i]);
444                 is_medoid[medoids[i]] = true;
445         }
446         printf("\n");
447
448         // assign each data point to the closest medoid
449         int assignment[64];
450         double current_score = find_best_assignment(medoids, assignment);
451
452         for (int i = 0; i < 1000; ++i) {
453                 printf("iter %d\n", i);
454                 bool any_changed = false;
455                 for (int m = 0; m < NUM_CLUSTERS; ++m) {
456                         for (int o = 0; o < 64; ++o) {
457                                 if (is_medoid[o]) continue;
458                                 int old_medoid = medoids[m];
459                                 medoids[m] = o;
460
461                                 int new_assignment[64];
462                                 double candidate_score = find_best_assignment(medoids, new_assignment);
463
464                                 if (candidate_score < current_score) {
465                                         current_score = candidate_score;
466                                         memcpy(assignment, new_assignment, sizeof(assignment));
467
468                                         is_medoid[old_medoid] = false;
469                                         is_medoid[medoids[m]] = true;
470                                         printf("%f: ", current_score);
471                                         for (int i = 0; i < 64; ++i) {
472                                                 printf("%d ", assignment[i]);
473                                         }
474                                         printf("\n");
475                                         any_changed = true;
476                                 } else {
477                                         medoids[m] = old_medoid;
478                                 }
479                         }
480                 }
481                 if (!any_changed) break;
482         }
483         printf("\n");
484         std::unordered_map<int, int> rmap;
485         for (int i = 0; i < 64; ++i) {
486                 if (i % 8 == 0) printf("\n");
487                 if (!rmap.count(assignment[i])) {
488                         rmap.emplace(assignment[i], rmap.size());
489                 }
490                 printf("%d, ", rmap[assignment[i]]);
491         }
492         printf("\n");
493 }
494 #endif
495
496 int main(int argc, char **argv)
497 {
498         if (argc >= 2)
499                 readpix(rgb, argv[1]);
500         else
501                 readpix(rgb, "color.pnm");
502         convert_ycbcr();
503
504         double sum_sq_err = 0.0;
505         //double last_cb_cfl_fac = 0.0;
506         //double last_cr_cfl_fac = 0.0;
507
508         // DCT and quantize luma
509         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
510                 for (unsigned xb = 0; xb < WIDTH; xb += 8) {
511                         // Read one block
512                         short in_y[64];
513                         for (unsigned y = 0; y < 8; ++y) {
514                                 for (unsigned x = 0; x < 8; ++x) {
515                                         in_y[y * 8 + x] = pix_y[(yb + y) * WIDTH + (xb + x)];
516                                 }
517                         }
518
519                         // FDCT it
520                         fdct_int32(in_y);
521
522                         for (unsigned y = 0; y < 8; ++y) {
523                                 for (unsigned x = 0; x < 8; ++x) {
524                                         int coeff_idx = y * 8 + x;
525                                         int k = quantize(in_y[coeff_idx], coeff_idx);
526                                         coeff_y[(yb + y) * WIDTH + (xb + x)] = k;
527
528                                         // Store back for reconstruction / PSNR calculation
529                                         in_y[coeff_idx] = unquantize(k, coeff_idx);
530                                 }
531                         }
532
533                         idct_int32(in_y);
534
535                         for (unsigned y = 0; y < 8; ++y) {
536                                 for (unsigned x = 0; x < 8; ++x) {
537                                         int k = clamp(in_y[y * 8 + x]);
538                                         uint8_t *ptr = &pix_y[(yb + y) * WIDTH + (xb + x)];
539                                         sum_sq_err += (*ptr - k) * (*ptr - k);
540                                         *ptr = k;
541                                 }
542                         }
543                 }
544         }
545         double mse = sum_sq_err / double(WIDTH * HEIGHT);
546         double psnr_db = 20 * log10(255.0 / sqrt(mse));
547         printf("psnr = %.2f dB\n", psnr_db);
548
549         //double chroma_energy = 0.0, chroma_energy_pred = 0.0;
550
551         // DCT and quantize chroma
552         //double last_cb_cfl_fac = 0.0, last_cr_cfl_fac = 0.0;
553         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
554                 for (unsigned xb = 0; xb < WIDTH/2; xb += 8) {
555 #if 0
556                         // TF switch: Two 8x8 luma blocks -> one 16x8 block, then drop high frequencies
557                         printf("in blocks:\n");
558                         for (unsigned y = 0; y < 8; ++y) {
559                                 for (unsigned x = 0; x < 8; ++x) {
560                                         short a = coeff_y[(yb + y) * WIDTH + (xb*2 + x)];
561                                         printf(" %4d", a);
562                                 }
563                                 printf(" | ");
564                                 for (unsigned x = 0; x < 8; ++x) {
565                                         short b = coeff_y[(yb + y) * WIDTH + (xb*2 + x + 8)];
566                                         printf(" %4d", b);
567                                 }
568                                 printf("\n");
569                         }
570
571                         short in_y[64];
572                         for (unsigned y = 0; y < 8; ++y) {
573                                 for (unsigned x = 0; x < 4; ++x) {
574                                         short a = coeff_y[(yb + y) * WIDTH + (xb*2 + x)];
575                                         short b = coeff_y[(yb + y) * WIDTH + (xb*2 + x + 8)];
576                                         b = a - b;
577                                         a = 2 * a - b;
578                                         in_y[y * 8 + x * 2 + 0] = a;
579                                         in_y[y * 8 + x * 2 + 1] = b;
580                                 }
581                         }
582
583                         printf("tf-ed block:\n");
584                         for (unsigned y = 0; y < 8; ++y) {
585                                 for (unsigned x = 0; x < 8; ++x) {
586                                         short a = in_y[y * 8 + x];
587                                         printf(" %4d", a);
588                                 }
589                                 printf("\n");
590                         }
591 #else
592                         // Read Y block with no tf switch (from reconstructed luma)
593                         short in_y[64];
594                         for (unsigned y = 0; y < 8; ++y) {
595                                 for (unsigned x = 0; x < 8; ++x) {
596                                         in_y[y * 8 + x] = pix_y[(yb + y) * (WIDTH) + (xb + x) * 2];
597                                 }
598                         }
599                         fdct_int32(in_y);
600 #endif
601
602                         // Read one block
603                         short in_cb[64], in_cr[64];
604                         for (unsigned y = 0; y < 8; ++y) {
605                                 for (unsigned x = 0; x < 8; ++x) {
606                                         in_cb[y * 8 + x] = pix_cb[(yb + y) * (WIDTH/2) + (xb + x)];
607                                         in_cr[y * 8 + x] = pix_cr[(yb + y) * (WIDTH/2) + (xb + x)];
608                                 }
609                         }
610
611                         // FDCT it
612                         fdct_int32(in_cb);
613                         fdct_int32(in_cr);
614
615 #if 0
616                         // Chroma from luma
617                         double x0 = in_y[1];
618                         double x1 = in_y[8];
619                         double x2 = in_y[9];
620                         double denom = (x0 * x0 + x1 * x1 + x2 * x2);
621                         //double denom = (x1 * x1);
622         
623                         double y0 = in_cb[1];
624                         double y1 = in_cb[8];
625                         double y2 = in_cb[9];
626                         double cb_cfl_fac = (x0 * y0 + x1 * y1 + x2 * y2) / denom;
627                         //double cb_cfl_fac = (x1 * y1) / denom;
628
629                         for (unsigned y = 0; y < 8; ++y) {
630                                 for (unsigned x = 0; x < 8; ++x) {
631                                         short a = in_y[y * 8 + x];
632                                         printf(" %4d", a);
633                                 }
634                                 printf(" | ");
635                                 for (unsigned x = 0; x < 8; ++x) {
636                                         short a = in_cb[y * 8 + x];
637                                         printf(" %4d", a);
638                                 }
639                                 printf("\n");
640                         }
641                         printf("(%d,%d,%d) -> (%d,%d,%d) gives %f\n",
642                                 in_y[1], in_y[8], in_y[9], 
643                                 in_cb[1], in_cb[8], in_cb[9],
644                                 cb_cfl_fac);
645
646                         y0 = in_cr[1];
647                         y1 = in_cr[8];
648                         y2 = in_cr[9];
649                         double cr_cfl_fac = (x0 * y0 + x1 * y1 + x2 * y2) / denom;
650                         //double cr_cfl_fac = (x1 * y1) / denom;
651                         printf("cb CfL = %7.3f  dc = %5d    cr CfL = %7.3f  dc = %d\n",
652                                 cb_cfl_fac, in_cb[0] - in_y[0],
653                                 cr_cfl_fac, in_cr[0] - in_y[0]);
654
655                         if (denom == 0.0) { cb_cfl_fac = cr_cfl_fac = 0.0; }
656
657                         // CHEAT
658                         //last_cb_cfl_fac = cb_cfl_fac;
659                         //last_cr_cfl_fac = cr_cfl_fac;
660
661                         for (unsigned coeff_idx = 1; coeff_idx < 64; ++coeff_idx) {
662                                 //printf("%2d: cb = %3d prediction = %f * %3d = %7.3f\n", coeff_idx, in_cb[coeff_idx], last_cb_cfl_fac, in_y[coeff_idx], last_cb_cfl_fac * in_y[coeff_idx]);
663                                 //printf("%2d: cr = %3d prediction = %f * %3d = %7.3f\n", coeff_idx, in_cr[coeff_idx], last_cr_cfl_fac, in_y[coeff_idx], last_cr_cfl_fac * in_y[coeff_idx]);
664                                 double cb_pred = last_cb_cfl_fac * in_y[coeff_idx];
665                                 chroma_energy += in_cb[coeff_idx] * in_cb[coeff_idx];
666                                 chroma_energy_pred += (in_cb[coeff_idx] - cb_pred) * (in_cb[coeff_idx] - cb_pred);
667
668                                 //in_cb[coeff_idx] -= lrint(last_cb_cfl_fac * in_y[coeff_idx]);
669                                 //in_cr[coeff_idx] -= lrint(last_cr_cfl_fac * in_y[coeff_idx]);
670                                 //in_cr[coeff_idx] -= lrint(last_cr_cfl_fac * in_y[coeff_idx]);
671                                 //in_cb[coeff_idx] = lrint(in_y[coeff_idx] * (1.0 / sqrt(2)));
672                                 //in_cr[coeff_idx] = lrint(in_y[coeff_idx] * (1.0 / sqrt(2)));
673                                 //in_cb[coeff_idx] = lrint(in_y[coeff_idx]);
674                                 //in_cr[coeff_idx] = lrint(in_y[coeff_idx]);
675                         }
676                         //in_cb[0] += 1024;
677                         //in_cr[0] += 1024;
678                         //in_cb[0] -= in_y[0];
679                         //in_cr[0] -= in_y[0];
680 #endif
681
682                         for (unsigned y = 0; y < 8; ++y) {
683                                 for (unsigned x = 0; x < 8; ++x) {
684                                         int coeff_idx = y * 8 + x;
685                                         int k_cb = quantize(in_cb[coeff_idx], coeff_idx);
686                                         coeff_cb[(yb + y) * (WIDTH/2) + (xb + x)] = k_cb;
687                                         int k_cr = quantize(in_cr[coeff_idx], coeff_idx);
688                                         coeff_cr[(yb + y) * (WIDTH/2) + (xb + x)] = k_cr;
689
690                                         // Store back for reconstruction / PSNR calculation
691                                         in_cb[coeff_idx] = unquantize(k_cb, coeff_idx);
692                                         in_cr[coeff_idx] = unquantize(k_cr, coeff_idx);
693                                 }
694                         }
695
696                         idct_int32(in_y);  // DEBUG
697                         idct_int32(in_cb);
698                         idct_int32(in_cr);
699
700                         for (unsigned y = 0; y < 8; ++y) {
701                                 for (unsigned x = 0; x < 8; ++x) {
702                                         pix_cb[(yb + y) * (WIDTH/2) + (xb + x)] = clamp(in_cb[y * 8 + x]);
703                                         pix_cr[(yb + y) * (WIDTH/2) + (xb + x)] = clamp(in_cr[y * 8 + x]);
704
705                         //              pix_cb[(yb + y) * (WIDTH/2) + (xb + x)] = in_y[y * 8 + x];
706                         //              pix_cr[(yb + y) * (WIDTH/2) + (xb + x)] = in_y[y * 8 + x];
707                                 }
708                         }
709
710 #if 0
711                         last_cb_cfl_fac = cb_cfl_fac;
712                         last_cr_cfl_fac = cr_cfl_fac;
713 #endif
714                 }
715         }
716
717 #if 0
718         printf("chroma_energy = %f, with_pred = %f\n",
719                 chroma_energy / (WIDTH * HEIGHT), chroma_energy_pred / (WIDTH * HEIGHT));
720 #endif
721
722         // DC coefficient pred from the right to left
723         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
724                 for (unsigned xb = 0; xb < WIDTH - 8; xb += 8) {
725                         coeff_y[yb * WIDTH + xb] -= coeff_y[yb * WIDTH + (xb + 8)];
726                 }
727         }
728         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
729                 for (unsigned xb = 0; xb < WIDTH/2 - 8; xb += 8) {
730                         coeff_cb[yb * WIDTH/2 + xb] -= coeff_cb[yb * WIDTH/2 + (xb + 8)];
731                         coeff_cr[yb * WIDTH/2 + xb] -= coeff_cr[yb * WIDTH/2 + (xb + 8)];
732                 }
733         }
734
735         FILE *fp = fopen("reconstructed.pgm", "wb");
736         fprintf(fp, "P5\n%d %d\n255\n", WIDTH, HEIGHT);
737         fwrite(pix_y, 1, WIDTH * HEIGHT, fp);
738         fclose(fp);
739
740         fp = fopen("reconstructed.pnm", "wb");
741         fprintf(fp, "P6\n%d %d\n255\n", WIDTH, HEIGHT);
742         for (unsigned yb = 0; yb < HEIGHT; ++yb) {
743                 for (unsigned xb = 0; xb < WIDTH; ++xb) {
744                         int y = pix_y[(yb * WIDTH) + xb];
745                         int cb, cr;
746                         int c0 = yb * (WIDTH/2) + xb/2;
747                         if (xb % 2 == 0) {
748                                 cb = pix_cb[c0] - 128.0;
749                                 cr = pix_cr[c0] - 128.0;
750                         } else {
751                                 int c1 = yb * (WIDTH/2) + std::min<int>(xb/2 + 1, WIDTH/2 - 1);
752                                 cb = 0.5 * (pix_cb[c0] + pix_cb[c1]) - 128.0;
753                                 cr = 0.5 * (pix_cr[c0] + pix_cr[c1]) - 128.0;
754                         }
755
756                         double r = y + 1.5748 * cr;
757                         double g = y - 0.1873 * cb - 0.4681 * cr;
758                         double b = y + 1.8556 * cb;
759
760                         putc(clamp(lrint(r)), fp);
761                         putc(clamp(lrint(g)), fp);
762                         putc(clamp(lrint(b)), fp);
763                 }
764         }
765         fclose(fp);
766
767         // For each coefficient, make some tables.
768         size_t extra_bits = 0;
769         for (unsigned i = 0; i < 64; ++i) {
770                 stats[i].clear();
771         }
772         for (unsigned y = 0; y < 8; ++y) {
773                 for (unsigned x = 0; x < 8; ++x) {
774                         SymbolStats &s_luma = stats[pick_stats_for(x, y, false)];
775                         SymbolStats &s_chroma = stats[pick_stats_for(x, y, true)];
776
777                         // Luma
778                         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
779                                 for (unsigned xb = 0; xb < WIDTH; xb += 8) {
780                                         unsigned short k = abs(coeff_y[(yb + y) * WIDTH + (xb + x)]);
781                                         if (k >= ESCAPE_LIMIT) {
782                                                 k = ESCAPE_LIMIT;
783                                                 extra_bits += 12;  // escape this one
784                                         }
785                                         ++s_luma.freqs[(k - 1) & (NUM_SYMS - 1)];
786                                 }
787                         }
788                         // Chroma
789                         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
790                                 for (unsigned xb = 0; xb < WIDTH/2; xb += 8) {
791                                         unsigned short k_cb = abs(coeff_cb[(yb + y) * WIDTH/2 + (xb + x)]);
792                                         unsigned short k_cr = abs(coeff_cr[(yb + y) * WIDTH/2 + (xb + x)]);
793                                         if (k_cb >= ESCAPE_LIMIT) {
794                                                 k_cb = ESCAPE_LIMIT;
795                                                 extra_bits += 12;  // escape this one
796                                         }
797                                         if (k_cr >= ESCAPE_LIMIT) {
798                                                 k_cr = ESCAPE_LIMIT;
799                                                 extra_bits += 12;  // escape this one
800                                         }
801                                         ++s_chroma.freqs[(k_cb - 1) & (NUM_SYMS - 1)];
802                                         ++s_chroma.freqs[(k_cr - 1) & (NUM_SYMS - 1)];
803                                 }
804                         }
805                 }
806         }
807
808 #if FIND_OPTIMAL_STREAM_ASSIGNMENT
809         printf("Luma:\n");
810         find_optimal_stream_assignment(0);
811         printf("Chroma:\n");
812         find_optimal_stream_assignment(64);
813         exit(0);
814 #endif
815
816         for (unsigned i = 0; i < 64; ++i) {
817                 stats[i].freqs[NUM_SYMS - 1] /= 2;  // zero, has no sign bits (yes, this is trickery)
818                 stats[i].normalize_freqs(prob_scale);
819                 stats[i].cum_freqs[NUM_SYMS] += stats[i].freqs[NUM_SYMS - 1];
820                 stats[i].freqs[NUM_SYMS - 1] *= 2;
821         }
822
823         FILE *codedfp = fopen("coded.dat", "wb");
824         if (codedfp == nullptr) {
825                 perror("coded.dat");
826                 exit(1);
827         }
828
829         // TODO: rather gamma-k or something
830         for (unsigned i = 0; i < 64; ++i) {
831                 if (stats[i].cum_freqs[NUM_SYMS] == 0) {
832                         continue;
833                 }
834                 printf("writing table %d\n", i);
835                 for (unsigned j = 0; j < NUM_SYMS; ++j) {
836                         write_varint(stats[i].freqs[j], codedfp);
837                 }
838         }
839
840         RansEncoder rans_encoder;
841
842         size_t tot_bytes = 0;
843
844         // Luma
845         for (unsigned y = 0; y < 8; ++y) {
846                 for (unsigned x = 0; x < 8; ++x) {
847                         SymbolStats &s_luma = stats[pick_stats_for(x, y, false)];
848                         rans_encoder.init_prob(s_luma);
849
850                         // Luma
851                         std::vector<int> lens;
852
853                         // need to reverse later
854                         rans_encoder.clear();
855                         size_t num_bytes = 0;
856                         for (unsigned block_idx = 0; block_idx < NUM_BLOCKS; ++block_idx) {
857                                 unsigned yb = block_idx / WIDTH_BLOCKS;
858                                 unsigned xb = block_idx % WIDTH_BLOCKS;
859
860                                 int k = coeff_y[(yb * 8 + y) * WIDTH + (xb * 8 + x)];
861                                 //printf("encoding coeff %d xb,yb=%d,%d: %d\n", y*8+x, xb, yb, k);
862                                 rans_encoder.encode_coeff(k);
863
864                                 if (block_idx % 320 == 319 || block_idx == NUM_BLOCKS - 1) {
865                                         int l = rans_encoder.save_block(codedfp);
866                                         num_bytes += l;
867                                         lens.push_back(l);
868                                 }
869                         }
870                         tot_bytes += num_bytes;
871                         printf("coeff %d Y': %ld bytes\n", y * 8 + x, num_bytes);
872
873                         double sum_l = 0.0;
874                         for (int l : lens) {
875                                 sum_l += l;
876                         }
877                         double avg_l = sum_l / lens.size();
878
879                         double sum_sql = 0.0;
880                         for (int l : lens) {
881                                 sum_sql += (l - avg_l) * (l - avg_l);
882                         }
883                         double stddev_l = sqrt(sum_sql / (lens.size() - 1));
884                         printf("coeff %d: avg=%.2f bytes, stddev=%.2f bytes\n", y*8+x, avg_l, stddev_l);
885                 }
886         }
887
888         // Cb
889         for (unsigned y = 0; y < 8; ++y) {
890                 for (unsigned x = 0; x < 8; ++x) {
891                         SymbolStats &s_chroma = stats[pick_stats_for(x, y, true)];
892                         rans_encoder.init_prob(s_chroma);
893
894                         rans_encoder.clear();
895                         size_t num_bytes = 0;
896                         for (unsigned block_idx = 0; block_idx < NUM_BLOCKS_CHROMA; ++block_idx) {
897                                 unsigned yb = block_idx / WIDTH_BLOCKS_CHROMA;
898                                 unsigned xb = block_idx % WIDTH_BLOCKS_CHROMA;
899
900                                 int k = coeff_cb[(yb * 8 + y) * WIDTH/2 + (xb * 8 + x)];
901                                 //printf("encoding coeff %d xb,yb=%d,%d: %d\n", y*8+x, xb, yb, k);
902                                 rans_encoder.encode_coeff(k);
903
904                                 if (block_idx % 320 == 319 || block_idx == NUM_BLOCKS - 1) {
905                                         num_bytes += rans_encoder.save_block(codedfp);
906                                 }
907                         }
908                         tot_bytes += num_bytes;
909                         printf("coeff %d Cb: %ld bytes\n", y * 8 + x, num_bytes);
910                 }
911         }
912
913         // Cr
914         for (unsigned y = 0; y < 8; ++y) {
915                 for (unsigned x = 0; x < 8; ++x) {
916                         SymbolStats &s_chroma = stats[pick_stats_for(x, y, true)];
917                         rans_encoder.init_prob(s_chroma);
918
919                         rans_encoder.clear();
920                         size_t num_bytes = 0;
921                         for (unsigned block_idx = 0; block_idx < NUM_BLOCKS_CHROMA; ++block_idx) {
922                                 unsigned yb = block_idx / WIDTH_BLOCKS_CHROMA;
923                                 unsigned xb = block_idx % WIDTH_BLOCKS_CHROMA;
924
925                                 int k = coeff_cr[(yb * 8 + y) * WIDTH/2 + (xb * 8 + x)];
926                                 //printf("encoding coeff %d xb,yb=%d,%d: %d\n", y*8+x, xb, yb, k);
927                                 rans_encoder.encode_coeff(k);
928
929                                 if (block_idx % 320 == 319 || block_idx == NUM_BLOCKS - 1) {
930                                         num_bytes += rans_encoder.save_block(codedfp);
931                                 }
932                         }
933                         tot_bytes += num_bytes;
934                         printf("coeff %d Cr: %ld bytes\n", y * 8 + x, num_bytes);
935                 }
936         }
937
938         printf("%ld bytes + %ld escape bits (%ld) = %ld total bytes\n",
939                 tot_bytes - extra_bits / 8,
940                 extra_bits,
941                 extra_bits / 8,
942                 tot_bytes);
943 }