]> git.sesse.net Git - narabu/blob - qdc.cpp
Switch to 64-bit rANS, although probably due for immediate revert (just want to prese...
[narabu] / qdc.cpp
1 #include <stdio.h>
2 #include <stdint.h>
3 #include <stdlib.h>
4 #include <string.h>
5 #include <assert.h>
6 #include <math.h>
7
8 #include "ryg_rans/rans64.h"
9 //#include "ryg_rans/rans_byte.h"
10 #include "ryg_rans/renormalize.h"
11
12 #include <algorithm>
13 #include <memory>
14 #include <numeric>
15 #include <random>
16 #include <vector>
17 #include <unordered_map>
18
19 #define WIDTH 1280
20 #define HEIGHT 720
21 #define WIDTH_BLOCKS (WIDTH/8)
22 #define WIDTH_BLOCKS_CHROMA (WIDTH/16)
23 #define HEIGHT_BLOCKS (HEIGHT/8)
24 #define NUM_BLOCKS (WIDTH_BLOCKS * HEIGHT_BLOCKS)
25 #define NUM_BLOCKS_CHROMA (WIDTH_BLOCKS_CHROMA * HEIGHT_BLOCKS)
26
27 #define NUM_SYMS 256
28 #define ESCAPE_LIMIT (NUM_SYMS - 1)
29 #define BLOCKS_PER_STREAM 320
30
31 // If you set this to 1, the program will try to optimize the placement
32 // of coefficients to rANS probability distributions. This is randomized,
33 // so you might want to run it a few times.
34 #define FIND_OPTIMAL_STREAM_ASSIGNMENT 0
35 #define NUM_CLUSTERS 4
36
37 static constexpr uint32_t prob_bits = 12;
38 static constexpr uint32_t prob_scale = 1 << prob_bits;
39
40 using namespace std;
41
42 void fdct_int32(short *const In);
43 void idct_int32(short *const In);
44
45 unsigned char rgb[WIDTH * HEIGHT * 3];
46 unsigned char pix_y[WIDTH * HEIGHT];
47 unsigned char pix_cb[(WIDTH/2) * HEIGHT];
48 unsigned char pix_cr[(WIDTH/2) * HEIGHT];
49 unsigned char full_pix_cb[WIDTH * HEIGHT];
50 unsigned char full_pix_cr[WIDTH * HEIGHT];
51 short coeff_y[WIDTH * HEIGHT], coeff_cb[(WIDTH/2) * HEIGHT], coeff_cr[(WIDTH/2) * HEIGHT];
52
53 int clamp(int x)
54 {
55         if (x < 0) return 0;
56         if (x > 255) return 255;
57         return x;
58 }
59
60 static const unsigned char std_luminance_quant_tbl[64] = {
61 #if 0
62         16,  11,  10,  16,  24,  40,  51,  61,
63         12,  12,  14,  19,  26,  58,  60,  55,
64         14,  13,  16,  24,  40,  57,  69,  56,
65         14,  17,  22,  29,  51,  87,  80,  62,
66         18,  22,  37,  56,  68, 109, 103,  77,
67         24,  35,  55,  64,  81, 104, 113,  92,
68         49,  64,  78,  87, 103, 121, 120, 101,
69         72,  92,  95,  98, 112, 100, 103,  99
70 #else
71         // ff_mpeg1_default_intra_matrix
72          8, 16, 19, 22, 26, 27, 29, 34,
73         16, 16, 22, 24, 27, 29, 34, 37,                                                 
74         19, 22, 26, 27, 29, 34, 34, 38,                                                 
75         22, 22, 26, 27, 29, 34, 37, 40,
76         22, 26, 27, 29, 32, 35, 40, 48,
77         26, 27, 29, 32, 35, 40, 48, 58,
78         26, 27, 29, 34, 38, 46, 56, 69,
79         27, 29, 35, 38, 46, 56, 69, 83
80 #endif
81 };
82
83 struct SymbolStats
84 {
85     uint32_t freqs[NUM_SYMS];
86     uint32_t cum_freqs[NUM_SYMS + 1];
87
88     void clear();
89     void calc_cum_freqs();
90     void normalize_freqs(uint32_t target_total);
91 };
92
93 void SymbolStats::clear()
94 {
95     for (int i=0; i < NUM_SYMS; i++)
96         freqs[i] = 0;
97 }
98
99 void SymbolStats::calc_cum_freqs()
100 {
101     cum_freqs[0] = 0;
102     for (int i=0; i < NUM_SYMS; i++)
103         cum_freqs[i+1] = cum_freqs[i] + freqs[i];
104 }
105
106 void SymbolStats::normalize_freqs(uint32_t target_total)
107 {
108     uint64_t real_freq[NUM_SYMS + 1];  // hack
109
110     assert(target_total >= NUM_SYMS);
111
112     calc_cum_freqs();
113     uint32_t cur_total = cum_freqs[NUM_SYMS];
114
115     if (cur_total == 0) return;
116
117     double ideal_cost = 0.0;
118     for (int i = 1; i <= NUM_SYMS; i++)
119     {
120       real_freq[i] = cum_freqs[i] - cum_freqs[i - 1];
121       if (real_freq[i] > 0)
122         ideal_cost -= real_freq[i] * log2(real_freq[i] / double(cur_total));
123     }
124
125     OptimalRenormalize(cum_freqs, NUM_SYMS, prob_scale);
126
127     // calculate updated freqs and make sure we didn't screw anything up
128     assert(cum_freqs[0] == 0 && cum_freqs[NUM_SYMS] == target_total);
129     for (int i=0; i < NUM_SYMS; i++) {
130         if (freqs[i] == 0)
131             assert(cum_freqs[i+1] == cum_freqs[i]);
132         else
133             assert(cum_freqs[i+1] > cum_freqs[i]);
134
135         // calc updated freq
136         freqs[i] = cum_freqs[i+1] - cum_freqs[i];
137     }
138
139     double calc_cost = 0.0;
140     for (int i = 1; i <= NUM_SYMS; i++)
141     {
142       uint64_t freq = cum_freqs[i] - cum_freqs[i - 1];
143       if (real_freq[i] > 0)
144         calc_cost -= real_freq[i] * log2(freq / double(target_total));
145     }
146
147     static double total_loss = 0.0;
148     total_loss += calc_cost - ideal_cost;
149     static double total_loss_with_dp = 0.0;
150         double optimal_cost = 0.0;
151     //total_loss_with_dp += optimal_cost - ideal_cost;
152     printf("ideal cost = %.0f bits, DP cost = %.0f bits, calc cost = %.0f bits (loss = %.2f bytes, total loss = %.2f bytes, total loss with DP = %.2f bytes)\n",
153                 ideal_cost, optimal_cost,
154                  calc_cost, (calc_cost - ideal_cost) / 8.0, total_loss / 8.0, total_loss_with_dp / 8.0);
155 }
156
157 SymbolStats stats[128];
158
159 #if FIND_OPTIMAL_STREAM_ASSIGNMENT
160 // Distance from one stream to the other, based on a hacked-up K-L divergence.
161 float kl_dist[64][64];
162 #endif
163
164 const int luma_mapping[64] = {
165         0, 0, 1, 1, 2, 2, 3, 3,
166         0, 0, 1, 2, 2, 2, 3, 3,
167         1, 1, 2, 2, 2, 3, 3, 3,
168         1, 1, 2, 2, 2, 3, 3, 3,
169         1, 2, 2, 2, 2, 3, 3, 3,
170         2, 2, 2, 2, 3, 3, 3, 3,
171         2, 2, 3, 3, 3, 3, 3, 3,
172         3, 3, 3, 3, 3, 3, 3, 3,
173 };
174 const int chroma_mapping[64] = {
175         0, 1, 1, 2, 2, 2, 3, 3,
176         1, 1, 2, 2, 2, 3, 3, 3,
177         2, 2, 2, 2, 3, 3, 3, 3,
178         2, 2, 2, 3, 3, 3, 3, 3,
179         2, 3, 3, 3, 3, 3, 3, 3,
180         3, 3, 3, 3, 3, 3, 3, 3,
181         3, 3, 3, 3, 3, 3, 3, 3,
182         3, 3, 3, 3, 3, 3, 3, 3,
183 };
184
185 int pick_stats_for(int x, int y, bool is_chroma)
186 {
187 #if FIND_OPTIMAL_STREAM_ASSIGNMENT
188         return y * 8 + x + is_chroma * 64;
189 #else
190         if (is_chroma) {
191                 return chroma_mapping[y * 8 + x] + 4;
192         } else {
193                 return luma_mapping[y * 8 + x];
194         }
195 #endif
196 }
197                 
198
199 void write_varint(int x, FILE *fp)
200 {
201         fwrite(&x, sizeof(x), 1, fp);
202 }
203
204 class RansEncoder {
205 public:
206         RansEncoder()
207         {
208                 out_buf.reset(new uint8_t[out_max_size]);
209                 clear();
210         }
211
212         void init_prob(SymbolStats &s)
213         {
214                 for (int i = 0; i < NUM_SYMS; i++) {
215                         //printf("%d: cumfreqs=%d freqs=%d prob_bits=%d\n", i, s.cum_freqs[i], s.freqs[i], prob_bits + 1);
216                         Rans64EncSymbolInit(&esyms[i], s.cum_freqs[i], s.freqs[i], prob_bits + 1);
217                 }
218                 sign_bias = s.cum_freqs[NUM_SYMS];
219         }
220
221         void clear()
222         {
223                 out_end = out_buf.get() + out_max_size;
224                 ptr = out_end; // *end* of output buffer
225                 Rans64EncInit(&rans);
226         }
227
228         uint32_t save_block(FILE *codedfp)  // Returns number of bytes.
229         {
230                 Rans64EncFlush(&rans, (uint32_t **)&ptr);
231                 //printf("post-flush = %08x\n", rans);
232
233                 uint32_t num_rans_bytes = out_end - ptr;
234 #if 0
235                 if (num_rans_bytes == 4) {
236                         uint32_t block;
237                         memcpy(&block, ptr, 4);
238
239                         if (block == last_block) {
240                                 write_varint(0, codedfp);
241                                 clear();
242                                 return 1;
243                         }
244
245                         last_block = block;
246                 } else {
247                         last_block = 0;
248                 }
249 #endif
250
251                 write_varint(num_rans_bytes, codedfp);
252                 //fwrite(&num_rans_bytes, 1, 4, codedfp);
253                 fwrite(ptr, 1, num_rans_bytes, codedfp);
254
255                 //printf("first rANS bytes: %02x %02x %02x %02x %02x %02x %02x %02x\n", ptr[0], ptr[1], ptr[2], ptr[3], ptr[4], ptr[5], ptr[6], ptr[7]);
256
257
258                 clear();
259
260                 //printf("Saving block: %d rANS bytes\n", num_rans_bytes);
261                 return num_rans_bytes;
262                 //return num_rans_bytes;
263         }
264
265         void encode_coeff(short signed_k)
266         {
267                 //printf("encoding coeff %d (sym %d), rans before encoding = %016lx\n", signed_k, ((abs(signed_k) - 1) & 255), rans);
268                 unsigned short k = abs(signed_k);
269                 if (k >= ESCAPE_LIMIT) {
270                         // Put the coefficient as a 1/(2^12) symbol _before_
271                         // the 255 coefficient, since the decoder will read the
272                         // 255 coefficient first.
273                         Rans64EncPut(&rans, (uint32_t **)&ptr, k, 1, prob_bits);
274                         k = ESCAPE_LIMIT;
275                 }
276                 Rans64EncPutSymbol(&rans, (uint32_t **)&ptr, &esyms[(k - 1) & (NUM_SYMS - 1)], prob_bits + 1);
277                 if (signed_k < 0) {
278                         rans += sign_bias;
279                 }
280         }
281
282 private:
283         static constexpr size_t out_max_size = 32 << 20; // 32 MB.
284         static constexpr size_t max_num_sign = 1048576;  // Way too big. And actually bytes.
285
286         unique_ptr<uint8_t[]> out_buf;
287         uint8_t *out_end;
288         uint8_t *ptr;
289         Rans64State rans;
290         Rans64EncSymbol esyms[NUM_SYMS];
291         uint32_t sign_bias;
292
293         uint32_t last_block = 0;  // Not a valid 4-byte rANS block (?)
294 };
295
296 static constexpr int dc_scalefac = 8;  // Matches the FDCT's gain.
297 static constexpr double quant_scalefac = 4.0;  // whatever?
298
299 static inline int quantize(int f, int coeff_idx)
300 {
301         if (coeff_idx == 0) {
302                 return f / dc_scalefac;
303         }
304         if (f == 0) {
305                 return 0;
306         }
307
308         const int w = std_luminance_quant_tbl[coeff_idx];
309         const int s = quant_scalefac;
310         int sign_f = (f > 0) ? 1 : -1;
311         return (32 * f + sign_f * w * s) / (2 * w * s);
312 }
313
314 static inline int unquantize(int qf, int coeff_idx)
315 {
316         if (coeff_idx == 0) {
317                 return qf * dc_scalefac;
318         }
319         if (qf == 0) {
320                 return 0;
321         }
322
323         const int w = std_luminance_quant_tbl[coeff_idx];
324         const int s = quant_scalefac;
325         return (2 * qf * w * s) / 32;
326 }
327
328 void readpix(unsigned char *ptr, const char *filename)
329 {
330         FILE *fp = fopen(filename, "rb");
331         if (fp == nullptr) {
332                 perror(filename);
333                 exit(1);
334         }
335
336         fseek(fp, 0, SEEK_END);
337         long len = ftell(fp);
338         assert(len >= WIDTH * HEIGHT * 3);
339         fseek(fp, len - WIDTH * HEIGHT * 3, SEEK_SET);
340
341         fread(ptr, 1, WIDTH * HEIGHT * 3, fp);
342         fclose(fp);
343 }
344
345 void convert_ycbcr()
346 {
347         double coeff[3] = { 0.2126, 0.7152, 0.0722 };  // sum = 1.0
348         double cb_fac = 1.0 / (coeff[0] + coeff[1] + 1.0f - coeff[2]);  // 0.539
349         double cr_fac = 1.0 / (1.0f - coeff[0] + coeff[1] + coeff[2]);  // 0.635 
350
351         unique_ptr<float[]> temp_cb(new float[WIDTH * HEIGHT]);
352         unique_ptr<float[]> temp_cr(new float[WIDTH * HEIGHT]);
353         for (unsigned yb = 0; yb < HEIGHT; ++yb) {
354                 for (unsigned xb = 0; xb < WIDTH; ++xb) {
355                         int r = rgb[((yb * WIDTH) + xb) * 3 + 0];
356                         int g = rgb[((yb * WIDTH) + xb) * 3 + 1];
357                         int b = rgb[((yb * WIDTH) + xb) * 3 + 2];
358                         double y = std::min(std::max(coeff[0] * r + coeff[1] * g + coeff[2] * b, 0.0), 255.0);
359                         double cb = (b - y) * cb_fac + 128.0;
360                         double cr = (r - y) * cr_fac + 128.0;
361                         pix_y[(yb * WIDTH) + xb] = lrint(y);
362                         temp_cb[(yb * WIDTH) + xb] = cb;
363                         temp_cr[(yb * WIDTH) + xb] = cr;
364                         full_pix_cb[(yb * WIDTH) + xb] = lrint(std::min(std::max(cb, 0.0), 255.0));
365                         full_pix_cr[(yb * WIDTH) + xb] = lrint(std::min(std::max(cr, 0.0), 255.0));
366                 }
367         }
368
369         // Simple 4:2:2 subsampling with left convention.
370         for (unsigned yb = 0; yb < HEIGHT; ++yb) {
371                 for (unsigned xb = 0; xb < WIDTH / 2; ++xb) {
372                         int c0 = yb * WIDTH + std::max(int(xb) * 2 - 1, 0);
373                         int c1 = yb * WIDTH + xb * 2;
374                         int c2 = yb * WIDTH + xb * 2 + 1;
375                         
376                         double cb = 0.25 * temp_cb[c0] + 0.5 * temp_cb[c1] + 0.25 * temp_cb[c2];
377                         double cr = 0.25 * temp_cr[c0] + 0.5 * temp_cr[c1] + 0.25 * temp_cr[c2];
378                         cb = std::min(std::max(cb, 0.0), 255.0);
379                         cr = std::min(std::max(cr, 0.0), 255.0);
380                         pix_cb[(yb * WIDTH/2) + xb] = lrint(cb);
381                         pix_cr[(yb * WIDTH/2) + xb] = lrint(cr);
382                 }
383         }
384 }
385
386 #if FIND_OPTIMAL_STREAM_ASSIGNMENT
387 double find_best_assignment(const int *medoids, int *assignment)
388 {
389         double current_score = 0.0;
390         for (int i = 0; i < 64; ++i) {
391                 int best_medoid = medoids[0];
392                 float best_medoid_score = kl_dist[i][medoids[0]];
393                 for (int j = 1; j < NUM_CLUSTERS; ++j) {
394                         if (kl_dist[i][medoids[j]] < best_medoid_score) {
395                                 best_medoid = medoids[j];
396                                 best_medoid_score = kl_dist[i][medoids[j]];
397                         }
398                 }
399                 assignment[i] = best_medoid;
400                 current_score += best_medoid_score;
401         }
402         return current_score;
403 }
404
405 void find_optimal_stream_assignment(int base)
406 {
407         double inv_sum[64];
408         for (unsigned i = 0; i < 64; ++i) {
409                 double s = 0.0;
410                 for (unsigned k = 0; k < NUM_SYMS; ++k) {
411                         s += stats[i + base].freqs[k] + 0.5;
412                 }
413                 inv_sum[i] = 1.0 / s;
414         }
415
416         for (unsigned i = 0; i < 64; ++i) {
417                 for (unsigned j = 0; j < 64; ++j) {
418                         double d = 0.0;
419                         for (unsigned k = 0; k < NUM_SYMS; ++k) {
420                                 double p1 = (stats[i + base].freqs[k] + 0.5) * inv_sum[i];
421                                 double p2 = (stats[j + base].freqs[k] + 0.5) * inv_sum[j];
422
423                                 // K-L divergence is asymmetric; this is a hack.
424                                 d += p1 * log(p1 / p2);
425                                 d += p2 * log(p2 / p1);
426                         }
427                         kl_dist[i][j] = d;
428                         //printf("%.3f ", d);
429                 }
430                 //printf("\n");
431         }
432
433         // k-medoids init
434         int medoids[64];  // only the first NUM_CLUSTERS matter
435         bool is_medoid[64] = { false };
436         std::iota(medoids, medoids + 64, 0);
437         std::random_device rd;
438         std::mt19937 g(rd());
439         std::shuffle(medoids, medoids + 64, g);
440         for (int i = 0; i < NUM_CLUSTERS; ++i) {
441                 printf("%d ", medoids[i]);
442                 is_medoid[medoids[i]] = true;
443         }
444         printf("\n");
445
446         // assign each data point to the closest medoid
447         int assignment[64];
448         double current_score = find_best_assignment(medoids, assignment);
449
450         for (int i = 0; i < 1000; ++i) {
451                 printf("iter %d\n", i);
452                 bool any_changed = false;
453                 for (int m = 0; m < NUM_CLUSTERS; ++m) {
454                         for (int o = 0; o < 64; ++o) {
455                                 if (is_medoid[o]) continue;
456                                 int old_medoid = medoids[m];
457                                 medoids[m] = o;
458
459                                 int new_assignment[64];
460                                 double candidate_score = find_best_assignment(medoids, new_assignment);
461
462                                 if (candidate_score < current_score) {
463                                         current_score = candidate_score;
464                                         memcpy(assignment, new_assignment, sizeof(assignment));
465
466                                         is_medoid[old_medoid] = false;
467                                         is_medoid[medoids[m]] = true;
468                                         printf("%f: ", current_score);
469                                         for (int i = 0; i < 64; ++i) {
470                                                 printf("%d ", assignment[i]);
471                                         }
472                                         printf("\n");
473                                         any_changed = true;
474                                 } else {
475                                         medoids[m] = old_medoid;
476                                 }
477                         }
478                 }
479                 if (!any_changed) break;
480         }
481         printf("\n");
482         std::unordered_map<int, int> rmap;
483         for (int i = 0; i < 64; ++i) {
484                 if (i % 8 == 0) printf("\n");
485                 if (!rmap.count(assignment[i])) {
486                         rmap.emplace(assignment[i], rmap.size());
487                 }
488                 printf("%d, ", rmap[assignment[i]]);
489         }
490         printf("\n");
491 }
492 #endif
493
494 int main(int argc, char **argv)
495 {
496         if (argc >= 2)
497                 readpix(rgb, argv[1]);
498         else
499                 readpix(rgb, "color.pnm");
500         convert_ycbcr();
501
502         double sum_sq_err = 0.0;
503         //double last_cb_cfl_fac = 0.0;
504         //double last_cr_cfl_fac = 0.0;
505
506         // DCT and quantize luma
507         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
508                 for (unsigned xb = 0; xb < WIDTH; xb += 8) {
509                         // Read one block
510                         short in_y[64];
511                         for (unsigned y = 0; y < 8; ++y) {
512                                 for (unsigned x = 0; x < 8; ++x) {
513                                         in_y[y * 8 + x] = pix_y[(yb + y) * WIDTH + (xb + x)];
514                                 }
515                         }
516
517                         // FDCT it
518                         fdct_int32(in_y);
519
520                         for (unsigned y = 0; y < 8; ++y) {
521                                 for (unsigned x = 0; x < 8; ++x) {
522                                         int coeff_idx = y * 8 + x;
523                                         int k = quantize(in_y[coeff_idx], coeff_idx);
524                                         coeff_y[(yb + y) * WIDTH + (xb + x)] = k;
525
526                                         // Store back for reconstruction / PSNR calculation
527                                         in_y[coeff_idx] = unquantize(k, coeff_idx);
528                                 }
529                         }
530
531                         idct_int32(in_y);
532
533                         for (unsigned y = 0; y < 8; ++y) {
534                                 for (unsigned x = 0; x < 8; ++x) {
535                                         int k = clamp(in_y[y * 8 + x]);
536                                         uint8_t *ptr = &pix_y[(yb + y) * WIDTH + (xb + x)];
537                                         sum_sq_err += (*ptr - k) * (*ptr - k);
538                                         *ptr = k;
539                                 }
540                         }
541                 }
542         }
543         double mse = sum_sq_err / double(WIDTH * HEIGHT);
544         double psnr_db = 20 * log10(255.0 / sqrt(mse));
545         printf("psnr = %.2f dB\n", psnr_db);
546
547         //double chroma_energy = 0.0, chroma_energy_pred = 0.0;
548
549         // DCT and quantize chroma
550         //double last_cb_cfl_fac = 0.0, last_cr_cfl_fac = 0.0;
551         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
552                 for (unsigned xb = 0; xb < WIDTH/2; xb += 8) {
553 #if 0
554                         // TF switch: Two 8x8 luma blocks -> one 16x8 block, then drop high frequencies
555                         printf("in blocks:\n");
556                         for (unsigned y = 0; y < 8; ++y) {
557                                 for (unsigned x = 0; x < 8; ++x) {
558                                         short a = coeff_y[(yb + y) * WIDTH + (xb*2 + x)];
559                                         printf(" %4d", a);
560                                 }
561                                 printf(" | ");
562                                 for (unsigned x = 0; x < 8; ++x) {
563                                         short b = coeff_y[(yb + y) * WIDTH + (xb*2 + x + 8)];
564                                         printf(" %4d", b);
565                                 }
566                                 printf("\n");
567                         }
568
569                         short in_y[64];
570                         for (unsigned y = 0; y < 8; ++y) {
571                                 for (unsigned x = 0; x < 4; ++x) {
572                                         short a = coeff_y[(yb + y) * WIDTH + (xb*2 + x)];
573                                         short b = coeff_y[(yb + y) * WIDTH + (xb*2 + x + 8)];
574                                         b = a - b;
575                                         a = 2 * a - b;
576                                         in_y[y * 8 + x * 2 + 0] = a;
577                                         in_y[y * 8 + x * 2 + 1] = b;
578                                 }
579                         }
580
581                         printf("tf-ed block:\n");
582                         for (unsigned y = 0; y < 8; ++y) {
583                                 for (unsigned x = 0; x < 8; ++x) {
584                                         short a = in_y[y * 8 + x];
585                                         printf(" %4d", a);
586                                 }
587                                 printf("\n");
588                         }
589 #else
590                         // Read Y block with no tf switch (from reconstructed luma)
591                         short in_y[64];
592                         for (unsigned y = 0; y < 8; ++y) {
593                                 for (unsigned x = 0; x < 8; ++x) {
594                                         in_y[y * 8 + x] = pix_y[(yb + y) * (WIDTH) + (xb + x) * 2];
595                                 }
596                         }
597                         fdct_int32(in_y);
598 #endif
599
600                         // Read one block
601                         short in_cb[64], in_cr[64];
602                         for (unsigned y = 0; y < 8; ++y) {
603                                 for (unsigned x = 0; x < 8; ++x) {
604                                         in_cb[y * 8 + x] = pix_cb[(yb + y) * (WIDTH/2) + (xb + x)];
605                                         in_cr[y * 8 + x] = pix_cr[(yb + y) * (WIDTH/2) + (xb + x)];
606                                 }
607                         }
608
609                         // FDCT it
610                         fdct_int32(in_cb);
611                         fdct_int32(in_cr);
612
613 #if 0
614                         // Chroma from luma
615                         double x0 = in_y[1];
616                         double x1 = in_y[8];
617                         double x2 = in_y[9];
618                         double denom = (x0 * x0 + x1 * x1 + x2 * x2);
619                         //double denom = (x1 * x1);
620         
621                         double y0 = in_cb[1];
622                         double y1 = in_cb[8];
623                         double y2 = in_cb[9];
624                         double cb_cfl_fac = (x0 * y0 + x1 * y1 + x2 * y2) / denom;
625                         //double cb_cfl_fac = (x1 * y1) / denom;
626
627                         for (unsigned y = 0; y < 8; ++y) {
628                                 for (unsigned x = 0; x < 8; ++x) {
629                                         short a = in_y[y * 8 + x];
630                                         printf(" %4d", a);
631                                 }
632                                 printf(" | ");
633                                 for (unsigned x = 0; x < 8; ++x) {
634                                         short a = in_cb[y * 8 + x];
635                                         printf(" %4d", a);
636                                 }
637                                 printf("\n");
638                         }
639                         printf("(%d,%d,%d) -> (%d,%d,%d) gives %f\n",
640                                 in_y[1], in_y[8], in_y[9], 
641                                 in_cb[1], in_cb[8], in_cb[9],
642                                 cb_cfl_fac);
643
644                         y0 = in_cr[1];
645                         y1 = in_cr[8];
646                         y2 = in_cr[9];
647                         double cr_cfl_fac = (x0 * y0 + x1 * y1 + x2 * y2) / denom;
648                         //double cr_cfl_fac = (x1 * y1) / denom;
649                         printf("cb CfL = %7.3f  dc = %5d    cr CfL = %7.3f  dc = %d\n",
650                                 cb_cfl_fac, in_cb[0] - in_y[0],
651                                 cr_cfl_fac, in_cr[0] - in_y[0]);
652
653                         if (denom == 0.0) { cb_cfl_fac = cr_cfl_fac = 0.0; }
654
655                         // CHEAT
656                         //last_cb_cfl_fac = cb_cfl_fac;
657                         //last_cr_cfl_fac = cr_cfl_fac;
658
659                         for (unsigned coeff_idx = 1; coeff_idx < 64; ++coeff_idx) {
660                                 //printf("%2d: cb = %3d prediction = %f * %3d = %7.3f\n", coeff_idx, in_cb[coeff_idx], last_cb_cfl_fac, in_y[coeff_idx], last_cb_cfl_fac * in_y[coeff_idx]);
661                                 //printf("%2d: cr = %3d prediction = %f * %3d = %7.3f\n", coeff_idx, in_cr[coeff_idx], last_cr_cfl_fac, in_y[coeff_idx], last_cr_cfl_fac * in_y[coeff_idx]);
662                                 double cb_pred = last_cb_cfl_fac * in_y[coeff_idx];
663                                 chroma_energy += in_cb[coeff_idx] * in_cb[coeff_idx];
664                                 chroma_energy_pred += (in_cb[coeff_idx] - cb_pred) * (in_cb[coeff_idx] - cb_pred);
665
666                                 //in_cb[coeff_idx] -= lrint(last_cb_cfl_fac * in_y[coeff_idx]);
667                                 //in_cr[coeff_idx] -= lrint(last_cr_cfl_fac * in_y[coeff_idx]);
668                                 //in_cr[coeff_idx] -= lrint(last_cr_cfl_fac * in_y[coeff_idx]);
669                                 //in_cb[coeff_idx] = lrint(in_y[coeff_idx] * (1.0 / sqrt(2)));
670                                 //in_cr[coeff_idx] = lrint(in_y[coeff_idx] * (1.0 / sqrt(2)));
671                                 //in_cb[coeff_idx] = lrint(in_y[coeff_idx]);
672                                 //in_cr[coeff_idx] = lrint(in_y[coeff_idx]);
673                         }
674                         //in_cb[0] += 1024;
675                         //in_cr[0] += 1024;
676                         //in_cb[0] -= in_y[0];
677                         //in_cr[0] -= in_y[0];
678 #endif
679
680                         for (unsigned y = 0; y < 8; ++y) {
681                                 for (unsigned x = 0; x < 8; ++x) {
682                                         int coeff_idx = y * 8 + x;
683                                         int k_cb = quantize(in_cb[coeff_idx], coeff_idx);
684                                         coeff_cb[(yb + y) * (WIDTH/2) + (xb + x)] = k_cb;
685                                         int k_cr = quantize(in_cr[coeff_idx], coeff_idx);
686                                         coeff_cr[(yb + y) * (WIDTH/2) + (xb + x)] = k_cr;
687
688                                         // Store back for reconstruction / PSNR calculation
689                                         in_cb[coeff_idx] = unquantize(k_cb, coeff_idx);
690                                         in_cr[coeff_idx] = unquantize(k_cr, coeff_idx);
691                                 }
692                         }
693
694                         idct_int32(in_y);  // DEBUG
695                         idct_int32(in_cb);
696                         idct_int32(in_cr);
697
698                         for (unsigned y = 0; y < 8; ++y) {
699                                 for (unsigned x = 0; x < 8; ++x) {
700                                         pix_cb[(yb + y) * (WIDTH/2) + (xb + x)] = clamp(in_cb[y * 8 + x]);
701                                         pix_cr[(yb + y) * (WIDTH/2) + (xb + x)] = clamp(in_cr[y * 8 + x]);
702
703                         //              pix_cb[(yb + y) * (WIDTH/2) + (xb + x)] = in_y[y * 8 + x];
704                         //              pix_cr[(yb + y) * (WIDTH/2) + (xb + x)] = in_y[y * 8 + x];
705                                 }
706                         }
707
708 #if 0
709                         last_cb_cfl_fac = cb_cfl_fac;
710                         last_cr_cfl_fac = cr_cfl_fac;
711 #endif
712                 }
713         }
714
715 #if 0
716         printf("chroma_energy = %f, with_pred = %f\n",
717                 chroma_energy / (WIDTH * HEIGHT), chroma_energy_pred / (WIDTH * HEIGHT));
718 #endif
719
720         // DC coefficient pred from the right to left (within each slice)
721         for (unsigned block_idx = 0; block_idx < NUM_BLOCKS; block_idx += BLOCKS_PER_STREAM) {
722                 int prev_k = 128;
723
724                 for (unsigned subblock_idx = BLOCKS_PER_STREAM; subblock_idx --> 0; ) {
725                         unsigned yb = (block_idx + subblock_idx) / WIDTH_BLOCKS;
726                         unsigned xb = (block_idx + subblock_idx) % WIDTH_BLOCKS;
727                         int k = coeff_y[(yb * 8) * WIDTH + (xb * 8)];
728
729                         coeff_y[(yb * 8) * WIDTH + (xb * 8)] = k - prev_k;
730
731                         prev_k = k;
732                 }
733         }
734         for (unsigned block_idx = 0; block_idx < NUM_BLOCKS_CHROMA; block_idx += BLOCKS_PER_STREAM) {
735                 int prev_k_cb = 0;
736                 int prev_k_cr = 0;
737
738                 for (unsigned subblock_idx = BLOCKS_PER_STREAM; subblock_idx --> 0; ) {
739                         unsigned yb = (block_idx + subblock_idx) / WIDTH_BLOCKS_CHROMA;
740                         unsigned xb = (block_idx + subblock_idx) % WIDTH_BLOCKS_CHROMA;
741                         int k_cb = coeff_cb[(yb * 8) * WIDTH/2 + (xb * 8)];
742                         int k_cr = coeff_cr[(yb * 8) * WIDTH/2 + (xb * 8)];
743
744                         coeff_cb[(yb * 8) * WIDTH/2 + (xb * 8)] = k_cb - prev_k_cb;
745                         coeff_cr[(yb * 8) * WIDTH/2 + (xb * 8)] = k_cr - prev_k_cr;
746
747                         prev_k_cb = k_cb;
748                         prev_k_cr = k_cr;
749                 }
750         }
751
752         FILE *fp = fopen("reconstructed.pgm", "wb");
753         fprintf(fp, "P5\n%d %d\n255\n", WIDTH, HEIGHT);
754         fwrite(pix_y, 1, WIDTH * HEIGHT, fp);
755         fclose(fp);
756
757         fp = fopen("reconstructed.pnm", "wb");
758         fprintf(fp, "P6\n%d %d\n255\n", WIDTH, HEIGHT);
759         for (unsigned yb = 0; yb < HEIGHT; ++yb) {
760                 for (unsigned xb = 0; xb < WIDTH; ++xb) {
761                         int y = pix_y[(yb * WIDTH) + xb];
762                         int cb, cr;
763                         int c0 = yb * (WIDTH/2) + xb/2;
764                         if (xb % 2 == 0) {
765                                 cb = pix_cb[c0] - 128.0;
766                                 cr = pix_cr[c0] - 128.0;
767                         } else {
768                                 int c1 = yb * (WIDTH/2) + std::min<int>(xb/2 + 1, WIDTH/2 - 1);
769                                 cb = 0.5 * (pix_cb[c0] + pix_cb[c1]) - 128.0;
770                                 cr = 0.5 * (pix_cr[c0] + pix_cr[c1]) - 128.0;
771                         }
772
773                         double r = y + 1.5748 * cr;
774                         double g = y - 0.1873 * cb - 0.4681 * cr;
775                         double b = y + 1.8556 * cb;
776
777                         putc(clamp(lrint(r)), fp);
778                         putc(clamp(lrint(g)), fp);
779                         putc(clamp(lrint(b)), fp);
780                 }
781         }
782         fclose(fp);
783
784         // For each coefficient, make some tables.
785         size_t extra_bits = 0;
786         for (unsigned i = 0; i < 64; ++i) {
787                 stats[i].clear();
788         }
789         for (unsigned y = 0; y < 8; ++y) {
790                 for (unsigned x = 0; x < 8; ++x) {
791                         SymbolStats &s_luma = stats[pick_stats_for(x, y, false)];
792                         SymbolStats &s_chroma = stats[pick_stats_for(x, y, true)];
793
794                         // Luma
795                         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
796                                 for (unsigned xb = 0; xb < WIDTH; xb += 8) {
797                                         unsigned short k = abs(coeff_y[(yb + y) * WIDTH + (xb + x)]);
798                                         if (k >= ESCAPE_LIMIT) {
799                                                 k = ESCAPE_LIMIT;
800                                                 extra_bits += 12;  // escape this one
801                                         }
802                                         ++s_luma.freqs[(k - 1) & (NUM_SYMS - 1)];
803                                 }
804                         }
805                         // Chroma
806                         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
807                                 for (unsigned xb = 0; xb < WIDTH/2; xb += 8) {
808                                         unsigned short k_cb = abs(coeff_cb[(yb + y) * WIDTH/2 + (xb + x)]);
809                                         unsigned short k_cr = abs(coeff_cr[(yb + y) * WIDTH/2 + (xb + x)]);
810                                         if (k_cb >= ESCAPE_LIMIT) {
811                                                 k_cb = ESCAPE_LIMIT;
812                                                 extra_bits += 12;  // escape this one
813                                         }
814                                         if (k_cr >= ESCAPE_LIMIT) {
815                                                 k_cr = ESCAPE_LIMIT;
816                                                 extra_bits += 12;  // escape this one
817                                         }
818                                         ++s_chroma.freqs[(k_cb - 1) & (NUM_SYMS - 1)];
819                                         ++s_chroma.freqs[(k_cr - 1) & (NUM_SYMS - 1)];
820                                 }
821                         }
822                 }
823         }
824
825 #if FIND_OPTIMAL_STREAM_ASSIGNMENT
826         printf("Luma:\n");
827         find_optimal_stream_assignment(0);
828         printf("Chroma:\n");
829         find_optimal_stream_assignment(64);
830         exit(0);
831 #endif
832
833         for (unsigned i = 0; i < 64; ++i) {
834                 stats[i].freqs[NUM_SYMS - 1] /= 2;  // zero, has no sign bits (yes, this is trickery)
835                 stats[i].normalize_freqs(prob_scale);
836                 stats[i].cum_freqs[NUM_SYMS] += stats[i].freqs[NUM_SYMS - 1];
837                 stats[i].freqs[NUM_SYMS - 1] *= 2;
838         }
839
840         FILE *codedfp = fopen("coded.dat", "wb");
841         if (codedfp == nullptr) {
842                 perror("coded.dat");
843                 exit(1);
844         }
845
846         // TODO: rather gamma-k or something
847         for (unsigned i = 0; i < 64; ++i) {
848                 if (stats[i].cum_freqs[NUM_SYMS] == 0) {
849                         continue;
850                 }
851                 printf("writing table %d\n", i);
852                 for (unsigned j = 0; j < NUM_SYMS; ++j) {
853                         write_varint(stats[i].freqs[j], codedfp);
854                 }
855         }
856
857         RansEncoder rans_encoder;
858
859         size_t tot_bytes = 0;
860
861         // Luma
862         for (unsigned y = 0; y < 8; ++y) {
863                 for (unsigned x = 0; x < 8; ++x) {
864                         SymbolStats &s_luma = stats[pick_stats_for(x, y, false)];
865                         rans_encoder.init_prob(s_luma);
866
867                         // Luma
868                         std::vector<int> lens;
869
870                         // need to reverse later
871                         rans_encoder.clear();
872                         size_t num_bytes = 0;
873                         for (unsigned block_idx = 0; block_idx < NUM_BLOCKS; ++block_idx) {
874                                 unsigned yb = block_idx / WIDTH_BLOCKS;
875                                 unsigned xb = block_idx % WIDTH_BLOCKS;
876
877                                 int k = coeff_y[(yb * 8 + y) * WIDTH + (xb * 8 + x)];
878                                 //printf("encoding coeff %d xb,yb=%d,%d: %d\n", y*8+x, xb, yb, k);
879                                 rans_encoder.encode_coeff(k);
880
881                                 if (block_idx % BLOCKS_PER_STREAM == (BLOCKS_PER_STREAM - 1) || block_idx == NUM_BLOCKS - 1) {
882                                         int l = rans_encoder.save_block(codedfp);
883                                         num_bytes += l;
884                                         lens.push_back(l);
885                                 }
886                         }
887                         tot_bytes += num_bytes;
888                         printf("coeff %d Y': %ld bytes\n", y * 8 + x, num_bytes);
889
890                         double sum_l = 0.0;
891                         for (int l : lens) {
892                                 sum_l += l;
893                         }
894                         double avg_l = sum_l / lens.size();
895
896                         double sum_sql = 0.0;
897                         for (int l : lens) {
898                                 sum_sql += (l - avg_l) * (l - avg_l);
899                         }
900                         double stddev_l = sqrt(sum_sql / (lens.size() - 1));
901                         printf("coeff %d: avg=%.2f bytes, stddev=%.2f bytes\n", y*8+x, avg_l, stddev_l);
902                 }
903         }
904
905         // Cb
906         for (unsigned y = 0; y < 8; ++y) {
907                 for (unsigned x = 0; x < 8; ++x) {
908                         SymbolStats &s_chroma = stats[pick_stats_for(x, y, true)];
909                         rans_encoder.init_prob(s_chroma);
910
911                         rans_encoder.clear();
912                         size_t num_bytes = 0;
913                         for (unsigned block_idx = 0; block_idx < NUM_BLOCKS_CHROMA; ++block_idx) {
914                                 unsigned yb = block_idx / WIDTH_BLOCKS_CHROMA;
915                                 unsigned xb = block_idx % WIDTH_BLOCKS_CHROMA;
916
917                                 int k = coeff_cb[(yb * 8 + y) * WIDTH/2 + (xb * 8 + x)];
918                                 //printf("encoding coeff %d xb,yb=%d,%d: %d\n", y*8+x, xb, yb, k);
919                                 rans_encoder.encode_coeff(k);
920
921                                 if (block_idx % BLOCKS_PER_STREAM == (BLOCKS_PER_STREAM - 1) || block_idx == NUM_BLOCKS - 1) {
922                                         num_bytes += rans_encoder.save_block(codedfp);
923                                 }
924                         }
925                         tot_bytes += num_bytes;
926                         printf("coeff %d Cb: %ld bytes\n", y * 8 + x, num_bytes);
927                 }
928         }
929
930         // Cr
931         for (unsigned y = 0; y < 8; ++y) {
932                 for (unsigned x = 0; x < 8; ++x) {
933                         SymbolStats &s_chroma = stats[pick_stats_for(x, y, true)];
934                         rans_encoder.init_prob(s_chroma);
935
936                         rans_encoder.clear();
937                         size_t num_bytes = 0;
938                         for (unsigned block_idx = 0; block_idx < NUM_BLOCKS_CHROMA; ++block_idx) {
939                                 unsigned yb = block_idx / WIDTH_BLOCKS_CHROMA;
940                                 unsigned xb = block_idx % WIDTH_BLOCKS_CHROMA;
941
942                                 int k = coeff_cr[(yb * 8 + y) * WIDTH/2 + (xb * 8 + x)];
943                                 //printf("encoding coeff %d xb,yb=%d,%d: %d\n", y*8+x, xb, yb, k);
944                                 rans_encoder.encode_coeff(k);
945
946                                 if (block_idx % BLOCKS_PER_STREAM == (BLOCKS_PER_STREAM - 1) || block_idx == NUM_BLOCKS - 1) {
947                                         num_bytes += rans_encoder.save_block(codedfp);
948                                 }
949                         }
950                         tot_bytes += num_bytes;
951                         printf("coeff %d Cr: %ld bytes\n", y * 8 + x, num_bytes);
952                 }
953         }
954
955         printf("%ld bytes + %ld escape bits (%ld) = %ld total bytes\n",
956                 tot_bytes - extra_bits / 8,
957                 extra_bits,
958                 extra_bits / 8,
959                 tot_bytes);
960 }