]> git.sesse.net Git - narabu/blob - qdc.cpp
951915225f8784ff1ede25c8978db48ea6395a2a
[narabu] / qdc.cpp
1 #include <stdio.h>
2 #include <stdint.h>
3 #include <stdlib.h>
4 #include <string.h>
5 #include <assert.h>
6 #include <math.h>
7
8 //#include "ryg_rans/rans64.h"
9 #include "ryg_rans/rans_byte.h"
10 #include "ryg_rans/renormalize.h"
11
12 #include <algorithm>
13 #include <memory>
14 #include <numeric>
15 #include <random>
16 #include <vector>
17 #include <unordered_map>
18
19 #define WIDTH 1280
20 #define HEIGHT 720
21 #define NUM_SYMS 256
22 #define ESCAPE_LIMIT (NUM_SYMS - 1)
23
24 // If you set this to 1, the program will try to optimize the placement
25 // of coefficients to rANS probability distributions. This is randomized,
26 // so you might want to run it a few times.
27 #define FIND_OPTIMAL_STREAM_ASSIGNMENT 0
28 #define NUM_CLUSTERS 4
29
30 static constexpr uint32_t prob_bits = 12;
31 static constexpr uint32_t prob_scale = 1 << prob_bits;
32
33 using namespace std;
34
35 void fdct_int32(short *const In);
36 void idct_int32(short *const In);
37
38 unsigned char rgb[WIDTH * HEIGHT * 3];
39 unsigned char pix_y[WIDTH * HEIGHT];
40 unsigned char pix_cb[(WIDTH/2) * HEIGHT];
41 unsigned char pix_cr[(WIDTH/2) * HEIGHT];
42 unsigned char full_pix_cb[WIDTH * HEIGHT];
43 unsigned char full_pix_cr[WIDTH * HEIGHT];
44 short coeff_y[WIDTH * HEIGHT], coeff_cb[(WIDTH/2) * HEIGHT], coeff_cr[(WIDTH/2) * HEIGHT];
45
46 int clamp(int x)
47 {
48         if (x < 0) return 0;
49         if (x > 255) return 255;
50         return x;
51 }
52
53 static const unsigned char std_luminance_quant_tbl[64] = {
54 #if 0
55         16,  11,  10,  16,  24,  40,  51,  61,
56         12,  12,  14,  19,  26,  58,  60,  55,
57         14,  13,  16,  24,  40,  57,  69,  56,
58         14,  17,  22,  29,  51,  87,  80,  62,
59         18,  22,  37,  56,  68, 109, 103,  77,
60         24,  35,  55,  64,  81, 104, 113,  92,
61         49,  64,  78,  87, 103, 121, 120, 101,
62         72,  92,  95,  98, 112, 100, 103,  99
63 #else
64         // ff_mpeg1_default_intra_matrix
65          8, 16, 19, 22, 26, 27, 29, 34,
66         16, 16, 22, 24, 27, 29, 34, 37,                                                 
67         19, 22, 26, 27, 29, 34, 34, 38,                                                 
68         22, 22, 26, 27, 29, 34, 37, 40,
69         22, 26, 27, 29, 32, 35, 40, 48,
70         26, 27, 29, 32, 35, 40, 48, 58,
71         26, 27, 29, 34, 38, 46, 56, 69,
72         27, 29, 35, 38, 46, 56, 69, 83
73 #endif
74 };
75
76 struct SymbolStats
77 {
78     uint32_t freqs[NUM_SYMS];
79     uint32_t cum_freqs[NUM_SYMS + 1];
80
81     void clear();
82     void calc_cum_freqs();
83     void normalize_freqs(uint32_t target_total);
84 };
85
86 void SymbolStats::clear()
87 {
88     for (int i=0; i < NUM_SYMS; i++)
89         freqs[i] = 0;
90 }
91
92 void SymbolStats::calc_cum_freqs()
93 {
94     cum_freqs[0] = 0;
95     for (int i=0; i < NUM_SYMS; i++)
96         cum_freqs[i+1] = cum_freqs[i] + freqs[i];
97 }
98
99 void SymbolStats::normalize_freqs(uint32_t target_total)
100 {
101     uint64_t real_freq[NUM_SYMS + 1];  // hack
102
103     assert(target_total >= NUM_SYMS);
104
105     calc_cum_freqs();
106     uint32_t cur_total = cum_freqs[NUM_SYMS];
107
108     if (cur_total == 0) return;
109
110     double ideal_cost = 0.0;
111     for (int i = 1; i <= NUM_SYMS; i++)
112     {
113       real_freq[i] = cum_freqs[i] - cum_freqs[i - 1];
114       if (real_freq[i] > 0)
115         ideal_cost -= real_freq[i] * log2(real_freq[i] / double(cur_total));
116     }
117
118     OptimalRenormalize(cum_freqs, NUM_SYMS, prob_scale);
119
120     // calculate updated freqs and make sure we didn't screw anything up
121     assert(cum_freqs[0] == 0 && cum_freqs[NUM_SYMS] == target_total);
122     for (int i=0; i < NUM_SYMS; i++) {
123         if (freqs[i] == 0)
124             assert(cum_freqs[i+1] == cum_freqs[i]);
125         else
126             assert(cum_freqs[i+1] > cum_freqs[i]);
127
128         // calc updated freq
129         freqs[i] = cum_freqs[i+1] - cum_freqs[i];
130     }
131
132     double calc_cost = 0.0;
133     for (int i = 1; i <= NUM_SYMS; i++)
134     {
135       uint64_t freq = cum_freqs[i] - cum_freqs[i - 1];
136       if (real_freq[i] > 0)
137         calc_cost -= real_freq[i] * log2(freq / double(target_total));
138     }
139
140     static double total_loss = 0.0;
141     total_loss += calc_cost - ideal_cost;
142     static double total_loss_with_dp = 0.0;
143         double optimal_cost = 0.0;
144     //total_loss_with_dp += optimal_cost - ideal_cost;
145     printf("ideal cost = %.0f bits, DP cost = %.0f bits, calc cost = %.0f bits (loss = %.2f bytes, total loss = %.2f bytes, total loss with DP = %.2f bytes)\n",
146                 ideal_cost, optimal_cost,
147                  calc_cost, (calc_cost - ideal_cost) / 8.0, total_loss / 8.0, total_loss_with_dp / 8.0);
148 }
149
150 SymbolStats stats[128];
151
152 #if FIND_OPTIMAL_STREAM_ASSIGNMENT
153 // Distance from one stream to the other, based on a hacked-up K-L divergence.
154 float kl_dist[64][64];
155 #endif
156
157 const int luma_mapping[64] = {
158         0, 0, 1, 1, 2, 2, 3, 3,
159         0, 0, 1, 2, 2, 2, 3, 3,
160         1, 1, 2, 2, 2, 3, 3, 3,
161         1, 1, 2, 2, 2, 3, 3, 3,
162         1, 2, 2, 2, 2, 3, 3, 3,
163         2, 2, 2, 2, 3, 3, 3, 3,
164         2, 2, 3, 3, 3, 3, 3, 3,
165         3, 3, 3, 3, 3, 3, 3, 3,
166 };
167 const int chroma_mapping[64] = {
168         0, 1, 1, 2, 2, 2, 3, 3,
169         1, 1, 2, 2, 2, 3, 3, 3,
170         2, 2, 2, 2, 3, 3, 3, 3,
171         2, 2, 2, 3, 3, 3, 3, 3,
172         2, 3, 3, 3, 3, 3, 3, 3,
173         3, 3, 3, 3, 3, 3, 3, 3,
174         3, 3, 3, 3, 3, 3, 3, 3,
175         3, 3, 3, 3, 3, 3, 3, 3,
176 };
177
178 int pick_stats_for(int x, int y, bool is_chroma)
179 {
180 #if FIND_OPTIMAL_STREAM_ASSIGNMENT
181         return y * 8 + x + is_chroma * 64;
182 #else
183         if (is_chroma) {
184                 return chroma_mapping[y * 8 + x] + 4;
185         } else {
186                 return luma_mapping[y * 8 + x];
187         }
188 #endif
189 }
190                 
191
192 void write_varint(int x, FILE *fp)
193 {
194         while (x >= 128) {
195                 putc((x & 0x7f) | 0x80, fp);
196                 x >>= 7;
197         }
198         putc(x, fp);
199 }
200
201 class RansEncoder {
202 public:
203         RansEncoder()
204         {
205                 out_buf.reset(new uint8_t[out_max_size]);
206                 clear();
207         }
208
209         void init_prob(SymbolStats &s)
210         {
211                 for (int i = 0; i < NUM_SYMS; i++) {
212                         printf("%d: cumfreqs=%d freqs=%d prob_bits=%d\n", i, s.cum_freqs[i], s.freqs[i], prob_bits + 1);
213                         RansEncSymbolInit(&esyms[i], s.cum_freqs[i], s.freqs[i], prob_bits + 1);
214                 }
215                 sign_bias = s.cum_freqs[256];
216         }
217
218         void clear()
219         {
220                 out_end = out_buf.get() + out_max_size;
221                 ptr = out_end; // *end* of output buffer
222                 RansEncInit(&rans);
223         }
224
225         uint32_t save_block(FILE *codedfp)  // Returns number of bytes.
226         {
227                 RansEncFlush(&rans, &ptr);
228                 //printf("post-flush = %08x\n", rans);
229
230                 uint32_t num_rans_bytes = out_end - ptr;
231 #if 0
232                 if (num_rans_bytes == 4) {
233                         uint32_t block;
234                         memcpy(&block, ptr, 4);
235
236                         if (block == last_block) {
237                                 write_varint(0, codedfp);
238                                 clear();
239                                 return 1;
240                         }
241
242                         last_block = block;
243                 } else {
244                         last_block = 0;
245                 }
246 #endif
247
248                 write_varint(num_rans_bytes, codedfp);
249                 //fwrite(&num_rans_bytes, 1, 4, codedfp);
250                 fwrite(ptr, 1, num_rans_bytes, codedfp);
251
252                 //printf("first rANS bytes: %02x %02x %02x %02x %02x %02x %02x %02x\n", ptr[0], ptr[1], ptr[2], ptr[3], ptr[4], ptr[5], ptr[6], ptr[7]);
253
254
255                 clear();
256
257                 printf("Saving block: %d rANS bytes\n", num_rans_bytes);
258                 return num_rans_bytes;
259                 //return num_rans_bytes;
260         }
261
262         void encode_coeff(short signed_k)
263         {
264                 //printf("encoding coeff %d (sym %d), rans before encoding = %08x\n", signed_k, ((abs(signed_k) - 1) & 255), rans);
265                 unsigned short k = abs(signed_k);
266                 if (k >= ESCAPE_LIMIT) {
267                         // Put the coefficient as a 1/(2^12) symbol _before_
268                         // the 255 coefficient, since the decoder will read the
269                         // 255 coefficient first.
270                         RansEncPut(&rans, &ptr, k, 1, prob_bits);
271                         k = ESCAPE_LIMIT;
272                 }
273                 RansEncPutSymbol(&rans, &ptr, &esyms[(k - 1) & 255]);
274                 if (signed_k < 0) {
275                         rans += sign_bias;
276                 }
277         }
278
279 private:
280         static constexpr size_t out_max_size = 32 << 20; // 32 MB.
281         static constexpr size_t max_num_sign = 1048576;  // Way too big. And actually bytes.
282
283         unique_ptr<uint8_t[]> out_buf;
284         uint8_t *out_end;
285         uint8_t *ptr;
286         RansState rans;
287         RansEncSymbol esyms[NUM_SYMS];
288         uint32_t sign_bias;
289
290         uint32_t last_block = 0;  // Not a valid 4-byte rANS block (?)
291 };
292
293 static constexpr int dc_scalefac = 8;  // Matches the FDCT's gain.
294 static constexpr double quant_scalefac = 4.0;  // whatever?
295
296 static inline int quantize(int f, int coeff_idx)
297 {
298         if (coeff_idx == 0) {
299                 return f / dc_scalefac;
300         }
301         if (f == 0) {
302                 return 0;
303         }
304
305         const int w = std_luminance_quant_tbl[coeff_idx];
306         const int s = quant_scalefac;
307         int sign_f = (f > 0) ? 1 : -1;
308         return (32 * f + sign_f * w * s) / (2 * w * s);
309 }
310
311 static inline int unquantize(int qf, int coeff_idx)
312 {
313         if (coeff_idx == 0) {
314                 return qf * dc_scalefac;
315         }
316         if (qf == 0) {
317                 return 0;
318         }
319
320         const int w = std_luminance_quant_tbl[coeff_idx];
321         const int s = quant_scalefac;
322         return (2 * qf * w * s) / 32;
323 }
324
325 void readpix(unsigned char *ptr, const char *filename)
326 {
327         FILE *fp = fopen(filename, "rb");
328         if (fp == nullptr) {
329                 perror(filename);
330                 exit(1);
331         }
332
333         fseek(fp, 0, SEEK_END);
334         long len = ftell(fp);
335         assert(len >= WIDTH * HEIGHT * 3);
336         fseek(fp, len - WIDTH * HEIGHT * 3, SEEK_SET);
337
338         fread(ptr, 1, WIDTH * HEIGHT * 3, fp);
339         fclose(fp);
340 }
341
342 void convert_ycbcr()
343 {
344         double coeff[3] = { 0.2126, 0.7152, 0.0722 };  // sum = 1.0
345         double cb_fac = 1.0 / (coeff[0] + coeff[1] + 1.0f - coeff[2]);  // 0.539
346         double cr_fac = 1.0 / (1.0f - coeff[0] + coeff[1] + coeff[2]);  // 0.635 
347
348         unique_ptr<float[]> temp_cb(new float[WIDTH * HEIGHT]);
349         unique_ptr<float[]> temp_cr(new float[WIDTH * HEIGHT]);
350         for (unsigned yb = 0; yb < HEIGHT; ++yb) {
351                 for (unsigned xb = 0; xb < WIDTH; ++xb) {
352                         int r = rgb[((yb * WIDTH) + xb) * 3 + 0];
353                         int g = rgb[((yb * WIDTH) + xb) * 3 + 1];
354                         int b = rgb[((yb * WIDTH) + xb) * 3 + 2];
355                         double y = std::min(std::max(coeff[0] * r + coeff[1] * g + coeff[2] * b, 0.0), 255.0);
356                         double cb = (b - y) * cb_fac + 128.0;
357                         double cr = (r - y) * cr_fac + 128.0;
358                         pix_y[(yb * WIDTH) + xb] = lrint(y);
359                         temp_cb[(yb * WIDTH) + xb] = cb;
360                         temp_cr[(yb * WIDTH) + xb] = cr;
361                         full_pix_cb[(yb * WIDTH) + xb] = lrint(std::min(std::max(cb, 0.0), 255.0));
362                         full_pix_cr[(yb * WIDTH) + xb] = lrint(std::min(std::max(cr, 0.0), 255.0));
363                 }
364         }
365
366         // Simple 4:2:2 subsampling with left convention.
367         for (unsigned yb = 0; yb < HEIGHT; ++yb) {
368                 for (unsigned xb = 0; xb < WIDTH / 2; ++xb) {
369                         int c0 = yb * WIDTH + std::max(int(xb) * 2 - 1, 0);
370                         int c1 = yb * WIDTH + xb * 2;
371                         int c2 = yb * WIDTH + xb * 2 + 1;
372                         
373                         double cb = 0.25 * temp_cb[c0] + 0.5 * temp_cb[c1] + 0.25 * temp_cb[c2];
374                         double cr = 0.25 * temp_cr[c0] + 0.5 * temp_cr[c1] + 0.25 * temp_cr[c2];
375                         cb = std::min(std::max(cb, 0.0), 255.0);
376                         cr = std::min(std::max(cr, 0.0), 255.0);
377                         pix_cb[(yb * WIDTH/2) + xb] = lrint(cb);
378                         pix_cr[(yb * WIDTH/2) + xb] = lrint(cr);
379                 }
380         }
381 }
382
383 #if FIND_OPTIMAL_STREAM_ASSIGNMENT
384 double find_best_assignment(const int *medoids, int *assignment)
385 {
386         double current_score = 0.0;
387         for (int i = 0; i < 64; ++i) {
388                 int best_medoid = medoids[0];
389                 float best_medoid_score = kl_dist[i][medoids[0]];
390                 for (int j = 1; j < NUM_CLUSTERS; ++j) {
391                         if (kl_dist[i][medoids[j]] < best_medoid_score) {
392                                 best_medoid = medoids[j];
393                                 best_medoid_score = kl_dist[i][medoids[j]];
394                         }
395                 }
396                 assignment[i] = best_medoid;
397                 current_score += best_medoid_score;
398         }
399         return current_score;
400 }
401
402 void find_optimal_stream_assignment(int base)
403 {
404         double inv_sum[64];
405         for (unsigned i = 0; i < 64; ++i) {
406                 double s = 0.0;
407                 for (unsigned k = 0; k < 256; ++k) {
408                         s += stats[i + base].freqs[k] + 0.5;
409                 }
410                 inv_sum[i] = 1.0 / s;
411         }
412
413         for (unsigned i = 0; i < 64; ++i) {
414                 for (unsigned j = 0; j < 64; ++j) {
415                         double d = 0.0;
416                         for (unsigned k = 0; k < 256; ++k) {
417                                 double p1 = (stats[i + base].freqs[k] + 0.5) * inv_sum[i];
418                                 double p2 = (stats[j + base].freqs[k] + 0.5) * inv_sum[j];
419
420                                 // K-L divergence is asymmetric; this is a hack.
421                                 d += p1 * log(p1 / p2);
422                                 d += p2 * log(p2 / p1);
423                         }
424                         kl_dist[i][j] = d;
425                         //printf("%.3f ", d);
426                 }
427                 //printf("\n");
428         }
429
430         // k-medoids init
431         int medoids[64];  // only the first NUM_CLUSTERS matter
432         bool is_medoid[64] = { false };
433         std::iota(medoids, medoids + 64, 0);
434         std::random_device rd;
435         std::mt19937 g(rd());
436         std::shuffle(medoids, medoids + 64, g);
437         for (int i = 0; i < NUM_CLUSTERS; ++i) {
438                 printf("%d ", medoids[i]);
439                 is_medoid[medoids[i]] = true;
440         }
441         printf("\n");
442
443         // assign each data point to the closest medoid
444         int assignment[64];
445         double current_score = find_best_assignment(medoids, assignment);
446
447         for (int i = 0; i < 1000; ++i) {
448                 printf("iter %d\n", i);
449                 bool any_changed = false;
450                 for (int m = 0; m < NUM_CLUSTERS; ++m) {
451                         for (int o = 0; o < 64; ++o) {
452                                 if (is_medoid[o]) continue;
453                                 int old_medoid = medoids[m];
454                                 medoids[m] = o;
455
456                                 int new_assignment[64];
457                                 double candidate_score = find_best_assignment(medoids, new_assignment);
458
459                                 if (candidate_score < current_score) {
460                                         current_score = candidate_score;
461                                         memcpy(assignment, new_assignment, sizeof(assignment));
462
463                                         is_medoid[old_medoid] = false;
464                                         is_medoid[medoids[m]] = true;
465                                         printf("%f: ", current_score);
466                                         for (int i = 0; i < 64; ++i) {
467                                                 printf("%d ", assignment[i]);
468                                         }
469                                         printf("\n");
470                                         any_changed = true;
471                                 } else {
472                                         medoids[m] = old_medoid;
473                                 }
474                         }
475                 }
476                 if (!any_changed) break;
477         }
478         printf("\n");
479         std::unordered_map<int, int> rmap;
480         for (int i = 0; i < 64; ++i) {
481                 if (i % 8 == 0) printf("\n");
482                 if (!rmap.count(assignment[i])) {
483                         rmap.emplace(assignment[i], rmap.size());
484                 }
485                 printf("%d, ", rmap[assignment[i]]);
486         }
487         printf("\n");
488 }
489 #endif
490
491 int main(int argc, char **argv)
492 {
493         if (argc >= 2)
494                 readpix(rgb, argv[1]);
495         else
496                 readpix(rgb, "color.pnm");
497         convert_ycbcr();
498
499         double sum_sq_err = 0.0;
500         //double last_cb_cfl_fac = 0.0;
501         //double last_cr_cfl_fac = 0.0;
502
503         // DCT and quantize luma
504         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
505                 for (unsigned xb = 0; xb < WIDTH; xb += 8) {
506                         // Read one block
507                         short in_y[64];
508                         for (unsigned y = 0; y < 8; ++y) {
509                                 for (unsigned x = 0; x < 8; ++x) {
510                                         in_y[y * 8 + x] = pix_y[(yb + y) * WIDTH + (xb + x)];
511                                 }
512                         }
513
514                         // FDCT it
515                         fdct_int32(in_y);
516
517                         for (unsigned y = 0; y < 8; ++y) {
518                                 for (unsigned x = 0; x < 8; ++x) {
519                                         int coeff_idx = y * 8 + x;
520                                         int k = quantize(in_y[coeff_idx], coeff_idx);
521                                         coeff_y[(yb + y) * WIDTH + (xb + x)] = k;
522
523                                         // Store back for reconstruction / PSNR calculation
524                                         in_y[coeff_idx] = unquantize(k, coeff_idx);
525                                 }
526                         }
527
528                         idct_int32(in_y);
529
530                         for (unsigned y = 0; y < 8; ++y) {
531                                 for (unsigned x = 0; x < 8; ++x) {
532                                         int k = clamp(in_y[y * 8 + x]);
533                                         uint8_t *ptr = &pix_y[(yb + y) * WIDTH + (xb + x)];
534                                         sum_sq_err += (*ptr - k) * (*ptr - k);
535                                         *ptr = k;
536                                 }
537                         }
538                 }
539         }
540         double mse = sum_sq_err / double(WIDTH * HEIGHT);
541         double psnr_db = 20 * log10(255.0 / sqrt(mse));
542         printf("psnr = %.2f dB\n", psnr_db);
543
544         //double chroma_energy = 0.0, chroma_energy_pred = 0.0;
545
546         // DCT and quantize chroma
547         //double last_cb_cfl_fac = 0.0, last_cr_cfl_fac = 0.0;
548         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
549                 for (unsigned xb = 0; xb < WIDTH/2; xb += 8) {
550 #if 0
551                         // TF switch: Two 8x8 luma blocks -> one 16x8 block, then drop high frequencies
552                         printf("in blocks:\n");
553                         for (unsigned y = 0; y < 8; ++y) {
554                                 for (unsigned x = 0; x < 8; ++x) {
555                                         short a = coeff_y[(yb + y) * WIDTH + (xb*2 + x)];
556                                         printf(" %4d", a);
557                                 }
558                                 printf(" | ");
559                                 for (unsigned x = 0; x < 8; ++x) {
560                                         short b = coeff_y[(yb + y) * WIDTH + (xb*2 + x + 8)];
561                                         printf(" %4d", b);
562                                 }
563                                 printf("\n");
564                         }
565
566                         short in_y[64];
567                         for (unsigned y = 0; y < 8; ++y) {
568                                 for (unsigned x = 0; x < 4; ++x) {
569                                         short a = coeff_y[(yb + y) * WIDTH + (xb*2 + x)];
570                                         short b = coeff_y[(yb + y) * WIDTH + (xb*2 + x + 8)];
571                                         b = a - b;
572                                         a = 2 * a - b;
573                                         in_y[y * 8 + x * 2 + 0] = a;
574                                         in_y[y * 8 + x * 2 + 1] = b;
575                                 }
576                         }
577
578                         printf("tf-ed block:\n");
579                         for (unsigned y = 0; y < 8; ++y) {
580                                 for (unsigned x = 0; x < 8; ++x) {
581                                         short a = in_y[y * 8 + x];
582                                         printf(" %4d", a);
583                                 }
584                                 printf("\n");
585                         }
586 #else
587                         // Read Y block with no tf switch (from reconstructed luma)
588                         short in_y[64];
589                         for (unsigned y = 0; y < 8; ++y) {
590                                 for (unsigned x = 0; x < 8; ++x) {
591                                         in_y[y * 8 + x] = pix_y[(yb + y) * (WIDTH) + (xb + x) * 2];
592                                 }
593                         }
594                         fdct_int32(in_y);
595 #endif
596
597                         // Read one block
598                         short in_cb[64], in_cr[64];
599                         for (unsigned y = 0; y < 8; ++y) {
600                                 for (unsigned x = 0; x < 8; ++x) {
601                                         in_cb[y * 8 + x] = pix_cb[(yb + y) * (WIDTH/2) + (xb + x)];
602                                         in_cr[y * 8 + x] = pix_cr[(yb + y) * (WIDTH/2) + (xb + x)];
603                                 }
604                         }
605
606                         // FDCT it
607                         fdct_int32(in_cb);
608                         fdct_int32(in_cr);
609
610 #if 0
611                         // Chroma from luma
612                         double x0 = in_y[1];
613                         double x1 = in_y[8];
614                         double x2 = in_y[9];
615                         double denom = (x0 * x0 + x1 * x1 + x2 * x2);
616                         //double denom = (x1 * x1);
617         
618                         double y0 = in_cb[1];
619                         double y1 = in_cb[8];
620                         double y2 = in_cb[9];
621                         double cb_cfl_fac = (x0 * y0 + x1 * y1 + x2 * y2) / denom;
622                         //double cb_cfl_fac = (x1 * y1) / denom;
623
624                         for (unsigned y = 0; y < 8; ++y) {
625                                 for (unsigned x = 0; x < 8; ++x) {
626                                         short a = in_y[y * 8 + x];
627                                         printf(" %4d", a);
628                                 }
629                                 printf(" | ");
630                                 for (unsigned x = 0; x < 8; ++x) {
631                                         short a = in_cb[y * 8 + x];
632                                         printf(" %4d", a);
633                                 }
634                                 printf("\n");
635                         }
636                         printf("(%d,%d,%d) -> (%d,%d,%d) gives %f\n",
637                                 in_y[1], in_y[8], in_y[9], 
638                                 in_cb[1], in_cb[8], in_cb[9],
639                                 cb_cfl_fac);
640
641                         y0 = in_cr[1];
642                         y1 = in_cr[8];
643                         y2 = in_cr[9];
644                         double cr_cfl_fac = (x0 * y0 + x1 * y1 + x2 * y2) / denom;
645                         //double cr_cfl_fac = (x1 * y1) / denom;
646                         printf("cb CfL = %7.3f  dc = %5d    cr CfL = %7.3f  dc = %d\n",
647                                 cb_cfl_fac, in_cb[0] - in_y[0],
648                                 cr_cfl_fac, in_cr[0] - in_y[0]);
649
650                         if (denom == 0.0) { cb_cfl_fac = cr_cfl_fac = 0.0; }
651
652                         // CHEAT
653                         //last_cb_cfl_fac = cb_cfl_fac;
654                         //last_cr_cfl_fac = cr_cfl_fac;
655
656                         for (unsigned coeff_idx = 1; coeff_idx < 64; ++coeff_idx) {
657                                 //printf("%2d: cb = %3d prediction = %f * %3d = %7.3f\n", coeff_idx, in_cb[coeff_idx], last_cb_cfl_fac, in_y[coeff_idx], last_cb_cfl_fac * in_y[coeff_idx]);
658                                 //printf("%2d: cr = %3d prediction = %f * %3d = %7.3f\n", coeff_idx, in_cr[coeff_idx], last_cr_cfl_fac, in_y[coeff_idx], last_cr_cfl_fac * in_y[coeff_idx]);
659                                 double cb_pred = last_cb_cfl_fac * in_y[coeff_idx];
660                                 chroma_energy += in_cb[coeff_idx] * in_cb[coeff_idx];
661                                 chroma_energy_pred += (in_cb[coeff_idx] - cb_pred) * (in_cb[coeff_idx] - cb_pred);
662
663                                 //in_cb[coeff_idx] -= lrint(last_cb_cfl_fac * in_y[coeff_idx]);
664                                 //in_cr[coeff_idx] -= lrint(last_cr_cfl_fac * in_y[coeff_idx]);
665                                 //in_cr[coeff_idx] -= lrint(last_cr_cfl_fac * in_y[coeff_idx]);
666                                 //in_cb[coeff_idx] = lrint(in_y[coeff_idx] * (1.0 / sqrt(2)));
667                                 //in_cr[coeff_idx] = lrint(in_y[coeff_idx] * (1.0 / sqrt(2)));
668                                 //in_cb[coeff_idx] = lrint(in_y[coeff_idx]);
669                                 //in_cr[coeff_idx] = lrint(in_y[coeff_idx]);
670                         }
671                         //in_cb[0] += 1024;
672                         //in_cr[0] += 1024;
673                         //in_cb[0] -= in_y[0];
674                         //in_cr[0] -= in_y[0];
675 #endif
676
677                         for (unsigned y = 0; y < 8; ++y) {
678                                 for (unsigned x = 0; x < 8; ++x) {
679                                         int coeff_idx = y * 8 + x;
680                                         int k_cb = quantize(in_cb[coeff_idx], coeff_idx);
681                                         coeff_cb[(yb + y) * (WIDTH/2) + (xb + x)] = k_cb;
682                                         int k_cr = quantize(in_cr[coeff_idx], coeff_idx);
683                                         coeff_cr[(yb + y) * (WIDTH/2) + (xb + x)] = k_cr;
684
685                                         // Store back for reconstruction / PSNR calculation
686                                         in_cb[coeff_idx] = unquantize(k_cb, coeff_idx);
687                                         in_cr[coeff_idx] = unquantize(k_cr, coeff_idx);
688                                 }
689                         }
690
691                         idct_int32(in_y);  // DEBUG
692                         idct_int32(in_cb);
693                         idct_int32(in_cr);
694
695                         for (unsigned y = 0; y < 8; ++y) {
696                                 for (unsigned x = 0; x < 8; ++x) {
697                                         pix_cb[(yb + y) * (WIDTH/2) + (xb + x)] = clamp(in_cb[y * 8 + x]);
698                                         pix_cr[(yb + y) * (WIDTH/2) + (xb + x)] = clamp(in_cr[y * 8 + x]);
699
700                         //              pix_cb[(yb + y) * (WIDTH/2) + (xb + x)] = in_y[y * 8 + x];
701                         //              pix_cr[(yb + y) * (WIDTH/2) + (xb + x)] = in_y[y * 8 + x];
702                                 }
703                         }
704
705 #if 0
706                         last_cb_cfl_fac = cb_cfl_fac;
707                         last_cr_cfl_fac = cr_cfl_fac;
708 #endif
709                 }
710         }
711
712 #if 0
713         printf("chroma_energy = %f, with_pred = %f\n",
714                 chroma_energy / (WIDTH * HEIGHT), chroma_energy_pred / (WIDTH * HEIGHT));
715 #endif
716
717         // DC coefficient pred from the right to left
718         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
719                 for (unsigned xb = 0; xb < WIDTH - 8; xb += 8) {
720                         coeff_y[yb * WIDTH + xb] -= coeff_y[yb * WIDTH + (xb + 8)];
721                 }
722         }
723
724         FILE *fp = fopen("reconstructed.pgm", "wb");
725         fprintf(fp, "P5\n%d %d\n255\n", WIDTH, HEIGHT);
726         fwrite(pix_y, 1, WIDTH * HEIGHT, fp);
727         fclose(fp);
728
729         fp = fopen("reconstructed.pnm", "wb");
730         fprintf(fp, "P6\n%d %d\n255\n", WIDTH, HEIGHT);
731         for (unsigned yb = 0; yb < HEIGHT; ++yb) {
732                 for (unsigned xb = 0; xb < WIDTH; ++xb) {
733                         int y = pix_y[(yb * WIDTH) + xb];
734                         int cb, cr;
735                         int c0 = yb * (WIDTH/2) + xb/2;
736                         if (xb % 2 == 0) {
737                                 cb = pix_cb[c0] - 128.0;
738                                 cr = pix_cr[c0] - 128.0;
739                         } else {
740                                 int c1 = yb * (WIDTH/2) + std::min<int>(xb/2 + 1, WIDTH/2 - 1);
741                                 cb = 0.5 * (pix_cb[c0] + pix_cb[c1]) - 128.0;
742                                 cr = 0.5 * (pix_cr[c0] + pix_cr[c1]) - 128.0;
743                         }
744
745                         double r = y + 1.5748 * cr;
746                         double g = y - 0.1873 * cb - 0.4681 * cr;
747                         double b = y + 1.8556 * cb;
748
749                         putc(clamp(lrint(r)), fp);
750                         putc(clamp(lrint(g)), fp);
751                         putc(clamp(lrint(b)), fp);
752                 }
753         }
754         fclose(fp);
755
756         // For each coefficient, make some tables.
757         size_t extra_bits = 0;
758         for (unsigned i = 0; i < 64; ++i) {
759                 stats[i].clear();
760         }
761         for (unsigned y = 0; y < 8; ++y) {
762                 for (unsigned x = 0; x < 8; ++x) {
763                         SymbolStats &s_luma = stats[pick_stats_for(x, y, false)];
764                         SymbolStats &s_chroma = stats[pick_stats_for(x, y, true)];
765
766                         // Luma
767                         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
768                                 for (unsigned xb = 0; xb < WIDTH; xb += 8) {
769                                         unsigned short k = abs(coeff_y[(yb + y) * WIDTH + (xb + x)]);
770                                         if (k >= ESCAPE_LIMIT) {
771                                                 k = ESCAPE_LIMIT;
772                                                 extra_bits += 12;  // escape this one
773                                         }
774                                         ++s_luma.freqs[(k - 1) & 255];
775                                 }
776                         }
777                         // Chroma
778                         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
779                                 for (unsigned xb = 0; xb < WIDTH/2; xb += 8) {
780                                         unsigned short k_cb = abs(coeff_cb[(yb + y) * WIDTH/2 + (xb + x)]);
781                                         unsigned short k_cr = abs(coeff_cr[(yb + y) * WIDTH/2 + (xb + x)]);
782                                         if (k_cb >= ESCAPE_LIMIT) {
783                                                 k_cb = ESCAPE_LIMIT;
784                                                 extra_bits += 12;  // escape this one
785                                         }
786                                         if (k_cr >= ESCAPE_LIMIT) {
787                                                 k_cr = ESCAPE_LIMIT;
788                                                 extra_bits += 12;  // escape this one
789                                         }
790                                         ++s_chroma.freqs[(k_cb - 1) & 255];
791                                         ++s_chroma.freqs[(k_cr - 1) & 255];
792                                 }
793                         }
794                 }
795         }
796
797 #if FIND_OPTIMAL_STREAM_ASSIGNMENT
798         printf("Luma:\n");
799         find_optimal_stream_assignment(0);
800         printf("Chroma:\n");
801         find_optimal_stream_assignment(64);
802         exit(0);
803 #endif
804
805         for (unsigned i = 0; i < 64; ++i) {
806                 stats[i].freqs[255] /= 2;  // zero, has no sign bits (yes, this is trickery)
807                 stats[i].normalize_freqs(prob_scale);
808                 stats[i].cum_freqs[256] += stats[i].freqs[255];
809                 stats[i].freqs[255] *= 2;
810         }
811
812         FILE *codedfp = fopen("coded.dat", "wb");
813         if (codedfp == nullptr) {
814                 perror("coded.dat");
815                 exit(1);
816         }
817
818         // TODO: rather gamma-k or something
819         for (unsigned i = 0; i < 64; ++i) {
820                 if (stats[i].cum_freqs[NUM_SYMS] == 0) {
821                         continue;
822                 }
823                 printf("writing table %d\n", i);
824                 for (unsigned j = 0; j < NUM_SYMS; ++j) {
825                         write_varint(stats[i].freqs[j], codedfp);
826                 }
827         }
828
829         RansEncoder rans_encoder;
830
831         size_t tot_bytes = 0;
832
833         // Luma
834         for (unsigned y = 0; y < 8; ++y) {
835                 for (unsigned x = 0; x < 8; ++x) {
836                         SymbolStats &s_luma = stats[pick_stats_for(x, y, false)];
837                         rans_encoder.init_prob(s_luma);
838
839                         // Luma
840                         std::vector<int> lens;
841
842                         // need to reverse later
843                         rans_encoder.clear();
844                         size_t num_bytes = 0;
845                         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
846                                 for (unsigned xb = 0; xb < WIDTH; xb += 8) {
847                                         int k = coeff_y[(yb + y) * WIDTH + (xb + x)];
848                                         //printf("encoding coeff %d xb,yb=%d,%d: %d\n", y*8+x, xb, yb, k);
849                                         rans_encoder.encode_coeff(k);
850                                 }
851                                 if (yb % 16 == 8) {
852                                         int l = rans_encoder.save_block(codedfp);
853                                         num_bytes += l;
854                                         lens.push_back(l);
855                                 }
856                         }
857                         if (HEIGHT % 16 != 0) {
858                                 num_bytes += rans_encoder.save_block(codedfp);
859                         }
860                         tot_bytes += num_bytes;
861                         printf("coeff %d Y': %ld bytes\n", y * 8 + x, num_bytes);
862
863                         double sum_l = 0.0;
864                         for (int l : lens) {
865                                 sum_l += l;
866                         }
867                         double avg_l = sum_l / lens.size();
868
869                         double sum_sql = 0.0;
870                         for (int l : lens) {
871                                 sum_sql += (l - avg_l) * (l - avg_l);
872                         }
873                         double stddev_l = sqrt(sum_sql / (lens.size() - 1));
874                         printf("coeff %d: avg=%.2f bytes, stddev=%.2f bytes\n", y*8+x, avg_l, stddev_l);
875                 }
876         }
877
878         // Cb
879         for (unsigned y = 0; y < 8; ++y) {
880                 for (unsigned x = 0; x < 8; ++x) {
881                         SymbolStats &s_chroma = stats[pick_stats_for(x, y, true)];
882                         rans_encoder.init_prob(s_chroma);
883
884                         rans_encoder.clear();
885                         size_t num_bytes = 0;
886                         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
887                                 for (unsigned xb = 0; xb < WIDTH/2; xb += 8) {
888                                         int k = coeff_cb[(yb + y) * WIDTH/2 + (xb + x)];
889                                         rans_encoder.encode_coeff(k);
890                                 }
891                                 if (yb % 16 == 8) {
892                                         num_bytes += rans_encoder.save_block(codedfp);
893                                 }
894                         }
895                         if (HEIGHT % 16 != 0) {
896                                 num_bytes += rans_encoder.save_block(codedfp);
897                         }
898                         tot_bytes += num_bytes;
899                         printf("coeff %d Cb: %ld bytes\n", y * 8 + x, num_bytes);
900                 }
901         }
902
903         // Cr
904         for (unsigned y = 0; y < 8; ++y) {
905                 for (unsigned x = 0; x < 8; ++x) {
906                         SymbolStats &s_chroma = stats[pick_stats_for(x, y, true)];
907                         rans_encoder.init_prob(s_chroma);
908
909                         rans_encoder.clear();
910                         size_t num_bytes = 0;
911                         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
912                                 for (unsigned xb = 0; xb < WIDTH/2; xb += 8) {
913                                         int k = coeff_cr[(yb + y) * WIDTH/2 + (xb + x)];
914                                         rans_encoder.encode_coeff(k);
915                                 }
916                                 if (yb % 16 == 8) {
917                                         num_bytes += rans_encoder.save_block(codedfp);
918                                 }
919                         }
920                         if (HEIGHT % 16 != 0) {
921                                 num_bytes += rans_encoder.save_block(codedfp);
922                         }
923                         tot_bytes += num_bytes;
924                         printf("coeff %d Cr: %ld bytes\n", y * 8 + x, num_bytes);
925                 }
926         }
927
928         printf("%ld bytes + %ld escape bits (%ld) = %ld total bytes\n",
929                 tot_bytes - extra_bits / 8,
930                 extra_bits,
931                 extra_bits / 8,
932                 tot_bytes);
933 }