]> git.sesse.net Git - narabu/blob - qdc.cpp
Add some parallel slicing code (not really a win).
[narabu] / qdc.cpp
1 #include <stdio.h>
2 #include <stdint.h>
3 #include <stdlib.h>
4 #include <string.h>
5 #include <assert.h>
6 #include <math.h>
7
8 //#include "ryg_rans/rans64.h"
9 #include "ryg_rans/rans_byte.h"
10 #include "ryg_rans/renormalize.h"
11
12 #include <memory>
13 #include <vector>
14
15 #define WIDTH 1280
16 #define HEIGHT 720
17 #define NUM_SYMS 256
18 #define ESCAPE_LIMIT (NUM_SYMS - 1)
19
20 static constexpr uint32_t prob_bits = 12;
21 static constexpr uint32_t prob_scale = 1 << prob_bits;
22
23 using namespace std;
24
25 void fdct_int32(short *const In);
26 void idct_int32(short *const In);
27
28 unsigned char rgb[WIDTH * HEIGHT * 3];
29 unsigned char pix_y[WIDTH * HEIGHT];
30 unsigned char pix_cb[(WIDTH/2) * HEIGHT];
31 unsigned char pix_cr[(WIDTH/2) * HEIGHT];
32 unsigned char full_pix_cb[WIDTH * HEIGHT];
33 unsigned char full_pix_cr[WIDTH * HEIGHT];
34 short coeff_y[WIDTH * HEIGHT], coeff_cb[(WIDTH/2) * HEIGHT], coeff_cr[(WIDTH/2) * HEIGHT];
35
36 int clamp(int x)
37 {
38         if (x < 0) return 0;
39         if (x > 255) return 255;
40         return x;
41 }
42
43 static const unsigned char std_luminance_quant_tbl[64] = {
44 #if 0
45         16,  11,  10,  16,  24,  40,  51,  61,
46         12,  12,  14,  19,  26,  58,  60,  55,
47         14,  13,  16,  24,  40,  57,  69,  56,
48         14,  17,  22,  29,  51,  87,  80,  62,
49         18,  22,  37,  56,  68, 109, 103,  77,
50         24,  35,  55,  64,  81, 104, 113,  92,
51         49,  64,  78,  87, 103, 121, 120, 101,
52         72,  92,  95,  98, 112, 100, 103,  99
53 #else
54         // ff_mpeg1_default_intra_matrix
55          8, 16, 19, 22, 26, 27, 29, 34,
56         16, 16, 22, 24, 27, 29, 34, 37,                                                 
57         19, 22, 26, 27, 29, 34, 34, 38,                                                 
58         22, 22, 26, 27, 29, 34, 37, 40,
59         22, 26, 27, 29, 32, 35, 40, 48,
60         26, 27, 29, 32, 35, 40, 48, 58,
61         26, 27, 29, 34, 38, 46, 56, 69,
62         27, 29, 35, 38, 46, 56, 69, 83
63 #endif
64 };
65
66 struct SymbolStats
67 {
68     uint32_t freqs[NUM_SYMS];
69     uint32_t cum_freqs[NUM_SYMS + 1];
70
71     void clear();
72     void calc_cum_freqs();
73     void normalize_freqs(uint32_t target_total);
74 };
75
76 void SymbolStats::clear()
77 {
78     for (int i=0; i < NUM_SYMS; i++)
79         freqs[i] = 0;
80 }
81
82 void SymbolStats::calc_cum_freqs()
83 {
84     cum_freqs[0] = 0;
85     for (int i=0; i < NUM_SYMS; i++)
86         cum_freqs[i+1] = cum_freqs[i] + freqs[i];
87 }
88
89 void SymbolStats::normalize_freqs(uint32_t target_total)
90 {
91     uint64_t real_freq[NUM_SYMS + 1];  // hack
92
93     assert(target_total >= NUM_SYMS);
94
95     calc_cum_freqs();
96     uint32_t cur_total = cum_freqs[NUM_SYMS];
97
98     if (cur_total == 0) return;
99
100     double ideal_cost = 0.0;
101     for (int i = 1; i <= NUM_SYMS; i++)
102     {
103       real_freq[i] = cum_freqs[i] - cum_freqs[i - 1];
104       if (real_freq[i] > 0)
105         ideal_cost -= real_freq[i] * log2(real_freq[i] / double(cur_total));
106     }
107
108     OptimalRenormalize(cum_freqs, NUM_SYMS, prob_scale);
109
110     // calculate updated freqs and make sure we didn't screw anything up
111     assert(cum_freqs[0] == 0 && cum_freqs[NUM_SYMS] == target_total);
112     for (int i=0; i < NUM_SYMS; i++) {
113         if (freqs[i] == 0)
114             assert(cum_freqs[i+1] == cum_freqs[i]);
115         else
116             assert(cum_freqs[i+1] > cum_freqs[i]);
117
118         // calc updated freq
119         freqs[i] = cum_freqs[i+1] - cum_freqs[i];
120     }
121
122     double calc_cost = 0.0;
123     for (int i = 1; i <= NUM_SYMS; i++)
124     {
125       uint64_t freq = cum_freqs[i] - cum_freqs[i - 1];
126       if (real_freq[i] > 0)
127         calc_cost -= real_freq[i] * log2(freq / double(target_total));
128     }
129
130     static double total_loss = 0.0;
131     total_loss += calc_cost - ideal_cost;
132     static double total_loss_with_dp = 0.0;
133         double optimal_cost = 0.0;
134     //total_loss_with_dp += optimal_cost - ideal_cost;
135     printf("ideal cost = %.0f bits, DP cost = %.0f bits, calc cost = %.0f bits (loss = %.2f bytes, total loss = %.2f bytes, total loss with DP = %.2f bytes)\n",
136                 ideal_cost, optimal_cost,
137                  calc_cost, (calc_cost - ideal_cost) / 8.0, total_loss / 8.0, total_loss_with_dp / 8.0);
138 }
139
140 SymbolStats stats[64];
141
142 int pick_stats_for(int y, int x)
143 {
144         //return 0;
145         //return std::min<int>(hypot(x, y), 7);
146         return std::min<int>(x + y, 7);
147         //if (x + y >= 7) return 7;
148         //return x + y;
149         //return y * 8 + x;
150 #if 0
151         if (y == 0 && x == 0) {
152                 return 0;
153         } else {
154                 return 1;
155         }
156 #endif
157 }
158                 
159
160 void write_varint(int x, FILE *fp)
161 {
162         while (x >= 128) {
163                 putc((x & 0x7f) | 0x80, fp);
164                 x >>= 7;
165         }
166         putc(x, fp);
167 }
168
169 class RansEncoder {
170 public:
171         RansEncoder()
172         {
173                 out_buf.reset(new uint8_t[out_max_size]);
174                 clear();
175         }
176
177         void init_prob(SymbolStats &s)
178         {
179                 for (int i = 0; i < NUM_SYMS; i++) {
180                         printf("%d: cumfreqs=%d freqs=%d prob_bits=%d\n", i, s.cum_freqs[i], s.freqs[i], prob_bits + 1);
181                         RansEncSymbolInit(&esyms[i], s.cum_freqs[i], s.freqs[i], prob_bits + 1);
182                 }
183                 sign_bias = s.cum_freqs[256];
184         }
185
186         void clear()
187         {
188                 out_end = out_buf.get() + out_max_size;
189                 ptr = out_end; // *end* of output buffer
190                 RansEncInit(&rans);
191         }
192
193         uint32_t save_block(FILE *codedfp)  // Returns number of bytes.
194         {
195                 RansEncFlush(&rans, &ptr);
196                 //printf("post-flush = %08x\n", rans);
197
198                 uint32_t num_rans_bytes = out_end - ptr;
199 #if 0
200                 if (num_rans_bytes == 4) {
201                         uint32_t block;
202                         memcpy(&block, ptr, 4);
203
204                         if (block == last_block) {
205                                 write_varint(0, codedfp);
206                                 clear();
207                                 return 1;
208                         }
209
210                         last_block = block;
211                 } else {
212                         last_block = 0;
213                 }
214 #endif
215
216                 write_varint(num_rans_bytes, codedfp);
217                 //fwrite(&num_rans_bytes, 1, 4, codedfp);
218                 fwrite(ptr, 1, num_rans_bytes, codedfp);
219
220                 //printf("first rANS bytes: %02x %02x %02x %02x %02x %02x %02x %02x\n", ptr[0], ptr[1], ptr[2], ptr[3], ptr[4], ptr[5], ptr[6], ptr[7]);
221
222
223                 clear();
224
225                 printf("Saving block: %d rANS bytes\n", num_rans_bytes);
226                 return num_rans_bytes;
227                 //return num_rans_bytes;
228         }
229
230         void encode_coeff(short signed_k)
231         {
232                 //printf("encoding coeff %d (sym %d), rans before encoding = %08x\n", signed_k, ((abs(signed_k) - 1) & 255), rans);
233                 unsigned short k = abs(signed_k);
234                 if (k >= ESCAPE_LIMIT) {
235                         // Put the coefficient as a 1/(2^12) symbol _before_
236                         // the 255 coefficient, since the decoder will read the
237                         // 255 coefficient first.
238                         RansEncPut(&rans, &ptr, k, 1, prob_bits);
239                         k = ESCAPE_LIMIT;
240                 }
241                 RansEncPutSymbol(&rans, &ptr, &esyms[(k - 1) & 255]);
242                 if (signed_k < 0) {
243                         rans += sign_bias;
244                 }
245         }
246
247 private:
248         static constexpr size_t out_max_size = 32 << 20; // 32 MB.
249         static constexpr size_t max_num_sign = 1048576;  // Way too big. And actually bytes.
250
251         unique_ptr<uint8_t[]> out_buf;
252         uint8_t *out_end;
253         uint8_t *ptr;
254         RansState rans;
255         RansEncSymbol esyms[NUM_SYMS];
256         uint32_t sign_bias;
257
258         uint32_t last_block = 0;  // Not a valid 4-byte rANS block (?)
259 };
260
261 static constexpr int dc_scalefac = 8;  // Matches the FDCT's gain.
262 static constexpr double quant_scalefac = 4.0;  // whatever?
263
264 static inline int quantize(int f, int coeff_idx)
265 {
266         if (coeff_idx == 0) {
267                 return f / dc_scalefac;
268         }
269         if (f == 0) {
270                 return 0;
271         }
272
273         const int w = std_luminance_quant_tbl[coeff_idx];
274         const int s = quant_scalefac;
275         int sign_f = (f > 0) ? 1 : -1;
276         return (32 * f + sign_f * w * s) / (2 * w * s);
277 }
278
279 static inline int unquantize(int qf, int coeff_idx)
280 {
281         if (coeff_idx == 0) {
282                 return qf * dc_scalefac;
283         }
284         if (qf == 0) {
285                 return 0;
286         }
287
288         const int w = std_luminance_quant_tbl[coeff_idx];
289         const int s = quant_scalefac;
290         return (2 * qf * w * s) / 32;
291 }
292
293 void readpix(unsigned char *ptr, const char *filename)
294 {
295         FILE *fp = fopen(filename, "rb");
296         if (fp == nullptr) {
297                 perror(filename);
298                 exit(1);
299         }
300
301         fseek(fp, 0, SEEK_END);
302         long len = ftell(fp);
303         assert(len >= WIDTH * HEIGHT * 3);
304         fseek(fp, len - WIDTH * HEIGHT * 3, SEEK_SET);
305
306         fread(ptr, 1, WIDTH * HEIGHT * 3, fp);
307         fclose(fp);
308 }
309
310 void convert_ycbcr()
311 {
312         double coeff[3] = { 0.2126, 0.7152, 0.0722 };  // sum = 1.0
313         double cb_fac = 1.0 / (coeff[0] + coeff[1] + 1.0f - coeff[2]);  // 0.539
314         double cr_fac = 1.0 / (1.0f - coeff[0] + coeff[1] + coeff[2]);  // 0.635 
315
316         unique_ptr<float[]> temp_cb(new float[WIDTH * HEIGHT]);
317         unique_ptr<float[]> temp_cr(new float[WIDTH * HEIGHT]);
318         for (unsigned yb = 0; yb < HEIGHT; ++yb) {
319                 for (unsigned xb = 0; xb < WIDTH; ++xb) {
320                         int r = rgb[((yb * WIDTH) + xb) * 3 + 0];
321                         int g = rgb[((yb * WIDTH) + xb) * 3 + 1];
322                         int b = rgb[((yb * WIDTH) + xb) * 3 + 2];
323                         double y = std::min(std::max(coeff[0] * r + coeff[1] * g + coeff[2] * b, 0.0), 255.0);
324                         double cb = (b - y) * cb_fac + 128.0;
325                         double cr = (r - y) * cr_fac + 128.0;
326                         pix_y[(yb * WIDTH) + xb] = lrint(y);
327                         temp_cb[(yb * WIDTH) + xb] = cb;
328                         temp_cr[(yb * WIDTH) + xb] = cr;
329                         full_pix_cb[(yb * WIDTH) + xb] = lrint(std::min(std::max(cb, 0.0), 255.0));
330                         full_pix_cr[(yb * WIDTH) + xb] = lrint(std::min(std::max(cr, 0.0), 255.0));
331                 }
332         }
333
334         // Simple 4:2:2 subsampling with left convention.
335         for (unsigned yb = 0; yb < HEIGHT; ++yb) {
336                 for (unsigned xb = 0; xb < WIDTH / 2; ++xb) {
337                         int c0 = yb * WIDTH + std::max(int(xb) * 2 - 1, 0);
338                         int c1 = yb * WIDTH + xb * 2;
339                         int c2 = yb * WIDTH + xb * 2 + 1;
340                         
341                         double cb = 0.25 * temp_cb[c0] + 0.5 * temp_cb[c1] + 0.25 * temp_cb[c2];
342                         double cr = 0.25 * temp_cr[c0] + 0.5 * temp_cr[c1] + 0.25 * temp_cr[c2];
343                         cb = std::min(std::max(cb, 0.0), 255.0);
344                         cr = std::min(std::max(cr, 0.0), 255.0);
345                         pix_cb[(yb * WIDTH/2) + xb] = lrint(cb);
346                         pix_cr[(yb * WIDTH/2) + xb] = lrint(cr);
347                 }
348         }
349 }
350
351 int main(int argc, char **argv)
352 {
353         if (argc >= 2)
354                 readpix(rgb, argv[1]);
355         else
356                 readpix(rgb, "color.pnm");
357         convert_ycbcr();
358
359         double sum_sq_err = 0.0;
360         //double last_cb_cfl_fac = 0.0;
361         //double last_cr_cfl_fac = 0.0;
362
363         // DCT and quantize luma
364         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
365                 for (unsigned xb = 0; xb < WIDTH; xb += 8) {
366                         // Read one block
367                         short in_y[64];
368                         for (unsigned y = 0; y < 8; ++y) {
369                                 for (unsigned x = 0; x < 8; ++x) {
370                                         in_y[y * 8 + x] = pix_y[(yb + y) * WIDTH + (xb + x)];
371                                 }
372                         }
373
374                         // FDCT it
375                         fdct_int32(in_y);
376
377                         for (unsigned y = 0; y < 8; ++y) {
378                                 for (unsigned x = 0; x < 8; ++x) {
379                                         int coeff_idx = y * 8 + x;
380                                         int k = quantize(in_y[coeff_idx], coeff_idx);
381                                         coeff_y[(yb + y) * WIDTH + (xb + x)] = k;
382
383                                         // Store back for reconstruction / PSNR calculation
384                                         in_y[coeff_idx] = unquantize(k, coeff_idx);
385                                 }
386                         }
387
388                         idct_int32(in_y);
389
390                         for (unsigned y = 0; y < 8; ++y) {
391                                 for (unsigned x = 0; x < 8; ++x) {
392                                         int k = clamp(in_y[y * 8 + x]);
393                                         uint8_t *ptr = &pix_y[(yb + y) * WIDTH + (xb + x)];
394                                         sum_sq_err += (*ptr - k) * (*ptr - k);
395                                         *ptr = k;
396                                 }
397                         }
398                 }
399         }
400         double mse = sum_sq_err / double(WIDTH * HEIGHT);
401         double psnr_db = 20 * log10(255.0 / sqrt(mse));
402         printf("psnr = %.2f dB\n", psnr_db);
403
404         //double chroma_energy = 0.0, chroma_energy_pred = 0.0;
405
406         // DCT and quantize chroma
407         //double last_cb_cfl_fac = 0.0, last_cr_cfl_fac = 0.0;
408         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
409                 for (unsigned xb = 0; xb < WIDTH/2; xb += 8) {
410 #if 0
411                         // TF switch: Two 8x8 luma blocks -> one 16x8 block, then drop high frequencies
412                         printf("in blocks:\n");
413                         for (unsigned y = 0; y < 8; ++y) {
414                                 for (unsigned x = 0; x < 8; ++x) {
415                                         short a = coeff_y[(yb + y) * WIDTH + (xb*2 + x)];
416                                         printf(" %4d", a);
417                                 }
418                                 printf(" | ");
419                                 for (unsigned x = 0; x < 8; ++x) {
420                                         short b = coeff_y[(yb + y) * WIDTH + (xb*2 + x + 8)];
421                                         printf(" %4d", b);
422                                 }
423                                 printf("\n");
424                         }
425
426                         short in_y[64];
427                         for (unsigned y = 0; y < 8; ++y) {
428                                 for (unsigned x = 0; x < 4; ++x) {
429                                         short a = coeff_y[(yb + y) * WIDTH + (xb*2 + x)];
430                                         short b = coeff_y[(yb + y) * WIDTH + (xb*2 + x + 8)];
431                                         b = a - b;
432                                         a = 2 * a - b;
433                                         in_y[y * 8 + x * 2 + 0] = a;
434                                         in_y[y * 8 + x * 2 + 1] = b;
435                                 }
436                         }
437
438                         printf("tf-ed block:\n");
439                         for (unsigned y = 0; y < 8; ++y) {
440                                 for (unsigned x = 0; x < 8; ++x) {
441                                         short a = in_y[y * 8 + x];
442                                         printf(" %4d", a);
443                                 }
444                                 printf("\n");
445                         }
446 #else
447                         // Read Y block with no tf switch (from reconstructed luma)
448                         short in_y[64];
449                         for (unsigned y = 0; y < 8; ++y) {
450                                 for (unsigned x = 0; x < 8; ++x) {
451                                         in_y[y * 8 + x] = pix_y[(yb + y) * (WIDTH) + (xb + x) * 2];
452                                 }
453                         }
454                         fdct_int32(in_y);
455 #endif
456
457                         // Read one block
458                         short in_cb[64], in_cr[64];
459                         for (unsigned y = 0; y < 8; ++y) {
460                                 for (unsigned x = 0; x < 8; ++x) {
461                                         in_cb[y * 8 + x] = pix_cb[(yb + y) * (WIDTH/2) + (xb + x)];
462                                         in_cr[y * 8 + x] = pix_cr[(yb + y) * (WIDTH/2) + (xb + x)];
463                                 }
464                         }
465
466                         // FDCT it
467                         fdct_int32(in_cb);
468                         fdct_int32(in_cr);
469
470 #if 0
471                         // Chroma from luma
472                         double x0 = in_y[1];
473                         double x1 = in_y[8];
474                         double x2 = in_y[9];
475                         double denom = (x0 * x0 + x1 * x1 + x2 * x2);
476                         //double denom = (x1 * x1);
477         
478                         double y0 = in_cb[1];
479                         double y1 = in_cb[8];
480                         double y2 = in_cb[9];
481                         double cb_cfl_fac = (x0 * y0 + x1 * y1 + x2 * y2) / denom;
482                         //double cb_cfl_fac = (x1 * y1) / denom;
483
484                         for (unsigned y = 0; y < 8; ++y) {
485                                 for (unsigned x = 0; x < 8; ++x) {
486                                         short a = in_y[y * 8 + x];
487                                         printf(" %4d", a);
488                                 }
489                                 printf(" | ");
490                                 for (unsigned x = 0; x < 8; ++x) {
491                                         short a = in_cb[y * 8 + x];
492                                         printf(" %4d", a);
493                                 }
494                                 printf("\n");
495                         }
496                         printf("(%d,%d,%d) -> (%d,%d,%d) gives %f\n",
497                                 in_y[1], in_y[8], in_y[9], 
498                                 in_cb[1], in_cb[8], in_cb[9],
499                                 cb_cfl_fac);
500
501                         y0 = in_cr[1];
502                         y1 = in_cr[8];
503                         y2 = in_cr[9];
504                         double cr_cfl_fac = (x0 * y0 + x1 * y1 + x2 * y2) / denom;
505                         //double cr_cfl_fac = (x1 * y1) / denom;
506                         printf("cb CfL = %7.3f  dc = %5d    cr CfL = %7.3f  dc = %d\n",
507                                 cb_cfl_fac, in_cb[0] - in_y[0],
508                                 cr_cfl_fac, in_cr[0] - in_y[0]);
509
510                         if (denom == 0.0) { cb_cfl_fac = cr_cfl_fac = 0.0; }
511
512                         // CHEAT
513                         //last_cb_cfl_fac = cb_cfl_fac;
514                         //last_cr_cfl_fac = cr_cfl_fac;
515
516                         for (unsigned coeff_idx = 1; coeff_idx < 64; ++coeff_idx) {
517                                 //printf("%2d: cb = %3d prediction = %f * %3d = %7.3f\n", coeff_idx, in_cb[coeff_idx], last_cb_cfl_fac, in_y[coeff_idx], last_cb_cfl_fac * in_y[coeff_idx]);
518                                 //printf("%2d: cr = %3d prediction = %f * %3d = %7.3f\n", coeff_idx, in_cr[coeff_idx], last_cr_cfl_fac, in_y[coeff_idx], last_cr_cfl_fac * in_y[coeff_idx]);
519                                 double cb_pred = last_cb_cfl_fac * in_y[coeff_idx];
520                                 chroma_energy += in_cb[coeff_idx] * in_cb[coeff_idx];
521                                 chroma_energy_pred += (in_cb[coeff_idx] - cb_pred) * (in_cb[coeff_idx] - cb_pred);
522
523                                 //in_cb[coeff_idx] -= lrint(last_cb_cfl_fac * in_y[coeff_idx]);
524                                 //in_cr[coeff_idx] -= lrint(last_cr_cfl_fac * in_y[coeff_idx]);
525                                 //in_cr[coeff_idx] -= lrint(last_cr_cfl_fac * in_y[coeff_idx]);
526                                 //in_cb[coeff_idx] = lrint(in_y[coeff_idx] * (1.0 / sqrt(2)));
527                                 //in_cr[coeff_idx] = lrint(in_y[coeff_idx] * (1.0 / sqrt(2)));
528                                 //in_cb[coeff_idx] = lrint(in_y[coeff_idx]);
529                                 //in_cr[coeff_idx] = lrint(in_y[coeff_idx]);
530                         }
531                         //in_cb[0] += 1024;
532                         //in_cr[0] += 1024;
533                         //in_cb[0] -= in_y[0];
534                         //in_cr[0] -= in_y[0];
535 #endif
536
537                         for (unsigned y = 0; y < 8; ++y) {
538                                 for (unsigned x = 0; x < 8; ++x) {
539                                         int coeff_idx = y * 8 + x;
540                                         int k_cb = quantize(in_cb[coeff_idx], coeff_idx);
541                                         coeff_cb[(yb + y) * (WIDTH/2) + (xb + x)] = k_cb;
542                                         int k_cr = quantize(in_cr[coeff_idx], coeff_idx);
543                                         coeff_cr[(yb + y) * (WIDTH/2) + (xb + x)] = k_cr;
544
545                                         // Store back for reconstruction / PSNR calculation
546                                         in_cb[coeff_idx] = unquantize(k_cb, coeff_idx);
547                                         in_cr[coeff_idx] = unquantize(k_cr, coeff_idx);
548                                 }
549                         }
550
551                         idct_int32(in_y);  // DEBUG
552                         idct_int32(in_cb);
553                         idct_int32(in_cr);
554
555                         for (unsigned y = 0; y < 8; ++y) {
556                                 for (unsigned x = 0; x < 8; ++x) {
557                                         pix_cb[(yb + y) * (WIDTH/2) + (xb + x)] = clamp(in_cb[y * 8 + x]);
558                                         pix_cr[(yb + y) * (WIDTH/2) + (xb + x)] = clamp(in_cr[y * 8 + x]);
559
560                         //              pix_cb[(yb + y) * (WIDTH/2) + (xb + x)] = in_y[y * 8 + x];
561                         //              pix_cr[(yb + y) * (WIDTH/2) + (xb + x)] = in_y[y * 8 + x];
562                                 }
563                         }
564
565 #if 0
566                         last_cb_cfl_fac = cb_cfl_fac;
567                         last_cr_cfl_fac = cr_cfl_fac;
568 #endif
569                 }
570         }
571
572 #if 0
573         printf("chroma_energy = %f, with_pred = %f\n",
574                 chroma_energy / (WIDTH * HEIGHT), chroma_energy_pred / (WIDTH * HEIGHT));
575 #endif
576
577         // DC coefficient pred from the right to left
578         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
579                 for (unsigned xb = 0; xb < WIDTH - 8; xb += 8) {
580                         coeff_y[yb * WIDTH + xb] -= coeff_y[yb * WIDTH + (xb + 8)];
581                 }
582         }
583
584         FILE *fp = fopen("reconstructed.pgm", "wb");
585         fprintf(fp, "P5\n%d %d\n255\n", WIDTH, HEIGHT);
586         fwrite(pix_y, 1, WIDTH * HEIGHT, fp);
587         fclose(fp);
588
589         fp = fopen("reconstructed.pnm", "wb");
590         fprintf(fp, "P6\n%d %d\n255\n", WIDTH, HEIGHT);
591         for (unsigned yb = 0; yb < HEIGHT; ++yb) {
592                 for (unsigned xb = 0; xb < WIDTH; ++xb) {
593                         int y = pix_y[(yb * WIDTH) + xb];
594                         int cb, cr;
595                         int c0 = yb * (WIDTH/2) + xb/2;
596                         if (xb % 2 == 0) {
597                                 cb = pix_cb[c0] - 128.0;
598                                 cr = pix_cr[c0] - 128.0;
599                         } else {
600                                 int c1 = yb * (WIDTH/2) + std::min<int>(xb/2 + 1, WIDTH/2 - 1);
601                                 cb = 0.5 * (pix_cb[c0] + pix_cb[c1]) - 128.0;
602                                 cr = 0.5 * (pix_cr[c0] + pix_cr[c1]) - 128.0;
603                         }
604
605                         double r = y + 1.5748 * cr;
606                         double g = y - 0.1873 * cb - 0.4681 * cr;
607                         double b = y + 1.8556 * cb;
608
609                         putc(clamp(lrint(r)), fp);
610                         putc(clamp(lrint(g)), fp);
611                         putc(clamp(lrint(b)), fp);
612                 }
613         }
614         fclose(fp);
615
616         // For each coefficient, make some tables.
617         size_t extra_bits = 0;
618         for (unsigned i = 0; i < 64; ++i) {
619                 stats[i].clear();
620         }
621         for (unsigned y = 0; y < 8; ++y) {
622                 for (unsigned x = 0; x < 8; ++x) {
623                         SymbolStats &s_luma = stats[pick_stats_for(x, y)];
624                         SymbolStats &s_chroma = stats[pick_stats_for(x, y) + 8];  // HACK
625                         //SymbolStats &s_chroma = stats[pick_stats_for(x, y)];
626
627                         // Luma
628                         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
629                                 for (unsigned xb = 0; xb < WIDTH; xb += 8) {
630                                         unsigned short k = abs(coeff_y[(yb + y) * WIDTH + (xb + x)]);
631                                         if (k >= ESCAPE_LIMIT) {
632                                                 k = ESCAPE_LIMIT;
633                                                 extra_bits += 12;  // escape this one
634                                         }
635                                         ++s_luma.freqs[(k - 1) & 255];
636                                 }
637                         }
638                         // Chroma
639                         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
640                                 for (unsigned xb = 0; xb < WIDTH/2; xb += 8) {
641                                         unsigned short k_cb = abs(coeff_cb[(yb + y) * WIDTH/2 + (xb + x)]);
642                                         unsigned short k_cr = abs(coeff_cr[(yb + y) * WIDTH/2 + (xb + x)]);
643                                         if (k_cb >= ESCAPE_LIMIT) {
644                                                 k_cb = ESCAPE_LIMIT;
645                                                 extra_bits += 12;  // escape this one
646                                         }
647                                         if (k_cr >= ESCAPE_LIMIT) {
648                                                 k_cr = ESCAPE_LIMIT;
649                                                 extra_bits += 12;  // escape this one
650                                         }
651                                         ++s_chroma.freqs[(k_cb - 1) & 255];
652                                         ++s_chroma.freqs[(k_cr - 1) & 255];
653                                 }
654                         }
655                 }
656         }
657         for (unsigned i = 0; i < 64; ++i) {
658                 stats[i].freqs[255] /= 2;  // zero, has no sign bits (yes, this is trickery)
659                 stats[i].normalize_freqs(prob_scale);
660                 stats[i].cum_freqs[256] += stats[i].freqs[255];
661                 stats[i].freqs[255] *= 2;
662         }
663
664         FILE *codedfp = fopen("coded.dat", "wb");
665         if (codedfp == nullptr) {
666                 perror("coded.dat");
667                 exit(1);
668         }
669
670         // TODO: rather gamma-k or something
671         for (unsigned i = 0; i < 64; ++i) {
672                 if (stats[i].cum_freqs[NUM_SYMS] == 0) {
673                         continue;
674                 }
675                 printf("writing table %d\n", i);
676                 for (unsigned j = 0; j < NUM_SYMS; ++j) {
677                         write_varint(stats[i].freqs[j], codedfp);
678                 }
679         }
680
681         RansEncoder rans_encoder;
682
683         size_t tot_bytes = 0;
684
685         // Luma
686         for (unsigned y = 0; y < 8; ++y) {
687                 for (unsigned x = 0; x < 8; ++x) {
688                         SymbolStats &s_luma = stats[pick_stats_for(x, y)];
689                         rans_encoder.init_prob(s_luma);
690
691                         // Luma
692                         std::vector<int> lens;
693
694                         // need to reverse later
695                         rans_encoder.clear();
696                         size_t num_bytes = 0;
697                         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
698                                 for (unsigned xb = 0; xb < WIDTH; xb += 8) {
699                                         int k = coeff_y[(yb + y) * WIDTH + (xb + x)];
700                                         //printf("encoding coeff %d xb,yb=%d,%d: %d\n", y*8+x, xb, yb, k);
701                                         rans_encoder.encode_coeff(k);
702                                 }
703                                 if (yb % 16 == 8) {
704                                         int l = rans_encoder.save_block(codedfp);
705                                         num_bytes += l;
706                                         lens.push_back(l);
707                                 }
708                         }
709                         if (HEIGHT % 16 != 0) {
710                                 num_bytes += rans_encoder.save_block(codedfp);
711                         }
712                         tot_bytes += num_bytes;
713                         printf("coeff %d Y': %ld bytes\n", y * 8 + x, num_bytes);
714
715                         double sum_l = 0.0;
716                         for (int l : lens) {
717                                 sum_l += l;
718                         }
719                         double avg_l = sum_l / lens.size();
720
721                         double sum_sql = 0.0;
722                         for (int l : lens) {
723                                 sum_sql += (l - avg_l) * (l - avg_l);
724                         }
725                         double stddev_l = sqrt(sum_sql / (lens.size() - 1));
726                         printf("coeff %d: avg=%.2f bytes, stddev=%.2f bytes\n", y*8+x, avg_l, stddev_l);
727                 }
728         }
729
730         // Cb
731         for (unsigned y = 0; y < 8; ++y) {
732                 for (unsigned x = 0; x < 8; ++x) {
733                         SymbolStats &s_chroma = stats[pick_stats_for(x, y) + 8];
734                         //SymbolStats &s_chroma = stats[pick_stats_for(x, y)];
735                         rans_encoder.init_prob(s_chroma);
736
737                         rans_encoder.clear();
738                         size_t num_bytes = 0;
739                         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
740                                 for (unsigned xb = 0; xb < WIDTH/2; xb += 8) {
741                                         int k = coeff_cb[(yb + y) * WIDTH/2 + (xb + x)];
742                                         rans_encoder.encode_coeff(k);
743                                 }
744                                 if (yb % 16 == 8) {
745                                         num_bytes += rans_encoder.save_block(codedfp);
746                                 }
747                         }
748                         if (HEIGHT % 16 != 0) {
749                                 num_bytes += rans_encoder.save_block(codedfp);
750                         }
751                         tot_bytes += num_bytes;
752                         printf("coeff %d Cb: %ld bytes\n", y * 8 + x, num_bytes);
753                 }
754         }
755
756         // Cr
757         for (unsigned y = 0; y < 8; ++y) {
758                 for (unsigned x = 0; x < 8; ++x) {
759                         SymbolStats &s_chroma = stats[pick_stats_for(x, y) + 8];
760                         //SymbolStats &s_chroma = stats[pick_stats_for(x, y)];
761                         rans_encoder.init_prob(s_chroma);
762
763                         rans_encoder.clear();
764                         size_t num_bytes = 0;
765                         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
766                                 for (unsigned xb = 0; xb < WIDTH/2; xb += 8) {
767                                         int k = coeff_cr[(yb + y) * WIDTH/2 + (xb + x)];
768                                         rans_encoder.encode_coeff(k);
769                                 }
770                                 if (yb % 16 == 8) {
771                                         num_bytes += rans_encoder.save_block(codedfp);
772                                 }
773                         }
774                         if (HEIGHT % 16 != 0) {
775                                 num_bytes += rans_encoder.save_block(codedfp);
776                         }
777                         tot_bytes += num_bytes;
778                         printf("coeff %d Cr: %ld bytes\n", y * 8 + x, num_bytes);
779                 }
780         }
781
782         printf("%ld bytes + %ld escape bits (%ld) = %ld total bytes\n",
783                 tot_bytes - extra_bits / 8,
784                 extra_bits,
785                 extra_bits / 8,
786                 tot_bytes);
787 }