]> git.sesse.net Git - narabu/blob - qdc.cpp
More fixes of hard-coded values.
[narabu] / qdc.cpp
1 #include <stdio.h>
2 #include <stdint.h>
3 #include <stdlib.h>
4 #include <string.h>
5 #include <assert.h>
6 #include <math.h>
7
8 //#include "ryg_rans/rans64.h"
9 #include "ryg_rans/rans_byte.h"
10 #include "ryg_rans/renormalize.h"
11
12 #include <algorithm>
13 #include <memory>
14 #include <numeric>
15 #include <random>
16 #include <vector>
17 #include <unordered_map>
18
19 #define WIDTH 1280
20 #define HEIGHT 720
21 #define WIDTH_BLOCKS (WIDTH/8)
22 #define WIDTH_BLOCKS_CHROMA (WIDTH/16)
23 #define HEIGHT_BLOCKS (HEIGHT/8)
24 #define NUM_BLOCKS (WIDTH_BLOCKS * HEIGHT_BLOCKS)
25 #define NUM_BLOCKS_CHROMA (WIDTH_BLOCKS_CHROMA * HEIGHT_BLOCKS)
26
27 #define NUM_SYMS 256
28 #define ESCAPE_LIMIT (NUM_SYMS - 1)
29 #define BLOCKS_PER_STREAM 320
30
31 // If you set this to 1, the program will try to optimize the placement
32 // of coefficients to rANS probability distributions. This is randomized,
33 // so you might want to run it a few times.
34 #define FIND_OPTIMAL_STREAM_ASSIGNMENT 0
35 #define NUM_CLUSTERS 4
36
37 static constexpr uint32_t prob_bits = 12;
38 static constexpr uint32_t prob_scale = 1 << prob_bits;
39
40 using namespace std;
41
42 void fdct_int32(short *const In);
43 void idct_int32(short *const In);
44
45 unsigned char rgb[WIDTH * HEIGHT * 3];
46 unsigned char pix_y[WIDTH * HEIGHT];
47 unsigned char pix_cb[(WIDTH/2) * HEIGHT];
48 unsigned char pix_cr[(WIDTH/2) * HEIGHT];
49 unsigned char full_pix_cb[WIDTH * HEIGHT];
50 unsigned char full_pix_cr[WIDTH * HEIGHT];
51 short coeff_y[WIDTH * HEIGHT], coeff_cb[(WIDTH/2) * HEIGHT], coeff_cr[(WIDTH/2) * HEIGHT];
52
53 int clamp(int x)
54 {
55         if (x < 0) return 0;
56         if (x > 255) return 255;
57         return x;
58 }
59
60 static const unsigned char std_luminance_quant_tbl[64] = {
61 #if 0
62         16,  11,  10,  16,  24,  40,  51,  61,
63         12,  12,  14,  19,  26,  58,  60,  55,
64         14,  13,  16,  24,  40,  57,  69,  56,
65         14,  17,  22,  29,  51,  87,  80,  62,
66         18,  22,  37,  56,  68, 109, 103,  77,
67         24,  35,  55,  64,  81, 104, 113,  92,
68         49,  64,  78,  87, 103, 121, 120, 101,
69         72,  92,  95,  98, 112, 100, 103,  99
70 #else
71         // ff_mpeg1_default_intra_matrix
72          8, 16, 19, 22, 26, 27, 29, 34,
73         16, 16, 22, 24, 27, 29, 34, 37,                                                 
74         19, 22, 26, 27, 29, 34, 34, 38,                                                 
75         22, 22, 26, 27, 29, 34, 37, 40,
76         22, 26, 27, 29, 32, 35, 40, 48,
77         26, 27, 29, 32, 35, 40, 48, 58,
78         26, 27, 29, 34, 38, 46, 56, 69,
79         27, 29, 35, 38, 46, 56, 69, 83
80 #endif
81 };
82
83 struct SymbolStats
84 {
85     uint32_t freqs[NUM_SYMS];
86     uint32_t cum_freqs[NUM_SYMS + 1];
87
88     void clear();
89     void calc_cum_freqs();
90     void normalize_freqs(uint32_t target_total);
91 };
92
93 void SymbolStats::clear()
94 {
95     for (int i=0; i < NUM_SYMS; i++)
96         freqs[i] = 0;
97 }
98
99 void SymbolStats::calc_cum_freqs()
100 {
101     cum_freqs[0] = 0;
102     for (int i=0; i < NUM_SYMS; i++)
103         cum_freqs[i+1] = cum_freqs[i] + freqs[i];
104 }
105
106 void SymbolStats::normalize_freqs(uint32_t target_total)
107 {
108     uint64_t real_freq[NUM_SYMS + 1];  // hack
109
110     assert(target_total >= NUM_SYMS);
111
112     calc_cum_freqs();
113     uint32_t cur_total = cum_freqs[NUM_SYMS];
114
115     if (cur_total == 0) return;
116
117     double ideal_cost = 0.0;
118     for (int i = 1; i <= NUM_SYMS; i++)
119     {
120       real_freq[i] = cum_freqs[i] - cum_freqs[i - 1];
121       if (real_freq[i] > 0)
122         ideal_cost -= real_freq[i] * log2(real_freq[i] / double(cur_total));
123     }
124
125     OptimalRenormalize(cum_freqs, NUM_SYMS, prob_scale);
126
127     // calculate updated freqs and make sure we didn't screw anything up
128     assert(cum_freqs[0] == 0 && cum_freqs[NUM_SYMS] == target_total);
129     for (int i=0; i < NUM_SYMS; i++) {
130         if (freqs[i] == 0)
131             assert(cum_freqs[i+1] == cum_freqs[i]);
132         else
133             assert(cum_freqs[i+1] > cum_freqs[i]);
134
135         // calc updated freq
136         freqs[i] = cum_freqs[i+1] - cum_freqs[i];
137     }
138
139     double calc_cost = 0.0;
140     for (int i = 1; i <= NUM_SYMS; i++)
141     {
142       uint64_t freq = cum_freqs[i] - cum_freqs[i - 1];
143       if (real_freq[i] > 0)
144         calc_cost -= real_freq[i] * log2(freq / double(target_total));
145     }
146
147     static double total_loss = 0.0;
148     total_loss += calc_cost - ideal_cost;
149     static double total_loss_with_dp = 0.0;
150         double optimal_cost = 0.0;
151     //total_loss_with_dp += optimal_cost - ideal_cost;
152     printf("ideal cost = %.0f bits, DP cost = %.0f bits, calc cost = %.0f bits (loss = %.2f bytes, total loss = %.2f bytes, total loss with DP = %.2f bytes)\n",
153                 ideal_cost, optimal_cost,
154                  calc_cost, (calc_cost - ideal_cost) / 8.0, total_loss / 8.0, total_loss_with_dp / 8.0);
155 }
156
157 SymbolStats stats[128];
158
159 #if FIND_OPTIMAL_STREAM_ASSIGNMENT
160 // Distance from one stream to the other, based on a hacked-up K-L divergence.
161 float kl_dist[64][64];
162 #endif
163
164 const int luma_mapping[64] = {
165         0, 0, 1, 1, 2, 2, 3, 3,
166         0, 0, 1, 2, 2, 2, 3, 3,
167         1, 1, 2, 2, 2, 3, 3, 3,
168         1, 1, 2, 2, 2, 3, 3, 3,
169         1, 2, 2, 2, 2, 3, 3, 3,
170         2, 2, 2, 2, 3, 3, 3, 3,
171         2, 2, 3, 3, 3, 3, 3, 3,
172         3, 3, 3, 3, 3, 3, 3, 3,
173 };
174 const int chroma_mapping[64] = {
175         0, 1, 1, 2, 2, 2, 3, 3,
176         1, 1, 2, 2, 2, 3, 3, 3,
177         2, 2, 2, 2, 3, 3, 3, 3,
178         2, 2, 2, 3, 3, 3, 3, 3,
179         2, 3, 3, 3, 3, 3, 3, 3,
180         3, 3, 3, 3, 3, 3, 3, 3,
181         3, 3, 3, 3, 3, 3, 3, 3,
182         3, 3, 3, 3, 3, 3, 3, 3,
183 };
184
185 int pick_stats_for(int x, int y, bool is_chroma)
186 {
187 #if FIND_OPTIMAL_STREAM_ASSIGNMENT
188         return y * 8 + x + is_chroma * 64;
189 #else
190         if (is_chroma) {
191                 return chroma_mapping[y * 8 + x] + 4;
192         } else {
193                 return luma_mapping[y * 8 + x];
194         }
195 #endif
196 }
197                 
198
199 void write_varint(int x, FILE *fp)
200 {
201         while (x >= 128) {
202                 putc((x & 0x7f) | 0x80, fp);
203                 x >>= 7;
204         }
205         putc(x, fp);
206 }
207
208 class RansEncoder {
209 public:
210         RansEncoder()
211         {
212                 out_buf.reset(new uint8_t[out_max_size]);
213                 clear();
214         }
215
216         void init_prob(SymbolStats &s)
217         {
218                 for (int i = 0; i < NUM_SYMS; i++) {
219                         //printf("%d: cumfreqs=%d freqs=%d prob_bits=%d\n", i, s.cum_freqs[i], s.freqs[i], prob_bits + 1);
220                         RansEncSymbolInit(&esyms[i], s.cum_freqs[i], s.freqs[i], prob_bits + 1);
221                 }
222                 sign_bias = s.cum_freqs[NUM_SYMS];
223         }
224
225         void clear()
226         {
227                 out_end = out_buf.get() + out_max_size;
228                 ptr = out_end; // *end* of output buffer
229                 RansEncInit(&rans);
230         }
231
232         uint32_t save_block(FILE *codedfp)  // Returns number of bytes.
233         {
234                 RansEncFlush(&rans, &ptr);
235                 //printf("post-flush = %08x\n", rans);
236
237                 uint32_t num_rans_bytes = out_end - ptr;
238                 if (num_rans_bytes == last_block.size() &&
239                     memcmp(last_block.data(), ptr, last_block.size()) == 0) {
240                         write_varint(0, codedfp);
241                         clear();
242                         return 1;
243                 } else {
244                         last_block = string((const char *)ptr, num_rans_bytes);
245                 }
246
247                 write_varint(num_rans_bytes, codedfp);
248                 //fwrite(&num_rans_bytes, 1, 4, codedfp);
249                 fwrite(ptr, 1, num_rans_bytes, codedfp);
250
251                 //printf("first rANS bytes: %02x %02x %02x %02x %02x %02x %02x %02x\n", ptr[0], ptr[1], ptr[2], ptr[3], ptr[4], ptr[5], ptr[6], ptr[7]);
252
253
254                 clear();
255
256                 //printf("Saving block: %d rANS bytes\n", num_rans_bytes);
257                 return num_rans_bytes;
258                 //return num_rans_bytes;
259         }
260
261         void encode_coeff(short signed_k)
262         {
263                 //printf("encoding coeff %d (sym %d), rans before encoding = %08x\n", signed_k, ((abs(signed_k) - 1) & 255), rans);
264                 unsigned short k = abs(signed_k);
265                 if (k >= ESCAPE_LIMIT) {
266                         // Put the coefficient as a 1/(2^12) symbol _before_
267                         // the 255 coefficient, since the decoder will read the
268                         // 255 coefficient first.
269                         RansEncPut(&rans, &ptr, k, 1, prob_bits);
270                         k = ESCAPE_LIMIT;
271                 }
272                 RansEncPutSymbol(&rans, &ptr, &esyms[(k - 1) & (NUM_SYMS - 1)]);
273                 if (signed_k < 0) {
274                         rans += sign_bias;
275                 }
276         }
277
278 private:
279         static constexpr size_t out_max_size = 32 << 20; // 32 MB.
280         static constexpr size_t max_num_sign = 1048576;  // Way too big. And actually bytes.
281
282         unique_ptr<uint8_t[]> out_buf;
283         uint8_t *out_end;
284         uint8_t *ptr;
285         RansState rans;
286         RansEncSymbol esyms[NUM_SYMS];
287         uint32_t sign_bias;
288
289         std::string last_block;
290 };
291
292 static constexpr int dc_scalefac = 8;  // Matches the FDCT's gain.
293 static constexpr double quant_scalefac = 4.0;  // whatever?
294
295 static inline int quantize(int f, int coeff_idx)
296 {
297         if (coeff_idx == 0) {
298                 return f / dc_scalefac;
299         }
300         if (f == 0) {
301                 return 0;
302         }
303
304         const int w = std_luminance_quant_tbl[coeff_idx];
305         const int s = quant_scalefac;
306         int sign_f = (f > 0) ? 1 : -1;
307         return (32 * f + sign_f * w * s) / (2 * w * s);
308 }
309
310 static inline int unquantize(int qf, int coeff_idx)
311 {
312         if (coeff_idx == 0) {
313                 return qf * dc_scalefac;
314         }
315         if (qf == 0) {
316                 return 0;
317         }
318
319         const int w = std_luminance_quant_tbl[coeff_idx];
320         const int s = quant_scalefac;
321         return (2 * qf * w * s) / 32;
322 }
323
324 void readpix(unsigned char *ptr, const char *filename)
325 {
326         FILE *fp = fopen(filename, "rb");
327         if (fp == nullptr) {
328                 perror(filename);
329                 exit(1);
330         }
331
332         fseek(fp, 0, SEEK_END);
333         long len = ftell(fp);
334         assert(len >= WIDTH * HEIGHT * 3);
335         fseek(fp, len - WIDTH * HEIGHT * 3, SEEK_SET);
336
337         fread(ptr, 1, WIDTH * HEIGHT * 3, fp);
338         fclose(fp);
339 }
340
341 void convert_ycbcr()
342 {
343         double coeff[3] = { 0.2126, 0.7152, 0.0722 };  // sum = 1.0
344         double cb_fac = 1.0 / (coeff[0] + coeff[1] + 1.0f - coeff[2]);  // 0.539
345         double cr_fac = 1.0 / (1.0f - coeff[0] + coeff[1] + coeff[2]);  // 0.635 
346
347         unique_ptr<float[]> temp_cb(new float[WIDTH * HEIGHT]);
348         unique_ptr<float[]> temp_cr(new float[WIDTH * HEIGHT]);
349         for (unsigned yb = 0; yb < HEIGHT; ++yb) {
350                 for (unsigned xb = 0; xb < WIDTH; ++xb) {
351                         int r = rgb[((yb * WIDTH) + xb) * 3 + 0];
352                         int g = rgb[((yb * WIDTH) + xb) * 3 + 1];
353                         int b = rgb[((yb * WIDTH) + xb) * 3 + 2];
354                         double y = std::min(std::max(coeff[0] * r + coeff[1] * g + coeff[2] * b, 0.0), 255.0);
355                         double cb = (b - y) * cb_fac + 128.0;
356                         double cr = (r - y) * cr_fac + 128.0;
357                         pix_y[(yb * WIDTH) + xb] = lrint(y);
358                         temp_cb[(yb * WIDTH) + xb] = cb;
359                         temp_cr[(yb * WIDTH) + xb] = cr;
360                         full_pix_cb[(yb * WIDTH) + xb] = lrint(std::min(std::max(cb, 0.0), 255.0));
361                         full_pix_cr[(yb * WIDTH) + xb] = lrint(std::min(std::max(cr, 0.0), 255.0));
362                 }
363         }
364
365         // Simple 4:2:2 subsampling with left convention.
366         for (unsigned yb = 0; yb < HEIGHT; ++yb) {
367                 for (unsigned xb = 0; xb < WIDTH / 2; ++xb) {
368                         int c0 = yb * WIDTH + std::max(int(xb) * 2 - 1, 0);
369                         int c1 = yb * WIDTH + xb * 2;
370                         int c2 = yb * WIDTH + xb * 2 + 1;
371                         
372                         double cb = 0.25 * temp_cb[c0] + 0.5 * temp_cb[c1] + 0.25 * temp_cb[c2];
373                         double cr = 0.25 * temp_cr[c0] + 0.5 * temp_cr[c1] + 0.25 * temp_cr[c2];
374                         cb = std::min(std::max(cb, 0.0), 255.0);
375                         cr = std::min(std::max(cr, 0.0), 255.0);
376                         pix_cb[(yb * WIDTH/2) + xb] = lrint(cb);
377                         pix_cr[(yb * WIDTH/2) + xb] = lrint(cr);
378                 }
379         }
380 }
381
382 #if FIND_OPTIMAL_STREAM_ASSIGNMENT
383 double find_best_assignment(const int *medoids, int *assignment)
384 {
385         double current_score = 0.0;
386         for (int i = 0; i < 64; ++i) {
387                 int best_medoid = medoids[0];
388                 float best_medoid_score = kl_dist[i][medoids[0]];
389                 for (int j = 1; j < NUM_CLUSTERS; ++j) {
390                         if (kl_dist[i][medoids[j]] < best_medoid_score) {
391                                 best_medoid = medoids[j];
392                                 best_medoid_score = kl_dist[i][medoids[j]];
393                         }
394                 }
395                 assignment[i] = best_medoid;
396                 current_score += best_medoid_score;
397         }
398         return current_score;
399 }
400
401 void find_optimal_stream_assignment(int base)
402 {
403         double inv_sum[64];
404         for (unsigned i = 0; i < 64; ++i) {
405                 double s = 0.0;
406                 for (unsigned k = 0; k < NUM_SYMS; ++k) {
407                         s += stats[i + base].freqs[k] + 0.5;
408                 }
409                 inv_sum[i] = 1.0 / s;
410         }
411
412         for (unsigned i = 0; i < 64; ++i) {
413                 for (unsigned j = 0; j < 64; ++j) {
414                         double d = 0.0;
415                         for (unsigned k = 0; k < NUM_SYMS; ++k) {
416                                 double p1 = (stats[i + base].freqs[k] + 0.5) * inv_sum[i];
417                                 double p2 = (stats[j + base].freqs[k] + 0.5) * inv_sum[j];
418
419                                 // K-L divergence is asymmetric; this is a hack.
420                                 d += p1 * log(p1 / p2);
421                                 d += p2 * log(p2 / p1);
422                         }
423                         kl_dist[i][j] = d;
424                         //printf("%.3f ", d);
425                 }
426                 //printf("\n");
427         }
428
429         // k-medoids init
430         int medoids[64];  // only the first NUM_CLUSTERS matter
431         bool is_medoid[64] = { false };
432         std::iota(medoids, medoids + 64, 0);
433         std::random_device rd;
434         std::mt19937 g(rd());
435         std::shuffle(medoids, medoids + 64, g);
436         for (int i = 0; i < NUM_CLUSTERS; ++i) {
437                 printf("%d ", medoids[i]);
438                 is_medoid[medoids[i]] = true;
439         }
440         printf("\n");
441
442         // assign each data point to the closest medoid
443         int assignment[64];
444         double current_score = find_best_assignment(medoids, assignment);
445
446         for (int i = 0; i < 1000; ++i) {
447                 printf("iter %d\n", i);
448                 bool any_changed = false;
449                 for (int m = 0; m < NUM_CLUSTERS; ++m) {
450                         for (int o = 0; o < 64; ++o) {
451                                 if (is_medoid[o]) continue;
452                                 int old_medoid = medoids[m];
453                                 medoids[m] = o;
454
455                                 int new_assignment[64];
456                                 double candidate_score = find_best_assignment(medoids, new_assignment);
457
458                                 if (candidate_score < current_score) {
459                                         current_score = candidate_score;
460                                         memcpy(assignment, new_assignment, sizeof(assignment));
461
462                                         is_medoid[old_medoid] = false;
463                                         is_medoid[medoids[m]] = true;
464                                         printf("%f: ", current_score);
465                                         for (int i = 0; i < 64; ++i) {
466                                                 printf("%d ", assignment[i]);
467                                         }
468                                         printf("\n");
469                                         any_changed = true;
470                                 } else {
471                                         medoids[m] = old_medoid;
472                                 }
473                         }
474                 }
475                 if (!any_changed) break;
476         }
477         printf("\n");
478         std::unordered_map<int, int> rmap;
479         for (int i = 0; i < 64; ++i) {
480                 if (i % 8 == 0) printf("\n");
481                 if (!rmap.count(assignment[i])) {
482                         rmap.emplace(assignment[i], rmap.size());
483                 }
484                 printf("%d, ", rmap[assignment[i]]);
485         }
486         printf("\n");
487 }
488 #endif
489
490 int main(int argc, char **argv)
491 {
492         if (argc >= 2)
493                 readpix(rgb, argv[1]);
494         else
495                 readpix(rgb, "color.pnm");
496         convert_ycbcr();
497
498         double sum_sq_err = 0.0;
499         //double last_cb_cfl_fac = 0.0;
500         //double last_cr_cfl_fac = 0.0;
501
502         int max_val_x[8] = {0}, min_val_x[8] = {0};
503         int max_val_y[8] = {0}, min_val_y[8] = {0};
504
505         // DCT and quantize luma
506         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
507                 for (unsigned xb = 0; xb < WIDTH; xb += 8) {
508                         // Read one block
509                         short in_y[64];
510                         for (unsigned y = 0; y < 8; ++y) {
511                                 for (unsigned x = 0; x < 8; ++x) {
512                                         in_y[y * 8 + x] = pix_y[(yb + y) * WIDTH + (xb + x)];
513                                 }
514                         }
515
516                         // FDCT it
517                         fdct_int32(in_y);
518
519                         for (unsigned y = 0; y < 8; ++y) {
520                                 for (unsigned x = 0; x < 8; ++x) {
521                                         int coeff_idx = y * 8 + x;
522                                         int k = quantize(in_y[coeff_idx], coeff_idx);
523                                         coeff_y[(yb + y) * WIDTH + (xb + x)] = k;
524
525                                         max_val_x[x] = std::max(max_val_x[x], k);
526                                         min_val_x[x] = std::min(min_val_x[x], k);
527                                         max_val_y[y] = std::max(max_val_y[y], k);
528                                         min_val_y[y] = std::min(min_val_y[y], k);
529
530                                         // Store back for reconstruction / PSNR calculation
531                                         in_y[coeff_idx] = unquantize(k, coeff_idx);
532                                 }
533                         }
534
535                         idct_int32(in_y);
536
537                         for (unsigned y = 0; y < 8; ++y) {
538                                 for (unsigned x = 0; x < 8; ++x) {
539                                         int k = clamp(in_y[y * 8 + x]);
540                                         uint8_t *ptr = &pix_y[(yb + y) * WIDTH + (xb + x)];
541                                         sum_sq_err += (*ptr - k) * (*ptr - k);
542                                         *ptr = k;
543                                 }
544                         }
545                 }
546         }
547         double mse = sum_sq_err / double(WIDTH * HEIGHT);
548         double psnr_db = 20 * log10(255.0 / sqrt(mse));
549         printf("psnr = %.2f dB\n", psnr_db);
550
551         //double chroma_energy = 0.0, chroma_energy_pred = 0.0;
552
553         // DCT and quantize chroma
554         //double last_cb_cfl_fac = 0.0, last_cr_cfl_fac = 0.0;
555         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
556                 for (unsigned xb = 0; xb < WIDTH/2; xb += 8) {
557 #if 0
558                         // TF switch: Two 8x8 luma blocks -> one 16x8 block, then drop high frequencies
559                         printf("in blocks:\n");
560                         for (unsigned y = 0; y < 8; ++y) {
561                                 for (unsigned x = 0; x < 8; ++x) {
562                                         short a = coeff_y[(yb + y) * WIDTH + (xb*2 + x)];
563                                         printf(" %4d", a);
564                                 }
565                                 printf(" | ");
566                                 for (unsigned x = 0; x < 8; ++x) {
567                                         short b = coeff_y[(yb + y) * WIDTH + (xb*2 + x + 8)];
568                                         printf(" %4d", b);
569                                 }
570                                 printf("\n");
571                         }
572
573                         short in_y[64];
574                         for (unsigned y = 0; y < 8; ++y) {
575                                 for (unsigned x = 0; x < 4; ++x) {
576                                         short a = coeff_y[(yb + y) * WIDTH + (xb*2 + x)];
577                                         short b = coeff_y[(yb + y) * WIDTH + (xb*2 + x + 8)];
578                                         b = a - b;
579                                         a = 2 * a - b;
580                                         in_y[y * 8 + x * 2 + 0] = a;
581                                         in_y[y * 8 + x * 2 + 1] = b;
582                                 }
583                         }
584
585                         printf("tf-ed block:\n");
586                         for (unsigned y = 0; y < 8; ++y) {
587                                 for (unsigned x = 0; x < 8; ++x) {
588                                         short a = in_y[y * 8 + x];
589                                         printf(" %4d", a);
590                                 }
591                                 printf("\n");
592                         }
593 #else
594                         // Read Y block with no tf switch (from reconstructed luma)
595                         short in_y[64];
596                         for (unsigned y = 0; y < 8; ++y) {
597                                 for (unsigned x = 0; x < 8; ++x) {
598                                         in_y[y * 8 + x] = pix_y[(yb + y) * (WIDTH) + (xb + x) * 2];
599                                 }
600                         }
601                         fdct_int32(in_y);
602 #endif
603
604                         // Read one block
605                         short in_cb[64], in_cr[64];
606                         for (unsigned y = 0; y < 8; ++y) {
607                                 for (unsigned x = 0; x < 8; ++x) {
608                                         in_cb[y * 8 + x] = pix_cb[(yb + y) * (WIDTH/2) + (xb + x)];
609                                         in_cr[y * 8 + x] = pix_cr[(yb + y) * (WIDTH/2) + (xb + x)];
610                                 }
611                         }
612
613                         // FDCT it
614                         fdct_int32(in_cb);
615                         fdct_int32(in_cr);
616
617 #if 0
618                         // Chroma from luma
619                         double x0 = in_y[1];
620                         double x1 = in_y[8];
621                         double x2 = in_y[9];
622                         double denom = (x0 * x0 + x1 * x1 + x2 * x2);
623                         //double denom = (x1 * x1);
624         
625                         double y0 = in_cb[1];
626                         double y1 = in_cb[8];
627                         double y2 = in_cb[9];
628                         double cb_cfl_fac = (x0 * y0 + x1 * y1 + x2 * y2) / denom;
629                         //double cb_cfl_fac = (x1 * y1) / denom;
630
631                         for (unsigned y = 0; y < 8; ++y) {
632                                 for (unsigned x = 0; x < 8; ++x) {
633                                         short a = in_y[y * 8 + x];
634                                         printf(" %4d", a);
635                                 }
636                                 printf(" | ");
637                                 for (unsigned x = 0; x < 8; ++x) {
638                                         short a = in_cb[y * 8 + x];
639                                         printf(" %4d", a);
640                                 }
641                                 printf("\n");
642                         }
643                         printf("(%d,%d,%d) -> (%d,%d,%d) gives %f\n",
644                                 in_y[1], in_y[8], in_y[9], 
645                                 in_cb[1], in_cb[8], in_cb[9],
646                                 cb_cfl_fac);
647
648                         y0 = in_cr[1];
649                         y1 = in_cr[8];
650                         y2 = in_cr[9];
651                         double cr_cfl_fac = (x0 * y0 + x1 * y1 + x2 * y2) / denom;
652                         //double cr_cfl_fac = (x1 * y1) / denom;
653                         printf("cb CfL = %7.3f  dc = %5d    cr CfL = %7.3f  dc = %d\n",
654                                 cb_cfl_fac, in_cb[0] - in_y[0],
655                                 cr_cfl_fac, in_cr[0] - in_y[0]);
656
657                         if (denom == 0.0) { cb_cfl_fac = cr_cfl_fac = 0.0; }
658
659                         // CHEAT
660                         //last_cb_cfl_fac = cb_cfl_fac;
661                         //last_cr_cfl_fac = cr_cfl_fac;
662
663                         for (unsigned coeff_idx = 1; coeff_idx < 64; ++coeff_idx) {
664                                 //printf("%2d: cb = %3d prediction = %f * %3d = %7.3f\n", coeff_idx, in_cb[coeff_idx], last_cb_cfl_fac, in_y[coeff_idx], last_cb_cfl_fac * in_y[coeff_idx]);
665                                 //printf("%2d: cr = %3d prediction = %f * %3d = %7.3f\n", coeff_idx, in_cr[coeff_idx], last_cr_cfl_fac, in_y[coeff_idx], last_cr_cfl_fac * in_y[coeff_idx]);
666                                 double cb_pred = last_cb_cfl_fac * in_y[coeff_idx];
667                                 chroma_energy += in_cb[coeff_idx] * in_cb[coeff_idx];
668                                 chroma_energy_pred += (in_cb[coeff_idx] - cb_pred) * (in_cb[coeff_idx] - cb_pred);
669
670                                 //in_cb[coeff_idx] -= lrint(last_cb_cfl_fac * in_y[coeff_idx]);
671                                 //in_cr[coeff_idx] -= lrint(last_cr_cfl_fac * in_y[coeff_idx]);
672                                 //in_cr[coeff_idx] -= lrint(last_cr_cfl_fac * in_y[coeff_idx]);
673                                 //in_cb[coeff_idx] = lrint(in_y[coeff_idx] * (1.0 / sqrt(2)));
674                                 //in_cr[coeff_idx] = lrint(in_y[coeff_idx] * (1.0 / sqrt(2)));
675                                 //in_cb[coeff_idx] = lrint(in_y[coeff_idx]);
676                                 //in_cr[coeff_idx] = lrint(in_y[coeff_idx]);
677                         }
678                         //in_cb[0] += 1024;
679                         //in_cr[0] += 1024;
680                         //in_cb[0] -= in_y[0];
681                         //in_cr[0] -= in_y[0];
682 #endif
683
684                         for (unsigned y = 0; y < 8; ++y) {
685                                 for (unsigned x = 0; x < 8; ++x) {
686                                         int coeff_idx = y * 8 + x;
687                                         int k_cb = quantize(in_cb[coeff_idx], coeff_idx);
688                                         coeff_cb[(yb + y) * (WIDTH/2) + (xb + x)] = k_cb;
689                                         int k_cr = quantize(in_cr[coeff_idx], coeff_idx);
690                                         coeff_cr[(yb + y) * (WIDTH/2) + (xb + x)] = k_cr;
691
692                                         // Store back for reconstruction / PSNR calculation
693                                         in_cb[coeff_idx] = unquantize(k_cb, coeff_idx);
694                                         in_cr[coeff_idx] = unquantize(k_cr, coeff_idx);
695                                 }
696                         }
697
698                         idct_int32(in_y);  // DEBUG
699                         idct_int32(in_cb);
700                         idct_int32(in_cr);
701
702                         for (unsigned y = 0; y < 8; ++y) {
703                                 for (unsigned x = 0; x < 8; ++x) {
704                                         pix_cb[(yb + y) * (WIDTH/2) + (xb + x)] = clamp(in_cb[y * 8 + x]);
705                                         pix_cr[(yb + y) * (WIDTH/2) + (xb + x)] = clamp(in_cr[y * 8 + x]);
706
707                         //              pix_cb[(yb + y) * (WIDTH/2) + (xb + x)] = in_y[y * 8 + x];
708                         //              pix_cr[(yb + y) * (WIDTH/2) + (xb + x)] = in_y[y * 8 + x];
709                                 }
710                         }
711
712 #if 0
713                         last_cb_cfl_fac = cb_cfl_fac;
714                         last_cr_cfl_fac = cr_cfl_fac;
715 #endif
716                 }
717         }
718
719 #if 0
720         printf("chroma_energy = %f, with_pred = %f\n",
721                 chroma_energy / (WIDTH * HEIGHT), chroma_energy_pred / (WIDTH * HEIGHT));
722 #endif
723
724         // DC coefficient pred from the right to left (within each slice)
725         for (unsigned block_idx = 0; block_idx < NUM_BLOCKS; block_idx += BLOCKS_PER_STREAM) {
726                 int prev_k = 128;
727
728                 for (unsigned subblock_idx = BLOCKS_PER_STREAM; subblock_idx --> 0; ) {
729                         unsigned yb = (block_idx + subblock_idx) / WIDTH_BLOCKS;
730                         unsigned xb = (block_idx + subblock_idx) % WIDTH_BLOCKS;
731                         int k = coeff_y[(yb * 8) * WIDTH + (xb * 8)];
732
733                         coeff_y[(yb * 8) * WIDTH + (xb * 8)] = k - prev_k;
734
735                         prev_k = k;
736                 }
737         }
738         for (unsigned block_idx = 0; block_idx < NUM_BLOCKS_CHROMA; block_idx += BLOCKS_PER_STREAM) {
739                 int prev_k_cb = 0;
740                 int prev_k_cr = 0;
741
742                 for (unsigned subblock_idx = BLOCKS_PER_STREAM; subblock_idx --> 0; ) {
743                         unsigned yb = (block_idx + subblock_idx) / WIDTH_BLOCKS_CHROMA;
744                         unsigned xb = (block_idx + subblock_idx) % WIDTH_BLOCKS_CHROMA;
745                         int k_cb = coeff_cb[(yb * 8) * WIDTH/2 + (xb * 8)];
746                         int k_cr = coeff_cr[(yb * 8) * WIDTH/2 + (xb * 8)];
747
748                         coeff_cb[(yb * 8) * WIDTH/2 + (xb * 8)] = k_cb - prev_k_cb;
749                         coeff_cr[(yb * 8) * WIDTH/2 + (xb * 8)] = k_cr - prev_k_cr;
750
751                         prev_k_cb = k_cb;
752                         prev_k_cr = k_cr;
753                 }
754         }
755
756         FILE *fp = fopen("reconstructed.pgm", "wb");
757         fprintf(fp, "P5\n%d %d\n255\n", WIDTH, HEIGHT);
758         fwrite(pix_y, 1, WIDTH * HEIGHT, fp);
759         fclose(fp);
760
761         fp = fopen("reconstructed.pnm", "wb");
762         fprintf(fp, "P6\n%d %d\n255\n", WIDTH, HEIGHT);
763         for (unsigned yb = 0; yb < HEIGHT; ++yb) {
764                 for (unsigned xb = 0; xb < WIDTH; ++xb) {
765                         int y = pix_y[(yb * WIDTH) + xb];
766                         int cb, cr;
767                         int c0 = yb * (WIDTH/2) + xb/2;
768                         if (xb % 2 == 0) {
769                                 cb = pix_cb[c0] - 128.0;
770                                 cr = pix_cr[c0] - 128.0;
771                         } else {
772                                 int c1 = yb * (WIDTH/2) + std::min<int>(xb/2 + 1, WIDTH/2 - 1);
773                                 cb = 0.5 * (pix_cb[c0] + pix_cb[c1]) - 128.0;
774                                 cr = 0.5 * (pix_cr[c0] + pix_cr[c1]) - 128.0;
775                         }
776
777                         double r = y + 1.5748 * cr;
778                         double g = y - 0.1873 * cb - 0.4681 * cr;
779                         double b = y + 1.8556 * cb;
780
781                         putc(clamp(lrint(r)), fp);
782                         putc(clamp(lrint(g)), fp);
783                         putc(clamp(lrint(b)), fp);
784                 }
785         }
786         fclose(fp);
787
788         // For each coefficient, make some tables.
789         size_t extra_bits = 0;
790         for (unsigned i = 0; i < 64; ++i) {
791                 stats[i].clear();
792         }
793         for (unsigned y = 0; y < 8; ++y) {
794                 for (unsigned x = 0; x < 8; ++x) {
795                         SymbolStats &s_luma = stats[pick_stats_for(x, y, false)];
796                         SymbolStats &s_chroma = stats[pick_stats_for(x, y, true)];
797
798                         // Luma
799                         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
800                                 for (unsigned xb = 0; xb < WIDTH; xb += 8) {
801                                         unsigned short k = abs(coeff_y[(yb + y) * WIDTH + (xb + x)]);
802                                         if (k >= ESCAPE_LIMIT) {
803                                                 k = ESCAPE_LIMIT;
804                                                 extra_bits += 12;  // escape this one
805                                         }
806                                         ++s_luma.freqs[(k - 1) & (NUM_SYMS - 1)];
807                                 }
808                         }
809                         // Chroma
810                         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
811                                 for (unsigned xb = 0; xb < WIDTH/2; xb += 8) {
812                                         unsigned short k_cb = abs(coeff_cb[(yb + y) * WIDTH/2 + (xb + x)]);
813                                         unsigned short k_cr = abs(coeff_cr[(yb + y) * WIDTH/2 + (xb + x)]);
814                                         if (k_cb >= ESCAPE_LIMIT) {
815                                                 k_cb = ESCAPE_LIMIT;
816                                                 extra_bits += 12;  // escape this one
817                                         }
818                                         if (k_cr >= ESCAPE_LIMIT) {
819                                                 k_cr = ESCAPE_LIMIT;
820                                                 extra_bits += 12;  // escape this one
821                                         }
822                                         ++s_chroma.freqs[(k_cb - 1) & (NUM_SYMS - 1)];
823                                         ++s_chroma.freqs[(k_cr - 1) & (NUM_SYMS - 1)];
824                                 }
825                         }
826                 }
827         }
828
829 #if FIND_OPTIMAL_STREAM_ASSIGNMENT
830         printf("Luma:\n");
831         find_optimal_stream_assignment(0);
832         printf("Chroma:\n");
833         find_optimal_stream_assignment(64);
834         exit(0);
835 #endif
836
837         for (unsigned i = 0; i < 64; ++i) {
838                 stats[i].freqs[NUM_SYMS - 1] /= 2;  // zero, has no sign bits (yes, this is trickery)
839                 stats[i].normalize_freqs(prob_scale);
840                 stats[i].cum_freqs[NUM_SYMS] += stats[i].freqs[NUM_SYMS - 1];
841                 stats[i].freqs[NUM_SYMS - 1] *= 2;
842         }
843
844         FILE *codedfp = fopen("coded.dat", "wb");
845         if (codedfp == nullptr) {
846                 perror("coded.dat");
847                 exit(1);
848         }
849
850         // TODO: rather gamma-k or something
851         for (unsigned i = 0; i < 64; ++i) {
852                 if (stats[i].cum_freqs[NUM_SYMS] == 0) {
853                         continue;
854                 }
855                 printf("writing table %d\n", i);
856                 for (unsigned j = 0; j < NUM_SYMS; ++j) {
857                         write_varint(stats[i].freqs[j], codedfp);
858                 }
859         }
860
861         RansEncoder rans_encoder;
862
863         size_t tot_bytes = 0;
864
865         // Luma
866         for (unsigned y = 0; y < 8; ++y) {
867                 for (unsigned x = 0; x < 8; ++x) {
868                         SymbolStats &s_luma = stats[pick_stats_for(x, y, false)];
869                         rans_encoder.init_prob(s_luma);
870
871                         // Luma
872                         std::vector<int> lens;
873
874                         // need to reverse later
875                         rans_encoder.clear();
876                         size_t num_bytes = 0;
877                         for (unsigned block_idx = 0; block_idx < NUM_BLOCKS; ++block_idx) {
878                                 unsigned yb = block_idx / WIDTH_BLOCKS;
879                                 unsigned xb = block_idx % WIDTH_BLOCKS;
880
881                                 int k = coeff_y[(yb * 8 + y) * WIDTH + (xb * 8 + x)];
882                                 //printf("encoding coeff %d xb,yb=%d,%d: %d\n", y*8+x, xb, yb, k);
883                                 rans_encoder.encode_coeff(k);
884
885                                 if (block_idx % BLOCKS_PER_STREAM == (BLOCKS_PER_STREAM - 1) || block_idx == NUM_BLOCKS - 1) {
886                                         int l = rans_encoder.save_block(codedfp);
887                                         num_bytes += l;
888                                         lens.push_back(l);
889                                 }
890                         }
891                         tot_bytes += num_bytes;
892                         printf("coeff %d Y': %ld bytes\n", y * 8 + x, num_bytes);
893
894                         double sum_l = 0.0;
895                         for (int l : lens) {
896                                 sum_l += l;
897                         }
898                         double avg_l = sum_l / lens.size();
899
900                         double sum_sql = 0.0;
901                         for (int l : lens) {
902                                 sum_sql += (l - avg_l) * (l - avg_l);
903                         }
904                         double stddev_l = sqrt(sum_sql / (lens.size() - 1));
905                         printf("coeff %d: avg=%.2f bytes, stddev=%.2f bytes\n", y*8+x, avg_l, stddev_l);
906                 }
907         }
908
909         // Cb
910         for (unsigned y = 0; y < 8; ++y) {
911                 for (unsigned x = 0; x < 8; ++x) {
912                         SymbolStats &s_chroma = stats[pick_stats_for(x, y, true)];
913                         rans_encoder.init_prob(s_chroma);
914
915                         rans_encoder.clear();
916                         size_t num_bytes = 0;
917                         for (unsigned block_idx = 0; block_idx < NUM_BLOCKS_CHROMA; ++block_idx) {
918                                 unsigned yb = block_idx / WIDTH_BLOCKS_CHROMA;
919                                 unsigned xb = block_idx % WIDTH_BLOCKS_CHROMA;
920
921                                 int k = coeff_cb[(yb * 8 + y) * WIDTH/2 + (xb * 8 + x)];
922                                 //printf("encoding coeff %d xb,yb=%d,%d: %d\n", y*8+x, xb, yb, k);
923                                 rans_encoder.encode_coeff(k);
924
925                                 if (block_idx % BLOCKS_PER_STREAM == (BLOCKS_PER_STREAM - 1) || block_idx == NUM_BLOCKS - 1) {
926                                         num_bytes += rans_encoder.save_block(codedfp);
927                                 }
928                         }
929                         tot_bytes += num_bytes;
930                         printf("coeff %d Cb: %ld bytes\n", y * 8 + x, num_bytes);
931                 }
932         }
933
934         // Cr
935         for (unsigned y = 0; y < 8; ++y) {
936                 for (unsigned x = 0; x < 8; ++x) {
937                         SymbolStats &s_chroma = stats[pick_stats_for(x, y, true)];
938                         rans_encoder.init_prob(s_chroma);
939
940                         rans_encoder.clear();
941                         size_t num_bytes = 0;
942                         for (unsigned block_idx = 0; block_idx < NUM_BLOCKS_CHROMA; ++block_idx) {
943                                 unsigned yb = block_idx / WIDTH_BLOCKS_CHROMA;
944                                 unsigned xb = block_idx % WIDTH_BLOCKS_CHROMA;
945
946                                 int k = coeff_cr[(yb * 8 + y) * WIDTH/2 + (xb * 8 + x)];
947                                 //printf("encoding coeff %d xb,yb=%d,%d: %d\n", y*8+x, xb, yb, k);
948                                 rans_encoder.encode_coeff(k);
949
950                                 if (block_idx % BLOCKS_PER_STREAM == (BLOCKS_PER_STREAM - 1) || block_idx == NUM_BLOCKS - 1) {
951                                         num_bytes += rans_encoder.save_block(codedfp);
952                                 }
953                         }
954                         tot_bytes += num_bytes;
955                         printf("coeff %d Cr: %ld bytes\n", y * 8 + x, num_bytes);
956                 }
957         }
958
959         printf("%ld bytes + %ld escape bits (%ld) = %ld total bytes\n",
960                 tot_bytes - extra_bits / 8,
961                 extra_bits,
962                 extra_bits / 8,
963                 tot_bytes);
964
965 #if 0
966         printf("Max coefficient ranges (as a function of x):\n\n");
967         for (unsigned x = 0; x < 8; ++x) {
968                 int range = std::max(max_val_x[x], -min_val_x[x]);
969                 printf("  [%4d, %4d] (%.2f bits)\n", min_val_x[x], max_val_x[x], log2(range * 2 + 1));
970         }
971
972         printf("Max coefficient ranges (as a function of y):\n\n");
973         for (unsigned y = 0; y < 8; ++y) {
974                 int range = std::max(max_val_y[y], -min_val_y[y]);
975                 printf("  [%4d, %4d] (%.2f bits)\n", min_val_y[y], max_val_y[y], log2(range * 2 + 1));
976         }
977 #endif
978 }