]> git.sesse.net Git - narabu/blob - qdc.cpp
Stop hardcoding blocks per row in the shader.
[narabu] / qdc.cpp
1 #include <stdio.h>
2 #include <stdint.h>
3 #include <stdlib.h>
4 #include <string.h>
5 #include <assert.h>
6 #include <math.h>
7
8 //#include "ryg_rans/rans64.h"
9 #include "ryg_rans/rans_byte.h"
10 #include "ryg_rans/renormalize.h"
11
12 #include <algorithm>
13 #include <memory>
14 #include <numeric>
15 #include <random>
16 #include <vector>
17 #include <unordered_map>
18
19 #define WIDTH 1280
20 #define HEIGHT 720
21 #define WIDTH_BLOCKS (WIDTH/8)
22 #define WIDTH_BLOCKS_CHROMA (WIDTH/16)
23 #define HEIGHT_BLOCKS (HEIGHT/8)
24 #define NUM_BLOCKS (WIDTH_BLOCKS * HEIGHT_BLOCKS)
25 #define NUM_BLOCKS_CHROMA (WIDTH_BLOCKS_CHROMA * HEIGHT_BLOCKS)
26 #define NUM_SYMS 256
27 #define ESCAPE_LIMIT (NUM_SYMS - 1)
28
29 // If you set this to 1, the program will try to optimize the placement
30 // of coefficients to rANS probability distributions. This is randomized,
31 // so you might want to run it a few times.
32 #define FIND_OPTIMAL_STREAM_ASSIGNMENT 0
33 #define NUM_CLUSTERS 4
34
35 static constexpr uint32_t prob_bits = 12;
36 static constexpr uint32_t prob_scale = 1 << prob_bits;
37
38 using namespace std;
39
40 void fdct_int32(short *const In);
41 void idct_int32(short *const In);
42
43 unsigned char rgb[WIDTH * HEIGHT * 3];
44 unsigned char pix_y[WIDTH * HEIGHT];
45 unsigned char pix_cb[(WIDTH/2) * HEIGHT];
46 unsigned char pix_cr[(WIDTH/2) * HEIGHT];
47 unsigned char full_pix_cb[WIDTH * HEIGHT];
48 unsigned char full_pix_cr[WIDTH * HEIGHT];
49 short coeff_y[WIDTH * HEIGHT], coeff_cb[(WIDTH/2) * HEIGHT], coeff_cr[(WIDTH/2) * HEIGHT];
50
51 int clamp(int x)
52 {
53         if (x < 0) return 0;
54         if (x > 255) return 255;
55         return x;
56 }
57
58 static const unsigned char std_luminance_quant_tbl[64] = {
59 #if 0
60         16,  11,  10,  16,  24,  40,  51,  61,
61         12,  12,  14,  19,  26,  58,  60,  55,
62         14,  13,  16,  24,  40,  57,  69,  56,
63         14,  17,  22,  29,  51,  87,  80,  62,
64         18,  22,  37,  56,  68, 109, 103,  77,
65         24,  35,  55,  64,  81, 104, 113,  92,
66         49,  64,  78,  87, 103, 121, 120, 101,
67         72,  92,  95,  98, 112, 100, 103,  99
68 #else
69         // ff_mpeg1_default_intra_matrix
70          8, 16, 19, 22, 26, 27, 29, 34,
71         16, 16, 22, 24, 27, 29, 34, 37,                                                 
72         19, 22, 26, 27, 29, 34, 34, 38,                                                 
73         22, 22, 26, 27, 29, 34, 37, 40,
74         22, 26, 27, 29, 32, 35, 40, 48,
75         26, 27, 29, 32, 35, 40, 48, 58,
76         26, 27, 29, 34, 38, 46, 56, 69,
77         27, 29, 35, 38, 46, 56, 69, 83
78 #endif
79 };
80
81 struct SymbolStats
82 {
83     uint32_t freqs[NUM_SYMS];
84     uint32_t cum_freqs[NUM_SYMS + 1];
85
86     void clear();
87     void calc_cum_freqs();
88     void normalize_freqs(uint32_t target_total);
89 };
90
91 void SymbolStats::clear()
92 {
93     for (int i=0; i < NUM_SYMS; i++)
94         freqs[i] = 0;
95 }
96
97 void SymbolStats::calc_cum_freqs()
98 {
99     cum_freqs[0] = 0;
100     for (int i=0; i < NUM_SYMS; i++)
101         cum_freqs[i+1] = cum_freqs[i] + freqs[i];
102 }
103
104 void SymbolStats::normalize_freqs(uint32_t target_total)
105 {
106     uint64_t real_freq[NUM_SYMS + 1];  // hack
107
108     assert(target_total >= NUM_SYMS);
109
110     calc_cum_freqs();
111     uint32_t cur_total = cum_freqs[NUM_SYMS];
112
113     if (cur_total == 0) return;
114
115     double ideal_cost = 0.0;
116     for (int i = 1; i <= NUM_SYMS; i++)
117     {
118       real_freq[i] = cum_freqs[i] - cum_freqs[i - 1];
119       if (real_freq[i] > 0)
120         ideal_cost -= real_freq[i] * log2(real_freq[i] / double(cur_total));
121     }
122
123     OptimalRenormalize(cum_freqs, NUM_SYMS, prob_scale);
124
125     // calculate updated freqs and make sure we didn't screw anything up
126     assert(cum_freqs[0] == 0 && cum_freqs[NUM_SYMS] == target_total);
127     for (int i=0; i < NUM_SYMS; i++) {
128         if (freqs[i] == 0)
129             assert(cum_freqs[i+1] == cum_freqs[i]);
130         else
131             assert(cum_freqs[i+1] > cum_freqs[i]);
132
133         // calc updated freq
134         freqs[i] = cum_freqs[i+1] - cum_freqs[i];
135     }
136
137     double calc_cost = 0.0;
138     for (int i = 1; i <= NUM_SYMS; i++)
139     {
140       uint64_t freq = cum_freqs[i] - cum_freqs[i - 1];
141       if (real_freq[i] > 0)
142         calc_cost -= real_freq[i] * log2(freq / double(target_total));
143     }
144
145     static double total_loss = 0.0;
146     total_loss += calc_cost - ideal_cost;
147     static double total_loss_with_dp = 0.0;
148         double optimal_cost = 0.0;
149     //total_loss_with_dp += optimal_cost - ideal_cost;
150     printf("ideal cost = %.0f bits, DP cost = %.0f bits, calc cost = %.0f bits (loss = %.2f bytes, total loss = %.2f bytes, total loss with DP = %.2f bytes)\n",
151                 ideal_cost, optimal_cost,
152                  calc_cost, (calc_cost - ideal_cost) / 8.0, total_loss / 8.0, total_loss_with_dp / 8.0);
153 }
154
155 SymbolStats stats[128];
156
157 #if FIND_OPTIMAL_STREAM_ASSIGNMENT
158 // Distance from one stream to the other, based on a hacked-up K-L divergence.
159 float kl_dist[64][64];
160 #endif
161
162 const int luma_mapping[64] = {
163         0, 0, 1, 1, 2, 2, 3, 3,
164         0, 0, 1, 2, 2, 2, 3, 3,
165         1, 1, 2, 2, 2, 3, 3, 3,
166         1, 1, 2, 2, 2, 3, 3, 3,
167         1, 2, 2, 2, 2, 3, 3, 3,
168         2, 2, 2, 2, 3, 3, 3, 3,
169         2, 2, 3, 3, 3, 3, 3, 3,
170         3, 3, 3, 3, 3, 3, 3, 3,
171 };
172 const int chroma_mapping[64] = {
173         0, 1, 1, 2, 2, 2, 3, 3,
174         1, 1, 2, 2, 2, 3, 3, 3,
175         2, 2, 2, 2, 3, 3, 3, 3,
176         2, 2, 2, 3, 3, 3, 3, 3,
177         2, 3, 3, 3, 3, 3, 3, 3,
178         3, 3, 3, 3, 3, 3, 3, 3,
179         3, 3, 3, 3, 3, 3, 3, 3,
180         3, 3, 3, 3, 3, 3, 3, 3,
181 };
182
183 int pick_stats_for(int x, int y, bool is_chroma)
184 {
185 #if FIND_OPTIMAL_STREAM_ASSIGNMENT
186         return y * 8 + x + is_chroma * 64;
187 #else
188         if (is_chroma) {
189                 return chroma_mapping[y * 8 + x] + 4;
190         } else {
191                 return luma_mapping[y * 8 + x];
192         }
193 #endif
194 }
195                 
196
197 void write_varint(int x, FILE *fp)
198 {
199         while (x >= 128) {
200                 putc((x & 0x7f) | 0x80, fp);
201                 x >>= 7;
202         }
203         putc(x, fp);
204 }
205
206 class RansEncoder {
207 public:
208         RansEncoder()
209         {
210                 out_buf.reset(new uint8_t[out_max_size]);
211                 clear();
212         }
213
214         void init_prob(SymbolStats &s)
215         {
216                 for (int i = 0; i < NUM_SYMS; i++) {
217                         printf("%d: cumfreqs=%d freqs=%d prob_bits=%d\n", i, s.cum_freqs[i], s.freqs[i], prob_bits + 1);
218                         RansEncSymbolInit(&esyms[i], s.cum_freqs[i], s.freqs[i], prob_bits + 1);
219                 }
220                 sign_bias = s.cum_freqs[NUM_SYMS];
221         }
222
223         void clear()
224         {
225                 out_end = out_buf.get() + out_max_size;
226                 ptr = out_end; // *end* of output buffer
227                 RansEncInit(&rans);
228         }
229
230         uint32_t save_block(FILE *codedfp)  // Returns number of bytes.
231         {
232                 RansEncFlush(&rans, &ptr);
233                 //printf("post-flush = %08x\n", rans);
234
235                 uint32_t num_rans_bytes = out_end - ptr;
236 #if 0
237                 if (num_rans_bytes == 4) {
238                         uint32_t block;
239                         memcpy(&block, ptr, 4);
240
241                         if (block == last_block) {
242                                 write_varint(0, codedfp);
243                                 clear();
244                                 return 1;
245                         }
246
247                         last_block = block;
248                 } else {
249                         last_block = 0;
250                 }
251 #endif
252
253                 write_varint(num_rans_bytes, codedfp);
254                 //fwrite(&num_rans_bytes, 1, 4, codedfp);
255                 fwrite(ptr, 1, num_rans_bytes, codedfp);
256
257                 //printf("first rANS bytes: %02x %02x %02x %02x %02x %02x %02x %02x\n", ptr[0], ptr[1], ptr[2], ptr[3], ptr[4], ptr[5], ptr[6], ptr[7]);
258
259
260                 clear();
261
262                 printf("Saving block: %d rANS bytes\n", num_rans_bytes);
263                 return num_rans_bytes;
264                 //return num_rans_bytes;
265         }
266
267         void encode_coeff(short signed_k)
268         {
269                 //printf("encoding coeff %d (sym %d), rans before encoding = %08x\n", signed_k, ((abs(signed_k) - 1) & 255), rans);
270                 unsigned short k = abs(signed_k);
271                 if (k >= ESCAPE_LIMIT) {
272                         // Put the coefficient as a 1/(2^12) symbol _before_
273                         // the 255 coefficient, since the decoder will read the
274                         // 255 coefficient first.
275                         RansEncPut(&rans, &ptr, k, 1, prob_bits);
276                         k = ESCAPE_LIMIT;
277                 }
278                 RansEncPutSymbol(&rans, &ptr, &esyms[(k - 1) & (NUM_SYMS - 1)]);
279                 if (signed_k < 0) {
280                         rans += sign_bias;
281                 }
282         }
283
284 private:
285         static constexpr size_t out_max_size = 32 << 20; // 32 MB.
286         static constexpr size_t max_num_sign = 1048576;  // Way too big. And actually bytes.
287
288         unique_ptr<uint8_t[]> out_buf;
289         uint8_t *out_end;
290         uint8_t *ptr;
291         RansState rans;
292         RansEncSymbol esyms[NUM_SYMS];
293         uint32_t sign_bias;
294
295         uint32_t last_block = 0;  // Not a valid 4-byte rANS block (?)
296 };
297
298 static constexpr int dc_scalefac = 8;  // Matches the FDCT's gain.
299 static constexpr double quant_scalefac = 4.0;  // whatever?
300
301 static inline int quantize(int f, int coeff_idx)
302 {
303         if (coeff_idx == 0) {
304                 return f / dc_scalefac;
305         }
306         if (f == 0) {
307                 return 0;
308         }
309
310         const int w = std_luminance_quant_tbl[coeff_idx];
311         const int s = quant_scalefac;
312         int sign_f = (f > 0) ? 1 : -1;
313         return (32 * f + sign_f * w * s) / (2 * w * s);
314 }
315
316 static inline int unquantize(int qf, int coeff_idx)
317 {
318         if (coeff_idx == 0) {
319                 return qf * dc_scalefac;
320         }
321         if (qf == 0) {
322                 return 0;
323         }
324
325         const int w = std_luminance_quant_tbl[coeff_idx];
326         const int s = quant_scalefac;
327         return (2 * qf * w * s) / 32;
328 }
329
330 void readpix(unsigned char *ptr, const char *filename)
331 {
332         FILE *fp = fopen(filename, "rb");
333         if (fp == nullptr) {
334                 perror(filename);
335                 exit(1);
336         }
337
338         fseek(fp, 0, SEEK_END);
339         long len = ftell(fp);
340         assert(len >= WIDTH * HEIGHT * 3);
341         fseek(fp, len - WIDTH * HEIGHT * 3, SEEK_SET);
342
343         fread(ptr, 1, WIDTH * HEIGHT * 3, fp);
344         fclose(fp);
345 }
346
347 void convert_ycbcr()
348 {
349         double coeff[3] = { 0.2126, 0.7152, 0.0722 };  // sum = 1.0
350         double cb_fac = 1.0 / (coeff[0] + coeff[1] + 1.0f - coeff[2]);  // 0.539
351         double cr_fac = 1.0 / (1.0f - coeff[0] + coeff[1] + coeff[2]);  // 0.635 
352
353         unique_ptr<float[]> temp_cb(new float[WIDTH * HEIGHT]);
354         unique_ptr<float[]> temp_cr(new float[WIDTH * HEIGHT]);
355         for (unsigned yb = 0; yb < HEIGHT; ++yb) {
356                 for (unsigned xb = 0; xb < WIDTH; ++xb) {
357                         int r = rgb[((yb * WIDTH) + xb) * 3 + 0];
358                         int g = rgb[((yb * WIDTH) + xb) * 3 + 1];
359                         int b = rgb[((yb * WIDTH) + xb) * 3 + 2];
360                         double y = std::min(std::max(coeff[0] * r + coeff[1] * g + coeff[2] * b, 0.0), 255.0);
361                         double cb = (b - y) * cb_fac + 128.0;
362                         double cr = (r - y) * cr_fac + 128.0;
363                         pix_y[(yb * WIDTH) + xb] = lrint(y);
364                         temp_cb[(yb * WIDTH) + xb] = cb;
365                         temp_cr[(yb * WIDTH) + xb] = cr;
366                         full_pix_cb[(yb * WIDTH) + xb] = lrint(std::min(std::max(cb, 0.0), 255.0));
367                         full_pix_cr[(yb * WIDTH) + xb] = lrint(std::min(std::max(cr, 0.0), 255.0));
368                 }
369         }
370
371         // Simple 4:2:2 subsampling with left convention.
372         for (unsigned yb = 0; yb < HEIGHT; ++yb) {
373                 for (unsigned xb = 0; xb < WIDTH / 2; ++xb) {
374                         int c0 = yb * WIDTH + std::max(int(xb) * 2 - 1, 0);
375                         int c1 = yb * WIDTH + xb * 2;
376                         int c2 = yb * WIDTH + xb * 2 + 1;
377                         
378                         double cb = 0.25 * temp_cb[c0] + 0.5 * temp_cb[c1] + 0.25 * temp_cb[c2];
379                         double cr = 0.25 * temp_cr[c0] + 0.5 * temp_cr[c1] + 0.25 * temp_cr[c2];
380                         cb = std::min(std::max(cb, 0.0), 255.0);
381                         cr = std::min(std::max(cr, 0.0), 255.0);
382                         pix_cb[(yb * WIDTH/2) + xb] = lrint(cb);
383                         pix_cr[(yb * WIDTH/2) + xb] = lrint(cr);
384                 }
385         }
386 }
387
388 #if FIND_OPTIMAL_STREAM_ASSIGNMENT
389 double find_best_assignment(const int *medoids, int *assignment)
390 {
391         double current_score = 0.0;
392         for (int i = 0; i < 64; ++i) {
393                 int best_medoid = medoids[0];
394                 float best_medoid_score = kl_dist[i][medoids[0]];
395                 for (int j = 1; j < NUM_CLUSTERS; ++j) {
396                         if (kl_dist[i][medoids[j]] < best_medoid_score) {
397                                 best_medoid = medoids[j];
398                                 best_medoid_score = kl_dist[i][medoids[j]];
399                         }
400                 }
401                 assignment[i] = best_medoid;
402                 current_score += best_medoid_score;
403         }
404         return current_score;
405 }
406
407 void find_optimal_stream_assignment(int base)
408 {
409         double inv_sum[64];
410         for (unsigned i = 0; i < 64; ++i) {
411                 double s = 0.0;
412                 for (unsigned k = 0; k < NUM_SYMS; ++k) {
413                         s += stats[i + base].freqs[k] + 0.5;
414                 }
415                 inv_sum[i] = 1.0 / s;
416         }
417
418         for (unsigned i = 0; i < 64; ++i) {
419                 for (unsigned j = 0; j < 64; ++j) {
420                         double d = 0.0;
421                         for (unsigned k = 0; k < NUM_SYMS; ++k) {
422                                 double p1 = (stats[i + base].freqs[k] + 0.5) * inv_sum[i];
423                                 double p2 = (stats[j + base].freqs[k] + 0.5) * inv_sum[j];
424
425                                 // K-L divergence is asymmetric; this is a hack.
426                                 d += p1 * log(p1 / p2);
427                                 d += p2 * log(p2 / p1);
428                         }
429                         kl_dist[i][j] = d;
430                         //printf("%.3f ", d);
431                 }
432                 //printf("\n");
433         }
434
435         // k-medoids init
436         int medoids[64];  // only the first NUM_CLUSTERS matter
437         bool is_medoid[64] = { false };
438         std::iota(medoids, medoids + 64, 0);
439         std::random_device rd;
440         std::mt19937 g(rd());
441         std::shuffle(medoids, medoids + 64, g);
442         for (int i = 0; i < NUM_CLUSTERS; ++i) {
443                 printf("%d ", medoids[i]);
444                 is_medoid[medoids[i]] = true;
445         }
446         printf("\n");
447
448         // assign each data point to the closest medoid
449         int assignment[64];
450         double current_score = find_best_assignment(medoids, assignment);
451
452         for (int i = 0; i < 1000; ++i) {
453                 printf("iter %d\n", i);
454                 bool any_changed = false;
455                 for (int m = 0; m < NUM_CLUSTERS; ++m) {
456                         for (int o = 0; o < 64; ++o) {
457                                 if (is_medoid[o]) continue;
458                                 int old_medoid = medoids[m];
459                                 medoids[m] = o;
460
461                                 int new_assignment[64];
462                                 double candidate_score = find_best_assignment(medoids, new_assignment);
463
464                                 if (candidate_score < current_score) {
465                                         current_score = candidate_score;
466                                         memcpy(assignment, new_assignment, sizeof(assignment));
467
468                                         is_medoid[old_medoid] = false;
469                                         is_medoid[medoids[m]] = true;
470                                         printf("%f: ", current_score);
471                                         for (int i = 0; i < 64; ++i) {
472                                                 printf("%d ", assignment[i]);
473                                         }
474                                         printf("\n");
475                                         any_changed = true;
476                                 } else {
477                                         medoids[m] = old_medoid;
478                                 }
479                         }
480                 }
481                 if (!any_changed) break;
482         }
483         printf("\n");
484         std::unordered_map<int, int> rmap;
485         for (int i = 0; i < 64; ++i) {
486                 if (i % 8 == 0) printf("\n");
487                 if (!rmap.count(assignment[i])) {
488                         rmap.emplace(assignment[i], rmap.size());
489                 }
490                 printf("%d, ", rmap[assignment[i]]);
491         }
492         printf("\n");
493 }
494 #endif
495
496 int main(int argc, char **argv)
497 {
498         if (argc >= 2)
499                 readpix(rgb, argv[1]);
500         else
501                 readpix(rgb, "color.pnm");
502         convert_ycbcr();
503
504         double sum_sq_err = 0.0;
505         //double last_cb_cfl_fac = 0.0;
506         //double last_cr_cfl_fac = 0.0;
507
508         // DCT and quantize luma
509         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
510                 for (unsigned xb = 0; xb < WIDTH; xb += 8) {
511                         // Read one block
512                         short in_y[64];
513                         for (unsigned y = 0; y < 8; ++y) {
514                                 for (unsigned x = 0; x < 8; ++x) {
515                                         in_y[y * 8 + x] = pix_y[(yb + y) * WIDTH + (xb + x)];
516                                 }
517                         }
518
519                         // FDCT it
520                         fdct_int32(in_y);
521
522                         for (unsigned y = 0; y < 8; ++y) {
523                                 for (unsigned x = 0; x < 8; ++x) {
524                                         int coeff_idx = y * 8 + x;
525                                         int k = quantize(in_y[coeff_idx], coeff_idx);
526                                         coeff_y[(yb + y) * WIDTH + (xb + x)] = k;
527
528                                         // Store back for reconstruction / PSNR calculation
529                                         in_y[coeff_idx] = unquantize(k, coeff_idx);
530                                 }
531                         }
532
533                         idct_int32(in_y);
534
535                         for (unsigned y = 0; y < 8; ++y) {
536                                 for (unsigned x = 0; x < 8; ++x) {
537                                         int k = clamp(in_y[y * 8 + x]);
538                                         uint8_t *ptr = &pix_y[(yb + y) * WIDTH + (xb + x)];
539                                         sum_sq_err += (*ptr - k) * (*ptr - k);
540                                         *ptr = k;
541                                 }
542                         }
543                 }
544         }
545         double mse = sum_sq_err / double(WIDTH * HEIGHT);
546         double psnr_db = 20 * log10(255.0 / sqrt(mse));
547         printf("psnr = %.2f dB\n", psnr_db);
548
549         //double chroma_energy = 0.0, chroma_energy_pred = 0.0;
550
551         // DCT and quantize chroma
552         //double last_cb_cfl_fac = 0.0, last_cr_cfl_fac = 0.0;
553         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
554                 for (unsigned xb = 0; xb < WIDTH/2; xb += 8) {
555 #if 0
556                         // TF switch: Two 8x8 luma blocks -> one 16x8 block, then drop high frequencies
557                         printf("in blocks:\n");
558                         for (unsigned y = 0; y < 8; ++y) {
559                                 for (unsigned x = 0; x < 8; ++x) {
560                                         short a = coeff_y[(yb + y) * WIDTH + (xb*2 + x)];
561                                         printf(" %4d", a);
562                                 }
563                                 printf(" | ");
564                                 for (unsigned x = 0; x < 8; ++x) {
565                                         short b = coeff_y[(yb + y) * WIDTH + (xb*2 + x + 8)];
566                                         printf(" %4d", b);
567                                 }
568                                 printf("\n");
569                         }
570
571                         short in_y[64];
572                         for (unsigned y = 0; y < 8; ++y) {
573                                 for (unsigned x = 0; x < 4; ++x) {
574                                         short a = coeff_y[(yb + y) * WIDTH + (xb*2 + x)];
575                                         short b = coeff_y[(yb + y) * WIDTH + (xb*2 + x + 8)];
576                                         b = a - b;
577                                         a = 2 * a - b;
578                                         in_y[y * 8 + x * 2 + 0] = a;
579                                         in_y[y * 8 + x * 2 + 1] = b;
580                                 }
581                         }
582
583                         printf("tf-ed block:\n");
584                         for (unsigned y = 0; y < 8; ++y) {
585                                 for (unsigned x = 0; x < 8; ++x) {
586                                         short a = in_y[y * 8 + x];
587                                         printf(" %4d", a);
588                                 }
589                                 printf("\n");
590                         }
591 #else
592                         // Read Y block with no tf switch (from reconstructed luma)
593                         short in_y[64];
594                         for (unsigned y = 0; y < 8; ++y) {
595                                 for (unsigned x = 0; x < 8; ++x) {
596                                         in_y[y * 8 + x] = pix_y[(yb + y) * (WIDTH) + (xb + x) * 2];
597                                 }
598                         }
599                         fdct_int32(in_y);
600 #endif
601
602                         // Read one block
603                         short in_cb[64], in_cr[64];
604                         for (unsigned y = 0; y < 8; ++y) {
605                                 for (unsigned x = 0; x < 8; ++x) {
606                                         in_cb[y * 8 + x] = pix_cb[(yb + y) * (WIDTH/2) + (xb + x)];
607                                         in_cr[y * 8 + x] = pix_cr[(yb + y) * (WIDTH/2) + (xb + x)];
608                                 }
609                         }
610
611                         // FDCT it
612                         fdct_int32(in_cb);
613                         fdct_int32(in_cr);
614
615 #if 0
616                         // Chroma from luma
617                         double x0 = in_y[1];
618                         double x1 = in_y[8];
619                         double x2 = in_y[9];
620                         double denom = (x0 * x0 + x1 * x1 + x2 * x2);
621                         //double denom = (x1 * x1);
622         
623                         double y0 = in_cb[1];
624                         double y1 = in_cb[8];
625                         double y2 = in_cb[9];
626                         double cb_cfl_fac = (x0 * y0 + x1 * y1 + x2 * y2) / denom;
627                         //double cb_cfl_fac = (x1 * y1) / denom;
628
629                         for (unsigned y = 0; y < 8; ++y) {
630                                 for (unsigned x = 0; x < 8; ++x) {
631                                         short a = in_y[y * 8 + x];
632                                         printf(" %4d", a);
633                                 }
634                                 printf(" | ");
635                                 for (unsigned x = 0; x < 8; ++x) {
636                                         short a = in_cb[y * 8 + x];
637                                         printf(" %4d", a);
638                                 }
639                                 printf("\n");
640                         }
641                         printf("(%d,%d,%d) -> (%d,%d,%d) gives %f\n",
642                                 in_y[1], in_y[8], in_y[9], 
643                                 in_cb[1], in_cb[8], in_cb[9],
644                                 cb_cfl_fac);
645
646                         y0 = in_cr[1];
647                         y1 = in_cr[8];
648                         y2 = in_cr[9];
649                         double cr_cfl_fac = (x0 * y0 + x1 * y1 + x2 * y2) / denom;
650                         //double cr_cfl_fac = (x1 * y1) / denom;
651                         printf("cb CfL = %7.3f  dc = %5d    cr CfL = %7.3f  dc = %d\n",
652                                 cb_cfl_fac, in_cb[0] - in_y[0],
653                                 cr_cfl_fac, in_cr[0] - in_y[0]);
654
655                         if (denom == 0.0) { cb_cfl_fac = cr_cfl_fac = 0.0; }
656
657                         // CHEAT
658                         //last_cb_cfl_fac = cb_cfl_fac;
659                         //last_cr_cfl_fac = cr_cfl_fac;
660
661                         for (unsigned coeff_idx = 1; coeff_idx < 64; ++coeff_idx) {
662                                 //printf("%2d: cb = %3d prediction = %f * %3d = %7.3f\n", coeff_idx, in_cb[coeff_idx], last_cb_cfl_fac, in_y[coeff_idx], last_cb_cfl_fac * in_y[coeff_idx]);
663                                 //printf("%2d: cr = %3d prediction = %f * %3d = %7.3f\n", coeff_idx, in_cr[coeff_idx], last_cr_cfl_fac, in_y[coeff_idx], last_cr_cfl_fac * in_y[coeff_idx]);
664                                 double cb_pred = last_cb_cfl_fac * in_y[coeff_idx];
665                                 chroma_energy += in_cb[coeff_idx] * in_cb[coeff_idx];
666                                 chroma_energy_pred += (in_cb[coeff_idx] - cb_pred) * (in_cb[coeff_idx] - cb_pred);
667
668                                 //in_cb[coeff_idx] -= lrint(last_cb_cfl_fac * in_y[coeff_idx]);
669                                 //in_cr[coeff_idx] -= lrint(last_cr_cfl_fac * in_y[coeff_idx]);
670                                 //in_cr[coeff_idx] -= lrint(last_cr_cfl_fac * in_y[coeff_idx]);
671                                 //in_cb[coeff_idx] = lrint(in_y[coeff_idx] * (1.0 / sqrt(2)));
672                                 //in_cr[coeff_idx] = lrint(in_y[coeff_idx] * (1.0 / sqrt(2)));
673                                 //in_cb[coeff_idx] = lrint(in_y[coeff_idx]);
674                                 //in_cr[coeff_idx] = lrint(in_y[coeff_idx]);
675                         }
676                         //in_cb[0] += 1024;
677                         //in_cr[0] += 1024;
678                         //in_cb[0] -= in_y[0];
679                         //in_cr[0] -= in_y[0];
680 #endif
681
682                         for (unsigned y = 0; y < 8; ++y) {
683                                 for (unsigned x = 0; x < 8; ++x) {
684                                         int coeff_idx = y * 8 + x;
685                                         int k_cb = quantize(in_cb[coeff_idx], coeff_idx);
686                                         coeff_cb[(yb + y) * (WIDTH/2) + (xb + x)] = k_cb;
687                                         int k_cr = quantize(in_cr[coeff_idx], coeff_idx);
688                                         coeff_cr[(yb + y) * (WIDTH/2) + (xb + x)] = k_cr;
689
690                                         // Store back for reconstruction / PSNR calculation
691                                         in_cb[coeff_idx] = unquantize(k_cb, coeff_idx);
692                                         in_cr[coeff_idx] = unquantize(k_cr, coeff_idx);
693                                 }
694                         }
695
696                         idct_int32(in_y);  // DEBUG
697                         idct_int32(in_cb);
698                         idct_int32(in_cr);
699
700                         for (unsigned y = 0; y < 8; ++y) {
701                                 for (unsigned x = 0; x < 8; ++x) {
702                                         pix_cb[(yb + y) * (WIDTH/2) + (xb + x)] = clamp(in_cb[y * 8 + x]);
703                                         pix_cr[(yb + y) * (WIDTH/2) + (xb + x)] = clamp(in_cr[y * 8 + x]);
704
705                         //              pix_cb[(yb + y) * (WIDTH/2) + (xb + x)] = in_y[y * 8 + x];
706                         //              pix_cr[(yb + y) * (WIDTH/2) + (xb + x)] = in_y[y * 8 + x];
707                                 }
708                         }
709
710 #if 0
711                         last_cb_cfl_fac = cb_cfl_fac;
712                         last_cr_cfl_fac = cr_cfl_fac;
713 #endif
714                 }
715         }
716
717 #if 0
718         printf("chroma_energy = %f, with_pred = %f\n",
719                 chroma_energy / (WIDTH * HEIGHT), chroma_energy_pred / (WIDTH * HEIGHT));
720 #endif
721
722         // DC coefficient pred from the right to left (within each slice)
723         for (unsigned block_idx = 0; block_idx < NUM_BLOCKS; block_idx += 320) {
724                 int prev_k = 128;
725
726                 for (unsigned subblock_idx = 320; subblock_idx --> 0; ) {
727                         unsigned yb = (block_idx + subblock_idx) / WIDTH_BLOCKS;
728                         unsigned xb = (block_idx + subblock_idx) % WIDTH_BLOCKS;
729                         int k = coeff_y[(yb * 8) * WIDTH + (xb * 8)];
730
731                         coeff_y[(yb * 8) * WIDTH + (xb * 8)] = k - prev_k;
732
733                         prev_k = k;
734                 }
735         }
736         for (unsigned block_idx = 0; block_idx < NUM_BLOCKS_CHROMA; block_idx += 320) {
737                 int prev_k_cb = 0;
738                 int prev_k_cr = 0;
739
740                 for (unsigned subblock_idx = 320; subblock_idx --> 0; ) {
741                         unsigned yb = (block_idx + subblock_idx) / WIDTH_BLOCKS_CHROMA;
742                         unsigned xb = (block_idx + subblock_idx) % WIDTH_BLOCKS_CHROMA;
743                         int k_cb = coeff_cb[(yb * 8) * WIDTH/2 + (xb * 8)];
744                         int k_cr = coeff_cr[(yb * 8) * WIDTH/2 + (xb * 8)];
745
746                         coeff_cb[(yb * 8) * WIDTH/2 + (xb * 8)] = k_cb - prev_k_cb;
747                         coeff_cr[(yb * 8) * WIDTH/2 + (xb * 8)] = k_cr - prev_k_cr;
748
749                         prev_k_cb = k_cb;
750                         prev_k_cr = k_cr;
751                 }
752         }
753
754         FILE *fp = fopen("reconstructed.pgm", "wb");
755         fprintf(fp, "P5\n%d %d\n255\n", WIDTH, HEIGHT);
756         fwrite(pix_y, 1, WIDTH * HEIGHT, fp);
757         fclose(fp);
758
759         fp = fopen("reconstructed.pnm", "wb");
760         fprintf(fp, "P6\n%d %d\n255\n", WIDTH, HEIGHT);
761         for (unsigned yb = 0; yb < HEIGHT; ++yb) {
762                 for (unsigned xb = 0; xb < WIDTH; ++xb) {
763                         int y = pix_y[(yb * WIDTH) + xb];
764                         int cb, cr;
765                         int c0 = yb * (WIDTH/2) + xb/2;
766                         if (xb % 2 == 0) {
767                                 cb = pix_cb[c0] - 128.0;
768                                 cr = pix_cr[c0] - 128.0;
769                         } else {
770                                 int c1 = yb * (WIDTH/2) + std::min<int>(xb/2 + 1, WIDTH/2 - 1);
771                                 cb = 0.5 * (pix_cb[c0] + pix_cb[c1]) - 128.0;
772                                 cr = 0.5 * (pix_cr[c0] + pix_cr[c1]) - 128.0;
773                         }
774
775                         double r = y + 1.5748 * cr;
776                         double g = y - 0.1873 * cb - 0.4681 * cr;
777                         double b = y + 1.8556 * cb;
778
779                         putc(clamp(lrint(r)), fp);
780                         putc(clamp(lrint(g)), fp);
781                         putc(clamp(lrint(b)), fp);
782                 }
783         }
784         fclose(fp);
785
786         // For each coefficient, make some tables.
787         size_t extra_bits = 0;
788         for (unsigned i = 0; i < 64; ++i) {
789                 stats[i].clear();
790         }
791         for (unsigned y = 0; y < 8; ++y) {
792                 for (unsigned x = 0; x < 8; ++x) {
793                         SymbolStats &s_luma = stats[pick_stats_for(x, y, false)];
794                         SymbolStats &s_chroma = stats[pick_stats_for(x, y, true)];
795
796                         // Luma
797                         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
798                                 for (unsigned xb = 0; xb < WIDTH; xb += 8) {
799                                         unsigned short k = abs(coeff_y[(yb + y) * WIDTH + (xb + x)]);
800                                         if (k >= ESCAPE_LIMIT) {
801                                                 k = ESCAPE_LIMIT;
802                                                 extra_bits += 12;  // escape this one
803                                         }
804                                         ++s_luma.freqs[(k - 1) & (NUM_SYMS - 1)];
805                                 }
806                         }
807                         // Chroma
808                         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
809                                 for (unsigned xb = 0; xb < WIDTH/2; xb += 8) {
810                                         unsigned short k_cb = abs(coeff_cb[(yb + y) * WIDTH/2 + (xb + x)]);
811                                         unsigned short k_cr = abs(coeff_cr[(yb + y) * WIDTH/2 + (xb + x)]);
812                                         if (k_cb >= ESCAPE_LIMIT) {
813                                                 k_cb = ESCAPE_LIMIT;
814                                                 extra_bits += 12;  // escape this one
815                                         }
816                                         if (k_cr >= ESCAPE_LIMIT) {
817                                                 k_cr = ESCAPE_LIMIT;
818                                                 extra_bits += 12;  // escape this one
819                                         }
820                                         ++s_chroma.freqs[(k_cb - 1) & (NUM_SYMS - 1)];
821                                         ++s_chroma.freqs[(k_cr - 1) & (NUM_SYMS - 1)];
822                                 }
823                         }
824                 }
825         }
826
827 #if FIND_OPTIMAL_STREAM_ASSIGNMENT
828         printf("Luma:\n");
829         find_optimal_stream_assignment(0);
830         printf("Chroma:\n");
831         find_optimal_stream_assignment(64);
832         exit(0);
833 #endif
834
835         for (unsigned i = 0; i < 64; ++i) {
836                 stats[i].freqs[NUM_SYMS - 1] /= 2;  // zero, has no sign bits (yes, this is trickery)
837                 stats[i].normalize_freqs(prob_scale);
838                 stats[i].cum_freqs[NUM_SYMS] += stats[i].freqs[NUM_SYMS - 1];
839                 stats[i].freqs[NUM_SYMS - 1] *= 2;
840         }
841
842         FILE *codedfp = fopen("coded.dat", "wb");
843         if (codedfp == nullptr) {
844                 perror("coded.dat");
845                 exit(1);
846         }
847
848         // TODO: rather gamma-k or something
849         for (unsigned i = 0; i < 64; ++i) {
850                 if (stats[i].cum_freqs[NUM_SYMS] == 0) {
851                         continue;
852                 }
853                 printf("writing table %d\n", i);
854                 for (unsigned j = 0; j < NUM_SYMS; ++j) {
855                         write_varint(stats[i].freqs[j], codedfp);
856                 }
857         }
858
859         RansEncoder rans_encoder;
860
861         size_t tot_bytes = 0;
862
863         // Luma
864         for (unsigned y = 0; y < 8; ++y) {
865                 for (unsigned x = 0; x < 8; ++x) {
866                         SymbolStats &s_luma = stats[pick_stats_for(x, y, false)];
867                         rans_encoder.init_prob(s_luma);
868
869                         // Luma
870                         std::vector<int> lens;
871
872                         // need to reverse later
873                         rans_encoder.clear();
874                         size_t num_bytes = 0;
875                         for (unsigned block_idx = 0; block_idx < NUM_BLOCKS; ++block_idx) {
876                                 unsigned yb = block_idx / WIDTH_BLOCKS;
877                                 unsigned xb = block_idx % WIDTH_BLOCKS;
878
879                                 int k = coeff_y[(yb * 8 + y) * WIDTH + (xb * 8 + x)];
880                                 //printf("encoding coeff %d xb,yb=%d,%d: %d\n", y*8+x, xb, yb, k);
881                                 rans_encoder.encode_coeff(k);
882
883                                 if (block_idx % 320 == 319 || block_idx == NUM_BLOCKS - 1) {
884                                         int l = rans_encoder.save_block(codedfp);
885                                         num_bytes += l;
886                                         lens.push_back(l);
887                                 }
888                         }
889                         tot_bytes += num_bytes;
890                         printf("coeff %d Y': %ld bytes\n", y * 8 + x, num_bytes);
891
892                         double sum_l = 0.0;
893                         for (int l : lens) {
894                                 sum_l += l;
895                         }
896                         double avg_l = sum_l / lens.size();
897
898                         double sum_sql = 0.0;
899                         for (int l : lens) {
900                                 sum_sql += (l - avg_l) * (l - avg_l);
901                         }
902                         double stddev_l = sqrt(sum_sql / (lens.size() - 1));
903                         printf("coeff %d: avg=%.2f bytes, stddev=%.2f bytes\n", y*8+x, avg_l, stddev_l);
904                 }
905         }
906
907         // Cb
908         for (unsigned y = 0; y < 8; ++y) {
909                 for (unsigned x = 0; x < 8; ++x) {
910                         SymbolStats &s_chroma = stats[pick_stats_for(x, y, true)];
911                         rans_encoder.init_prob(s_chroma);
912
913                         rans_encoder.clear();
914                         size_t num_bytes = 0;
915                         for (unsigned block_idx = 0; block_idx < NUM_BLOCKS_CHROMA; ++block_idx) {
916                                 unsigned yb = block_idx / WIDTH_BLOCKS_CHROMA;
917                                 unsigned xb = block_idx % WIDTH_BLOCKS_CHROMA;
918
919                                 int k = coeff_cb[(yb * 8 + y) * WIDTH/2 + (xb * 8 + x)];
920                                 //printf("encoding coeff %d xb,yb=%d,%d: %d\n", y*8+x, xb, yb, k);
921                                 rans_encoder.encode_coeff(k);
922
923                                 if (block_idx % 320 == 319 || block_idx == NUM_BLOCKS - 1) {
924                                         num_bytes += rans_encoder.save_block(codedfp);
925                                 }
926                         }
927                         tot_bytes += num_bytes;
928                         printf("coeff %d Cb: %ld bytes\n", y * 8 + x, num_bytes);
929                 }
930         }
931
932         // Cr
933         for (unsigned y = 0; y < 8; ++y) {
934                 for (unsigned x = 0; x < 8; ++x) {
935                         SymbolStats &s_chroma = stats[pick_stats_for(x, y, true)];
936                         rans_encoder.init_prob(s_chroma);
937
938                         rans_encoder.clear();
939                         size_t num_bytes = 0;
940                         for (unsigned block_idx = 0; block_idx < NUM_BLOCKS_CHROMA; ++block_idx) {
941                                 unsigned yb = block_idx / WIDTH_BLOCKS_CHROMA;
942                                 unsigned xb = block_idx % WIDTH_BLOCKS_CHROMA;
943
944                                 int k = coeff_cr[(yb * 8 + y) * WIDTH/2 + (xb * 8 + x)];
945                                 //printf("encoding coeff %d xb,yb=%d,%d: %d\n", y*8+x, xb, yb, k);
946                                 rans_encoder.encode_coeff(k);
947
948                                 if (block_idx % 320 == 319 || block_idx == NUM_BLOCKS - 1) {
949                                         num_bytes += rans_encoder.save_block(codedfp);
950                                 }
951                         }
952                         tot_bytes += num_bytes;
953                         printf("coeff %d Cr: %ld bytes\n", y * 8 + x, num_bytes);
954                 }
955         }
956
957         printf("%ld bytes + %ld escape bits (%ld) = %ld total bytes\n",
958                 tot_bytes - extra_bits / 8,
959                 extra_bits,
960                 extra_bits / 8,
961                 tot_bytes);
962 }