]> git.sesse.net Git - narabu/blob - qdc.cpp
41ef04df00f95f3714e1617326fb4d4c4f5ddefa
[narabu] / qdc.cpp
1 #include <stdio.h>
2 #include <stdint.h>
3 #include <stdlib.h>
4 #include <string.h>
5 #include <assert.h>
6 #include <math.h>
7
8 //#include "ryg_rans/rans64.h"
9 #include "ryg_rans/rans_byte.h"
10 #include "ryg_rans/renormalize.h"
11
12 #include <algorithm>
13 #include <memory>
14 #include <numeric>
15 #include <random>
16 #include <vector>
17 #include <unordered_map>
18
19 #define WIDTH 1280
20 #define HEIGHT 720
21 #define WIDTH_BLOCKS (WIDTH/8)
22 #define WIDTH_BLOCKS_CHROMA (WIDTH/16)
23 #define HEIGHT_BLOCKS (HEIGHT/8)
24 #define NUM_BLOCKS (WIDTH_BLOCKS * HEIGHT_BLOCKS)
25 #define NUM_BLOCKS_CHROMA (WIDTH_BLOCKS_CHROMA * HEIGHT_BLOCKS)
26
27 #define NUM_SYMS 256
28 #define ESCAPE_LIMIT (NUM_SYMS - 1)
29 #define BLOCKS_PER_STREAM 320
30
31 // If you set this to 1, the program will try to optimize the placement
32 // of coefficients to rANS probability distributions. This is randomized,
33 // so you might want to run it a few times.
34 #define FIND_OPTIMAL_STREAM_ASSIGNMENT 0
35 #define NUM_CLUSTERS 4
36
37 static constexpr uint32_t prob_bits = 12;
38 static constexpr uint32_t prob_scale = 1 << prob_bits;
39
40 using namespace std;
41
42 void fdct_int32(short *const In);
43 void idct_int32(short *const In);
44
45 unsigned char rgb[WIDTH * HEIGHT * 3];
46 unsigned char pix_y[WIDTH * HEIGHT];
47 unsigned char pix_cb[(WIDTH/2) * HEIGHT];
48 unsigned char pix_cr[(WIDTH/2) * HEIGHT];
49 unsigned char full_pix_cb[WIDTH * HEIGHT];
50 unsigned char full_pix_cr[WIDTH * HEIGHT];
51 short coeff_y[WIDTH * HEIGHT], coeff_cb[(WIDTH/2) * HEIGHT], coeff_cr[(WIDTH/2) * HEIGHT];
52
53 int clamp(int x)
54 {
55         if (x < 0) return 0;
56         if (x > 255) return 255;
57         return x;
58 }
59
60 static const unsigned char std_luminance_quant_tbl[64] = {
61 #if 0
62         16,  11,  10,  16,  24,  40,  51,  61,
63         12,  12,  14,  19,  26,  58,  60,  55,
64         14,  13,  16,  24,  40,  57,  69,  56,
65         14,  17,  22,  29,  51,  87,  80,  62,
66         18,  22,  37,  56,  68, 109, 103,  77,
67         24,  35,  55,  64,  81, 104, 113,  92,
68         49,  64,  78,  87, 103, 121, 120, 101,
69         72,  92,  95,  98, 112, 100, 103,  99
70 #else
71         // ff_mpeg1_default_intra_matrix
72          8, 16, 19, 22, 26, 27, 29, 34,
73         16, 16, 22, 24, 27, 29, 34, 37,                                                 
74         19, 22, 26, 27, 29, 34, 34, 38,                                                 
75         22, 22, 26, 27, 29, 34, 37, 40,
76         22, 26, 27, 29, 32, 35, 40, 48,
77         26, 27, 29, 32, 35, 40, 48, 58,
78         26, 27, 29, 34, 38, 46, 56, 69,
79         27, 29, 35, 38, 46, 56, 69, 83
80 #endif
81 };
82
83 struct SymbolStats
84 {
85     uint32_t freqs[NUM_SYMS];
86     uint32_t cum_freqs[NUM_SYMS + 1];
87
88     void clear();
89     void calc_cum_freqs();
90     void normalize_freqs(uint32_t target_total);
91 };
92
93 void SymbolStats::clear()
94 {
95     for (int i=0; i < NUM_SYMS; i++)
96         freqs[i] = 0;
97 }
98
99 void SymbolStats::calc_cum_freqs()
100 {
101     cum_freqs[0] = 0;
102     for (int i=0; i < NUM_SYMS; i++)
103         cum_freqs[i+1] = cum_freqs[i] + freqs[i];
104 }
105
106 void SymbolStats::normalize_freqs(uint32_t target_total)
107 {
108     uint64_t real_freq[NUM_SYMS + 1];  // hack
109
110     assert(target_total >= NUM_SYMS);
111
112     calc_cum_freqs();
113     uint32_t cur_total = cum_freqs[NUM_SYMS];
114
115     if (cur_total == 0) return;
116
117     double ideal_cost = 0.0;
118     for (int i = 1; i <= NUM_SYMS; i++)
119     {
120       real_freq[i] = cum_freqs[i] - cum_freqs[i - 1];
121       if (real_freq[i] > 0)
122         ideal_cost -= real_freq[i] * log2(real_freq[i] / double(cur_total));
123     }
124
125     OptimalRenormalize(cum_freqs, NUM_SYMS, prob_scale);
126
127     // calculate updated freqs and make sure we didn't screw anything up
128     assert(cum_freqs[0] == 0 && cum_freqs[NUM_SYMS] == target_total);
129     for (int i=0; i < NUM_SYMS; i++) {
130         if (freqs[i] == 0)
131             assert(cum_freqs[i+1] == cum_freqs[i]);
132         else
133             assert(cum_freqs[i+1] > cum_freqs[i]);
134
135         // calc updated freq
136         freqs[i] = cum_freqs[i+1] - cum_freqs[i];
137     }
138
139     double calc_cost = 0.0;
140     for (int i = 1; i <= NUM_SYMS; i++)
141     {
142       uint64_t freq = cum_freqs[i] - cum_freqs[i - 1];
143       if (real_freq[i] > 0)
144         calc_cost -= real_freq[i] * log2(freq / double(target_total));
145     }
146
147     static double total_loss = 0.0;
148     total_loss += calc_cost - ideal_cost;
149     static double total_loss_with_dp = 0.0;
150         double optimal_cost = 0.0;
151     //total_loss_with_dp += optimal_cost - ideal_cost;
152     printf("ideal cost = %.0f bits, DP cost = %.0f bits, calc cost = %.0f bits (loss = %.2f bytes, total loss = %.2f bytes, total loss with DP = %.2f bytes)\n",
153                 ideal_cost, optimal_cost,
154                  calc_cost, (calc_cost - ideal_cost) / 8.0, total_loss / 8.0, total_loss_with_dp / 8.0);
155 }
156
157 SymbolStats stats[128];
158
159 #if FIND_OPTIMAL_STREAM_ASSIGNMENT
160 // Distance from one stream to the other, based on a hacked-up K-L divergence.
161 float kl_dist[64][64];
162 #endif
163
164 const int luma_mapping[64] = {
165         0, 0, 1, 1, 2, 2, 3, 3,
166         0, 0, 1, 2, 2, 2, 3, 3,
167         1, 1, 2, 2, 2, 3, 3, 3,
168         1, 1, 2, 2, 2, 3, 3, 3,
169         1, 2, 2, 2, 2, 3, 3, 3,
170         2, 2, 2, 2, 3, 3, 3, 3,
171         2, 2, 3, 3, 3, 3, 3, 3,
172         3, 3, 3, 3, 3, 3, 3, 3,
173 };
174 const int chroma_mapping[64] = {
175         0, 1, 1, 2, 2, 2, 3, 3,
176         1, 1, 2, 2, 2, 3, 3, 3,
177         2, 2, 2, 2, 3, 3, 3, 3,
178         2, 2, 2, 3, 3, 3, 3, 3,
179         2, 3, 3, 3, 3, 3, 3, 3,
180         3, 3, 3, 3, 3, 3, 3, 3,
181         3, 3, 3, 3, 3, 3, 3, 3,
182         3, 3, 3, 3, 3, 3, 3, 3,
183 };
184
185 int pick_stats_for(int x, int y, bool is_chroma)
186 {
187 #if FIND_OPTIMAL_STREAM_ASSIGNMENT
188         return y * 8 + x + is_chroma * 64;
189 #else
190         if (is_chroma) {
191                 return chroma_mapping[y * 8 + x] + 4;
192         } else {
193                 return luma_mapping[y * 8 + x];
194         }
195 #endif
196 }
197                 
198
199 void write_varint(int x, FILE *fp)
200 {
201         while (x >= 128) {
202                 putc((x & 0x7f) | 0x80, fp);
203                 x >>= 7;
204         }
205         putc(x, fp);
206 }
207
208 class RansEncoder {
209 public:
210         RansEncoder()
211         {
212                 out_buf.reset(new uint8_t[out_max_size]);
213                 clear();
214         }
215
216         void init_prob(SymbolStats &s)
217         {
218                 for (int i = 0; i < NUM_SYMS; i++) {
219                         //printf("%d: cumfreqs=%d freqs=%d prob_bits=%d\n", i, s.cum_freqs[i], s.freqs[i], prob_bits + 1);
220                         RansEncSymbolInit(&esyms[i], s.cum_freqs[i], s.freqs[i], prob_bits + 1);
221                 }
222                 sign_bias = s.cum_freqs[NUM_SYMS];
223         }
224
225         void clear()
226         {
227                 out_end = out_buf.get() + out_max_size;
228                 ptr = out_end; // *end* of output buffer
229                 RansEncInit(&rans);
230         }
231
232         uint32_t save_block(FILE *codedfp)  // Returns number of bytes.
233         {
234                 RansEncFlush(&rans, &ptr);
235                 //printf("post-flush = %08x\n", rans);
236
237                 uint32_t num_rans_bytes = out_end - ptr;
238 #if 0
239                 if (num_rans_bytes == 4) {
240                         uint32_t block;
241                         memcpy(&block, ptr, 4);
242
243                         if (block == last_block) {
244                                 write_varint(0, codedfp);
245                                 clear();
246                                 return 1;
247                         }
248
249                         last_block = block;
250                 } else {
251                         last_block = 0;
252                 }
253 #endif
254
255                 write_varint(num_rans_bytes, codedfp);
256                 //fwrite(&num_rans_bytes, 1, 4, codedfp);
257                 fwrite(ptr, 1, num_rans_bytes, codedfp);
258
259                 //printf("first rANS bytes: %02x %02x %02x %02x %02x %02x %02x %02x\n", ptr[0], ptr[1], ptr[2], ptr[3], ptr[4], ptr[5], ptr[6], ptr[7]);
260
261
262                 clear();
263
264                 //printf("Saving block: %d rANS bytes\n", num_rans_bytes);
265                 return num_rans_bytes;
266                 //return num_rans_bytes;
267         }
268
269         void encode_coeff(short signed_k)
270         {
271                 //printf("encoding coeff %d (sym %d), rans before encoding = %08x\n", signed_k, ((abs(signed_k) - 1) & 255), rans);
272                 unsigned short k = abs(signed_k);
273                 if (k >= ESCAPE_LIMIT) {
274                         // Put the coefficient as a 1/(2^12) symbol _before_
275                         // the 255 coefficient, since the decoder will read the
276                         // 255 coefficient first.
277                         RansEncPut(&rans, &ptr, k, 1, prob_bits);
278                         k = ESCAPE_LIMIT;
279                 }
280                 RansEncPutSymbol(&rans, &ptr, &esyms[(k - 1) & (NUM_SYMS - 1)]);
281                 if (signed_k < 0) {
282                         rans += sign_bias;
283                 }
284         }
285
286 private:
287         static constexpr size_t out_max_size = 32 << 20; // 32 MB.
288         static constexpr size_t max_num_sign = 1048576;  // Way too big. And actually bytes.
289
290         unique_ptr<uint8_t[]> out_buf;
291         uint8_t *out_end;
292         uint8_t *ptr;
293         RansState rans;
294         RansEncSymbol esyms[NUM_SYMS];
295         uint32_t sign_bias;
296
297         uint32_t last_block = 0;  // Not a valid 4-byte rANS block (?)
298 };
299
300 static constexpr int dc_scalefac = 8;  // Matches the FDCT's gain.
301 static constexpr double quant_scalefac = 4.0;  // whatever?
302
303 static inline int quantize(int f, int coeff_idx)
304 {
305         if (coeff_idx == 0) {
306                 return f / dc_scalefac;
307         }
308         if (f == 0) {
309                 return 0;
310         }
311
312         const int w = std_luminance_quant_tbl[coeff_idx];
313         const int s = quant_scalefac;
314         int sign_f = (f > 0) ? 1 : -1;
315         return (32 * f + sign_f * w * s) / (2 * w * s);
316 }
317
318 static inline int unquantize(int qf, int coeff_idx)
319 {
320         if (coeff_idx == 0) {
321                 return qf * dc_scalefac;
322         }
323         if (qf == 0) {
324                 return 0;
325         }
326
327         const int w = std_luminance_quant_tbl[coeff_idx];
328         const int s = quant_scalefac;
329         return (2 * qf * w * s) / 32;
330 }
331
332 void readpix(unsigned char *ptr, const char *filename)
333 {
334         FILE *fp = fopen(filename, "rb");
335         if (fp == nullptr) {
336                 perror(filename);
337                 exit(1);
338         }
339
340         fseek(fp, 0, SEEK_END);
341         long len = ftell(fp);
342         assert(len >= WIDTH * HEIGHT * 3);
343         fseek(fp, len - WIDTH * HEIGHT * 3, SEEK_SET);
344
345         fread(ptr, 1, WIDTH * HEIGHT * 3, fp);
346         fclose(fp);
347 }
348
349 void convert_ycbcr()
350 {
351         double coeff[3] = { 0.2126, 0.7152, 0.0722 };  // sum = 1.0
352         double cb_fac = 1.0 / (coeff[0] + coeff[1] + 1.0f - coeff[2]);  // 0.539
353         double cr_fac = 1.0 / (1.0f - coeff[0] + coeff[1] + coeff[2]);  // 0.635 
354
355         unique_ptr<float[]> temp_cb(new float[WIDTH * HEIGHT]);
356         unique_ptr<float[]> temp_cr(new float[WIDTH * HEIGHT]);
357         for (unsigned yb = 0; yb < HEIGHT; ++yb) {
358                 for (unsigned xb = 0; xb < WIDTH; ++xb) {
359                         int r = rgb[((yb * WIDTH) + xb) * 3 + 0];
360                         int g = rgb[((yb * WIDTH) + xb) * 3 + 1];
361                         int b = rgb[((yb * WIDTH) + xb) * 3 + 2];
362                         double y = std::min(std::max(coeff[0] * r + coeff[1] * g + coeff[2] * b, 0.0), 255.0);
363                         double cb = (b - y) * cb_fac + 128.0;
364                         double cr = (r - y) * cr_fac + 128.0;
365                         pix_y[(yb * WIDTH) + xb] = lrint(y);
366                         temp_cb[(yb * WIDTH) + xb] = cb;
367                         temp_cr[(yb * WIDTH) + xb] = cr;
368                         full_pix_cb[(yb * WIDTH) + xb] = lrint(std::min(std::max(cb, 0.0), 255.0));
369                         full_pix_cr[(yb * WIDTH) + xb] = lrint(std::min(std::max(cr, 0.0), 255.0));
370                 }
371         }
372
373         // Simple 4:2:2 subsampling with left convention.
374         for (unsigned yb = 0; yb < HEIGHT; ++yb) {
375                 for (unsigned xb = 0; xb < WIDTH / 2; ++xb) {
376                         int c0 = yb * WIDTH + std::max(int(xb) * 2 - 1, 0);
377                         int c1 = yb * WIDTH + xb * 2;
378                         int c2 = yb * WIDTH + xb * 2 + 1;
379                         
380                         double cb = 0.25 * temp_cb[c0] + 0.5 * temp_cb[c1] + 0.25 * temp_cb[c2];
381                         double cr = 0.25 * temp_cr[c0] + 0.5 * temp_cr[c1] + 0.25 * temp_cr[c2];
382                         cb = std::min(std::max(cb, 0.0), 255.0);
383                         cr = std::min(std::max(cr, 0.0), 255.0);
384                         pix_cb[(yb * WIDTH/2) + xb] = lrint(cb);
385                         pix_cr[(yb * WIDTH/2) + xb] = lrint(cr);
386                 }
387         }
388 }
389
390 #if FIND_OPTIMAL_STREAM_ASSIGNMENT
391 double find_best_assignment(const int *medoids, int *assignment)
392 {
393         double current_score = 0.0;
394         for (int i = 0; i < 64; ++i) {
395                 int best_medoid = medoids[0];
396                 float best_medoid_score = kl_dist[i][medoids[0]];
397                 for (int j = 1; j < NUM_CLUSTERS; ++j) {
398                         if (kl_dist[i][medoids[j]] < best_medoid_score) {
399                                 best_medoid = medoids[j];
400                                 best_medoid_score = kl_dist[i][medoids[j]];
401                         }
402                 }
403                 assignment[i] = best_medoid;
404                 current_score += best_medoid_score;
405         }
406         return current_score;
407 }
408
409 void find_optimal_stream_assignment(int base)
410 {
411         double inv_sum[64];
412         for (unsigned i = 0; i < 64; ++i) {
413                 double s = 0.0;
414                 for (unsigned k = 0; k < NUM_SYMS; ++k) {
415                         s += stats[i + base].freqs[k] + 0.5;
416                 }
417                 inv_sum[i] = 1.0 / s;
418         }
419
420         for (unsigned i = 0; i < 64; ++i) {
421                 for (unsigned j = 0; j < 64; ++j) {
422                         double d = 0.0;
423                         for (unsigned k = 0; k < NUM_SYMS; ++k) {
424                                 double p1 = (stats[i + base].freqs[k] + 0.5) * inv_sum[i];
425                                 double p2 = (stats[j + base].freqs[k] + 0.5) * inv_sum[j];
426
427                                 // K-L divergence is asymmetric; this is a hack.
428                                 d += p1 * log(p1 / p2);
429                                 d += p2 * log(p2 / p1);
430                         }
431                         kl_dist[i][j] = d;
432                         //printf("%.3f ", d);
433                 }
434                 //printf("\n");
435         }
436
437         // k-medoids init
438         int medoids[64];  // only the first NUM_CLUSTERS matter
439         bool is_medoid[64] = { false };
440         std::iota(medoids, medoids + 64, 0);
441         std::random_device rd;
442         std::mt19937 g(rd());
443         std::shuffle(medoids, medoids + 64, g);
444         for (int i = 0; i < NUM_CLUSTERS; ++i) {
445                 printf("%d ", medoids[i]);
446                 is_medoid[medoids[i]] = true;
447         }
448         printf("\n");
449
450         // assign each data point to the closest medoid
451         int assignment[64];
452         double current_score = find_best_assignment(medoids, assignment);
453
454         for (int i = 0; i < 1000; ++i) {
455                 printf("iter %d\n", i);
456                 bool any_changed = false;
457                 for (int m = 0; m < NUM_CLUSTERS; ++m) {
458                         for (int o = 0; o < 64; ++o) {
459                                 if (is_medoid[o]) continue;
460                                 int old_medoid = medoids[m];
461                                 medoids[m] = o;
462
463                                 int new_assignment[64];
464                                 double candidate_score = find_best_assignment(medoids, new_assignment);
465
466                                 if (candidate_score < current_score) {
467                                         current_score = candidate_score;
468                                         memcpy(assignment, new_assignment, sizeof(assignment));
469
470                                         is_medoid[old_medoid] = false;
471                                         is_medoid[medoids[m]] = true;
472                                         printf("%f: ", current_score);
473                                         for (int i = 0; i < 64; ++i) {
474                                                 printf("%d ", assignment[i]);
475                                         }
476                                         printf("\n");
477                                         any_changed = true;
478                                 } else {
479                                         medoids[m] = old_medoid;
480                                 }
481                         }
482                 }
483                 if (!any_changed) break;
484         }
485         printf("\n");
486         std::unordered_map<int, int> rmap;
487         for (int i = 0; i < 64; ++i) {
488                 if (i % 8 == 0) printf("\n");
489                 if (!rmap.count(assignment[i])) {
490                         rmap.emplace(assignment[i], rmap.size());
491                 }
492                 printf("%d, ", rmap[assignment[i]]);
493         }
494         printf("\n");
495 }
496 #endif
497
498 int main(int argc, char **argv)
499 {
500         if (argc >= 2)
501                 readpix(rgb, argv[1]);
502         else
503                 readpix(rgb, "color.pnm");
504         convert_ycbcr();
505
506         double sum_sq_err = 0.0;
507         //double last_cb_cfl_fac = 0.0;
508         //double last_cr_cfl_fac = 0.0;
509
510         int max_val_x[8] = {0}, min_val_x[8] = {0};
511         int max_val_y[8] = {0}, min_val_y[8] = {0};
512
513         // DCT and quantize luma
514         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
515                 for (unsigned xb = 0; xb < WIDTH; xb += 8) {
516                         // Read one block
517                         short in_y[64];
518                         for (unsigned y = 0; y < 8; ++y) {
519                                 for (unsigned x = 0; x < 8; ++x) {
520                                         in_y[y * 8 + x] = pix_y[(yb + y) * WIDTH + (xb + x)];
521                                 }
522                         }
523
524                         // FDCT it
525                         fdct_int32(in_y);
526
527                         for (unsigned y = 0; y < 8; ++y) {
528                                 for (unsigned x = 0; x < 8; ++x) {
529                                         int coeff_idx = y * 8 + x;
530                                         int k = quantize(in_y[coeff_idx], coeff_idx);
531                                         coeff_y[(yb + y) * WIDTH + (xb + x)] = k;
532
533                                         max_val_x[x] = std::max(max_val_x[x], k);
534                                         min_val_x[x] = std::min(min_val_x[x], k);
535                                         max_val_y[y] = std::max(max_val_y[y], k);
536                                         min_val_y[y] = std::min(min_val_y[y], k);
537
538                                         // Store back for reconstruction / PSNR calculation
539                                         in_y[coeff_idx] = unquantize(k, coeff_idx);
540                                 }
541                         }
542
543                         idct_int32(in_y);
544
545                         for (unsigned y = 0; y < 8; ++y) {
546                                 for (unsigned x = 0; x < 8; ++x) {
547                                         int k = clamp(in_y[y * 8 + x]);
548                                         uint8_t *ptr = &pix_y[(yb + y) * WIDTH + (xb + x)];
549                                         sum_sq_err += (*ptr - k) * (*ptr - k);
550                                         *ptr = k;
551                                 }
552                         }
553                 }
554         }
555         double mse = sum_sq_err / double(WIDTH * HEIGHT);
556         double psnr_db = 20 * log10(255.0 / sqrt(mse));
557         printf("psnr = %.2f dB\n", psnr_db);
558
559         //double chroma_energy = 0.0, chroma_energy_pred = 0.0;
560
561         // DCT and quantize chroma
562         //double last_cb_cfl_fac = 0.0, last_cr_cfl_fac = 0.0;
563         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
564                 for (unsigned xb = 0; xb < WIDTH/2; xb += 8) {
565 #if 0
566                         // TF switch: Two 8x8 luma blocks -> one 16x8 block, then drop high frequencies
567                         printf("in blocks:\n");
568                         for (unsigned y = 0; y < 8; ++y) {
569                                 for (unsigned x = 0; x < 8; ++x) {
570                                         short a = coeff_y[(yb + y) * WIDTH + (xb*2 + x)];
571                                         printf(" %4d", a);
572                                 }
573                                 printf(" | ");
574                                 for (unsigned x = 0; x < 8; ++x) {
575                                         short b = coeff_y[(yb + y) * WIDTH + (xb*2 + x + 8)];
576                                         printf(" %4d", b);
577                                 }
578                                 printf("\n");
579                         }
580
581                         short in_y[64];
582                         for (unsigned y = 0; y < 8; ++y) {
583                                 for (unsigned x = 0; x < 4; ++x) {
584                                         short a = coeff_y[(yb + y) * WIDTH + (xb*2 + x)];
585                                         short b = coeff_y[(yb + y) * WIDTH + (xb*2 + x + 8)];
586                                         b = a - b;
587                                         a = 2 * a - b;
588                                         in_y[y * 8 + x * 2 + 0] = a;
589                                         in_y[y * 8 + x * 2 + 1] = b;
590                                 }
591                         }
592
593                         printf("tf-ed block:\n");
594                         for (unsigned y = 0; y < 8; ++y) {
595                                 for (unsigned x = 0; x < 8; ++x) {
596                                         short a = in_y[y * 8 + x];
597                                         printf(" %4d", a);
598                                 }
599                                 printf("\n");
600                         }
601 #else
602                         // Read Y block with no tf switch (from reconstructed luma)
603                         short in_y[64];
604                         for (unsigned y = 0; y < 8; ++y) {
605                                 for (unsigned x = 0; x < 8; ++x) {
606                                         in_y[y * 8 + x] = pix_y[(yb + y) * (WIDTH) + (xb + x) * 2];
607                                 }
608                         }
609                         fdct_int32(in_y);
610 #endif
611
612                         // Read one block
613                         short in_cb[64], in_cr[64];
614                         for (unsigned y = 0; y < 8; ++y) {
615                                 for (unsigned x = 0; x < 8; ++x) {
616                                         in_cb[y * 8 + x] = pix_cb[(yb + y) * (WIDTH/2) + (xb + x)];
617                                         in_cr[y * 8 + x] = pix_cr[(yb + y) * (WIDTH/2) + (xb + x)];
618                                 }
619                         }
620
621                         // FDCT it
622                         fdct_int32(in_cb);
623                         fdct_int32(in_cr);
624
625 #if 0
626                         // Chroma from luma
627                         double x0 = in_y[1];
628                         double x1 = in_y[8];
629                         double x2 = in_y[9];
630                         double denom = (x0 * x0 + x1 * x1 + x2 * x2);
631                         //double denom = (x1 * x1);
632         
633                         double y0 = in_cb[1];
634                         double y1 = in_cb[8];
635                         double y2 = in_cb[9];
636                         double cb_cfl_fac = (x0 * y0 + x1 * y1 + x2 * y2) / denom;
637                         //double cb_cfl_fac = (x1 * y1) / denom;
638
639                         for (unsigned y = 0; y < 8; ++y) {
640                                 for (unsigned x = 0; x < 8; ++x) {
641                                         short a = in_y[y * 8 + x];
642                                         printf(" %4d", a);
643                                 }
644                                 printf(" | ");
645                                 for (unsigned x = 0; x < 8; ++x) {
646                                         short a = in_cb[y * 8 + x];
647                                         printf(" %4d", a);
648                                 }
649                                 printf("\n");
650                         }
651                         printf("(%d,%d,%d) -> (%d,%d,%d) gives %f\n",
652                                 in_y[1], in_y[8], in_y[9], 
653                                 in_cb[1], in_cb[8], in_cb[9],
654                                 cb_cfl_fac);
655
656                         y0 = in_cr[1];
657                         y1 = in_cr[8];
658                         y2 = in_cr[9];
659                         double cr_cfl_fac = (x0 * y0 + x1 * y1 + x2 * y2) / denom;
660                         //double cr_cfl_fac = (x1 * y1) / denom;
661                         printf("cb CfL = %7.3f  dc = %5d    cr CfL = %7.3f  dc = %d\n",
662                                 cb_cfl_fac, in_cb[0] - in_y[0],
663                                 cr_cfl_fac, in_cr[0] - in_y[0]);
664
665                         if (denom == 0.0) { cb_cfl_fac = cr_cfl_fac = 0.0; }
666
667                         // CHEAT
668                         //last_cb_cfl_fac = cb_cfl_fac;
669                         //last_cr_cfl_fac = cr_cfl_fac;
670
671                         for (unsigned coeff_idx = 1; coeff_idx < 64; ++coeff_idx) {
672                                 //printf("%2d: cb = %3d prediction = %f * %3d = %7.3f\n", coeff_idx, in_cb[coeff_idx], last_cb_cfl_fac, in_y[coeff_idx], last_cb_cfl_fac * in_y[coeff_idx]);
673                                 //printf("%2d: cr = %3d prediction = %f * %3d = %7.3f\n", coeff_idx, in_cr[coeff_idx], last_cr_cfl_fac, in_y[coeff_idx], last_cr_cfl_fac * in_y[coeff_idx]);
674                                 double cb_pred = last_cb_cfl_fac * in_y[coeff_idx];
675                                 chroma_energy += in_cb[coeff_idx] * in_cb[coeff_idx];
676                                 chroma_energy_pred += (in_cb[coeff_idx] - cb_pred) * (in_cb[coeff_idx] - cb_pred);
677
678                                 //in_cb[coeff_idx] -= lrint(last_cb_cfl_fac * in_y[coeff_idx]);
679                                 //in_cr[coeff_idx] -= lrint(last_cr_cfl_fac * in_y[coeff_idx]);
680                                 //in_cr[coeff_idx] -= lrint(last_cr_cfl_fac * in_y[coeff_idx]);
681                                 //in_cb[coeff_idx] = lrint(in_y[coeff_idx] * (1.0 / sqrt(2)));
682                                 //in_cr[coeff_idx] = lrint(in_y[coeff_idx] * (1.0 / sqrt(2)));
683                                 //in_cb[coeff_idx] = lrint(in_y[coeff_idx]);
684                                 //in_cr[coeff_idx] = lrint(in_y[coeff_idx]);
685                         }
686                         //in_cb[0] += 1024;
687                         //in_cr[0] += 1024;
688                         //in_cb[0] -= in_y[0];
689                         //in_cr[0] -= in_y[0];
690 #endif
691
692                         for (unsigned y = 0; y < 8; ++y) {
693                                 for (unsigned x = 0; x < 8; ++x) {
694                                         int coeff_idx = y * 8 + x;
695                                         int k_cb = quantize(in_cb[coeff_idx], coeff_idx);
696                                         coeff_cb[(yb + y) * (WIDTH/2) + (xb + x)] = k_cb;
697                                         int k_cr = quantize(in_cr[coeff_idx], coeff_idx);
698                                         coeff_cr[(yb + y) * (WIDTH/2) + (xb + x)] = k_cr;
699
700                                         // Store back for reconstruction / PSNR calculation
701                                         in_cb[coeff_idx] = unquantize(k_cb, coeff_idx);
702                                         in_cr[coeff_idx] = unquantize(k_cr, coeff_idx);
703                                 }
704                         }
705
706                         idct_int32(in_y);  // DEBUG
707                         idct_int32(in_cb);
708                         idct_int32(in_cr);
709
710                         for (unsigned y = 0; y < 8; ++y) {
711                                 for (unsigned x = 0; x < 8; ++x) {
712                                         pix_cb[(yb + y) * (WIDTH/2) + (xb + x)] = clamp(in_cb[y * 8 + x]);
713                                         pix_cr[(yb + y) * (WIDTH/2) + (xb + x)] = clamp(in_cr[y * 8 + x]);
714
715                         //              pix_cb[(yb + y) * (WIDTH/2) + (xb + x)] = in_y[y * 8 + x];
716                         //              pix_cr[(yb + y) * (WIDTH/2) + (xb + x)] = in_y[y * 8 + x];
717                                 }
718                         }
719
720 #if 0
721                         last_cb_cfl_fac = cb_cfl_fac;
722                         last_cr_cfl_fac = cr_cfl_fac;
723 #endif
724                 }
725         }
726
727 #if 0
728         printf("chroma_energy = %f, with_pred = %f\n",
729                 chroma_energy / (WIDTH * HEIGHT), chroma_energy_pred / (WIDTH * HEIGHT));
730 #endif
731
732         // DC coefficient pred from the right to left (within each slice)
733         for (unsigned block_idx = 0; block_idx < NUM_BLOCKS; block_idx += BLOCKS_PER_STREAM) {
734                 int prev_k = 128;
735
736                 for (unsigned subblock_idx = BLOCKS_PER_STREAM; subblock_idx --> 0; ) {
737                         unsigned yb = (block_idx + subblock_idx) / WIDTH_BLOCKS;
738                         unsigned xb = (block_idx + subblock_idx) % WIDTH_BLOCKS;
739                         int k = coeff_y[(yb * 8) * WIDTH + (xb * 8)];
740
741                         coeff_y[(yb * 8) * WIDTH + (xb * 8)] = k - prev_k;
742
743                         prev_k = k;
744                 }
745         }
746         for (unsigned block_idx = 0; block_idx < NUM_BLOCKS_CHROMA; block_idx += BLOCKS_PER_STREAM) {
747                 int prev_k_cb = 0;
748                 int prev_k_cr = 0;
749
750                 for (unsigned subblock_idx = BLOCKS_PER_STREAM; subblock_idx --> 0; ) {
751                         unsigned yb = (block_idx + subblock_idx) / WIDTH_BLOCKS_CHROMA;
752                         unsigned xb = (block_idx + subblock_idx) % WIDTH_BLOCKS_CHROMA;
753                         int k_cb = coeff_cb[(yb * 8) * WIDTH/2 + (xb * 8)];
754                         int k_cr = coeff_cr[(yb * 8) * WIDTH/2 + (xb * 8)];
755
756                         coeff_cb[(yb * 8) * WIDTH/2 + (xb * 8)] = k_cb - prev_k_cb;
757                         coeff_cr[(yb * 8) * WIDTH/2 + (xb * 8)] = k_cr - prev_k_cr;
758
759                         prev_k_cb = k_cb;
760                         prev_k_cr = k_cr;
761                 }
762         }
763
764         FILE *fp = fopen("reconstructed.pgm", "wb");
765         fprintf(fp, "P5\n%d %d\n255\n", WIDTH, HEIGHT);
766         fwrite(pix_y, 1, WIDTH * HEIGHT, fp);
767         fclose(fp);
768
769         fp = fopen("reconstructed.pnm", "wb");
770         fprintf(fp, "P6\n%d %d\n255\n", WIDTH, HEIGHT);
771         for (unsigned yb = 0; yb < HEIGHT; ++yb) {
772                 for (unsigned xb = 0; xb < WIDTH; ++xb) {
773                         int y = pix_y[(yb * WIDTH) + xb];
774                         int cb, cr;
775                         int c0 = yb * (WIDTH/2) + xb/2;
776                         if (xb % 2 == 0) {
777                                 cb = pix_cb[c0] - 128.0;
778                                 cr = pix_cr[c0] - 128.0;
779                         } else {
780                                 int c1 = yb * (WIDTH/2) + std::min<int>(xb/2 + 1, WIDTH/2 - 1);
781                                 cb = 0.5 * (pix_cb[c0] + pix_cb[c1]) - 128.0;
782                                 cr = 0.5 * (pix_cr[c0] + pix_cr[c1]) - 128.0;
783                         }
784
785                         double r = y + 1.5748 * cr;
786                         double g = y - 0.1873 * cb - 0.4681 * cr;
787                         double b = y + 1.8556 * cb;
788
789                         putc(clamp(lrint(r)), fp);
790                         putc(clamp(lrint(g)), fp);
791                         putc(clamp(lrint(b)), fp);
792                 }
793         }
794         fclose(fp);
795
796         // For each coefficient, make some tables.
797         size_t extra_bits = 0;
798         for (unsigned i = 0; i < 64; ++i) {
799                 stats[i].clear();
800         }
801         for (unsigned y = 0; y < 8; ++y) {
802                 for (unsigned x = 0; x < 8; ++x) {
803                         SymbolStats &s_luma = stats[pick_stats_for(x, y, false)];
804                         SymbolStats &s_chroma = stats[pick_stats_for(x, y, true)];
805
806                         // Luma
807                         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
808                                 for (unsigned xb = 0; xb < WIDTH; xb += 8) {
809                                         unsigned short k = abs(coeff_y[(yb + y) * WIDTH + (xb + x)]);
810                                         if (k >= ESCAPE_LIMIT) {
811                                                 k = ESCAPE_LIMIT;
812                                                 extra_bits += 12;  // escape this one
813                                         }
814                                         ++s_luma.freqs[(k - 1) & (NUM_SYMS - 1)];
815                                 }
816                         }
817                         // Chroma
818                         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
819                                 for (unsigned xb = 0; xb < WIDTH/2; xb += 8) {
820                                         unsigned short k_cb = abs(coeff_cb[(yb + y) * WIDTH/2 + (xb + x)]);
821                                         unsigned short k_cr = abs(coeff_cr[(yb + y) * WIDTH/2 + (xb + x)]);
822                                         if (k_cb >= ESCAPE_LIMIT) {
823                                                 k_cb = ESCAPE_LIMIT;
824                                                 extra_bits += 12;  // escape this one
825                                         }
826                                         if (k_cr >= ESCAPE_LIMIT) {
827                                                 k_cr = ESCAPE_LIMIT;
828                                                 extra_bits += 12;  // escape this one
829                                         }
830                                         ++s_chroma.freqs[(k_cb - 1) & (NUM_SYMS - 1)];
831                                         ++s_chroma.freqs[(k_cr - 1) & (NUM_SYMS - 1)];
832                                 }
833                         }
834                 }
835         }
836
837 #if FIND_OPTIMAL_STREAM_ASSIGNMENT
838         printf("Luma:\n");
839         find_optimal_stream_assignment(0);
840         printf("Chroma:\n");
841         find_optimal_stream_assignment(64);
842         exit(0);
843 #endif
844
845         for (unsigned i = 0; i < 64; ++i) {
846                 stats[i].freqs[NUM_SYMS - 1] /= 2;  // zero, has no sign bits (yes, this is trickery)
847                 stats[i].normalize_freqs(prob_scale);
848                 stats[i].cum_freqs[NUM_SYMS] += stats[i].freqs[NUM_SYMS - 1];
849                 stats[i].freqs[NUM_SYMS - 1] *= 2;
850         }
851
852         FILE *codedfp = fopen("coded.dat", "wb");
853         if (codedfp == nullptr) {
854                 perror("coded.dat");
855                 exit(1);
856         }
857
858         // TODO: rather gamma-k or something
859         for (unsigned i = 0; i < 64; ++i) {
860                 if (stats[i].cum_freqs[NUM_SYMS] == 0) {
861                         continue;
862                 }
863                 printf("writing table %d\n", i);
864                 for (unsigned j = 0; j < NUM_SYMS; ++j) {
865                         write_varint(stats[i].freqs[j], codedfp);
866                 }
867         }
868
869         RansEncoder rans_encoder;
870
871         size_t tot_bytes = 0;
872
873         // Luma
874         for (unsigned y = 0; y < 8; ++y) {
875                 for (unsigned x = 0; x < 8; ++x) {
876                         SymbolStats &s_luma = stats[pick_stats_for(x, y, false)];
877                         rans_encoder.init_prob(s_luma);
878
879                         // Luma
880                         std::vector<int> lens;
881
882                         // need to reverse later
883                         rans_encoder.clear();
884                         size_t num_bytes = 0;
885                         for (unsigned block_idx = 0; block_idx < NUM_BLOCKS; ++block_idx) {
886                                 unsigned yb = block_idx / WIDTH_BLOCKS;
887                                 unsigned xb = block_idx % WIDTH_BLOCKS;
888
889                                 int k = coeff_y[(yb * 8 + y) * WIDTH + (xb * 8 + x)];
890                                 //printf("encoding coeff %d xb,yb=%d,%d: %d\n", y*8+x, xb, yb, k);
891                                 rans_encoder.encode_coeff(k);
892
893                                 if (block_idx % BLOCKS_PER_STREAM == (BLOCKS_PER_STREAM - 1) || block_idx == NUM_BLOCKS - 1) {
894                                         int l = rans_encoder.save_block(codedfp);
895                                         num_bytes += l;
896                                         lens.push_back(l);
897                                 }
898                         }
899                         tot_bytes += num_bytes;
900                         printf("coeff %d Y': %ld bytes\n", y * 8 + x, num_bytes);
901
902                         double sum_l = 0.0;
903                         for (int l : lens) {
904                                 sum_l += l;
905                         }
906                         double avg_l = sum_l / lens.size();
907
908                         double sum_sql = 0.0;
909                         for (int l : lens) {
910                                 sum_sql += (l - avg_l) * (l - avg_l);
911                         }
912                         double stddev_l = sqrt(sum_sql / (lens.size() - 1));
913                         printf("coeff %d: avg=%.2f bytes, stddev=%.2f bytes\n", y*8+x, avg_l, stddev_l);
914                 }
915         }
916
917         // Cb
918         for (unsigned y = 0; y < 8; ++y) {
919                 for (unsigned x = 0; x < 8; ++x) {
920                         SymbolStats &s_chroma = stats[pick_stats_for(x, y, true)];
921                         rans_encoder.init_prob(s_chroma);
922
923                         rans_encoder.clear();
924                         size_t num_bytes = 0;
925                         for (unsigned block_idx = 0; block_idx < NUM_BLOCKS_CHROMA; ++block_idx) {
926                                 unsigned yb = block_idx / WIDTH_BLOCKS_CHROMA;
927                                 unsigned xb = block_idx % WIDTH_BLOCKS_CHROMA;
928
929                                 int k = coeff_cb[(yb * 8 + y) * WIDTH/2 + (xb * 8 + x)];
930                                 //printf("encoding coeff %d xb,yb=%d,%d: %d\n", y*8+x, xb, yb, k);
931                                 rans_encoder.encode_coeff(k);
932
933                                 if (block_idx % BLOCKS_PER_STREAM == (BLOCKS_PER_STREAM - 1) || block_idx == NUM_BLOCKS - 1) {
934                                         num_bytes += rans_encoder.save_block(codedfp);
935                                 }
936                         }
937                         tot_bytes += num_bytes;
938                         printf("coeff %d Cb: %ld bytes\n", y * 8 + x, num_bytes);
939                 }
940         }
941
942         // Cr
943         for (unsigned y = 0; y < 8; ++y) {
944                 for (unsigned x = 0; x < 8; ++x) {
945                         SymbolStats &s_chroma = stats[pick_stats_for(x, y, true)];
946                         rans_encoder.init_prob(s_chroma);
947
948                         rans_encoder.clear();
949                         size_t num_bytes = 0;
950                         for (unsigned block_idx = 0; block_idx < NUM_BLOCKS_CHROMA; ++block_idx) {
951                                 unsigned yb = block_idx / WIDTH_BLOCKS_CHROMA;
952                                 unsigned xb = block_idx % WIDTH_BLOCKS_CHROMA;
953
954                                 int k = coeff_cr[(yb * 8 + y) * WIDTH/2 + (xb * 8 + x)];
955                                 //printf("encoding coeff %d xb,yb=%d,%d: %d\n", y*8+x, xb, yb, k);
956                                 rans_encoder.encode_coeff(k);
957
958                                 if (block_idx % BLOCKS_PER_STREAM == (BLOCKS_PER_STREAM - 1) || block_idx == NUM_BLOCKS - 1) {
959                                         num_bytes += rans_encoder.save_block(codedfp);
960                                 }
961                         }
962                         tot_bytes += num_bytes;
963                         printf("coeff %d Cr: %ld bytes\n", y * 8 + x, num_bytes);
964                 }
965         }
966
967         printf("%ld bytes + %ld escape bits (%ld) = %ld total bytes\n",
968                 tot_bytes - extra_bits / 8,
969                 extra_bits,
970                 extra_bits / 8,
971                 tot_bytes);
972
973 #if 0
974         printf("Max coefficient ranges (as a function of x):\n\n");
975         for (unsigned x = 0; x < 8; ++x) {
976                 int range = std::max(max_val_x[x], -min_val_x[x]);
977                 printf("  [%4d, %4d] (%.2f bits)\n", min_val_x[x], max_val_x[x], log2(range * 2 + 1));
978         }
979
980         printf("Max coefficient ranges (as a function of y):\n\n");
981         for (unsigned y = 0; y < 8; ++y) {
982                 int range = std::max(max_val_y[y], -min_val_y[y]);
983                 printf("  [%4d, %4d] (%.2f bits)\n", min_val_y[y], max_val_y[y], log2(range * 2 + 1));
984         }
985 #endif
986 }