]> git.sesse.net Git - narabu/commitdiff
Add color support.
authorSteinar H. Gunderson <sgunderson@bigfoot.com>
Sat, 16 Sep 2017 13:03:57 +0000 (15:03 +0200)
committerSteinar H. Gunderson <sgunderson@bigfoot.com>
Sat, 16 Sep 2017 13:58:08 +0000 (15:58 +0200)
qdc.cpp

diff --git a/qdc.cpp b/qdc.cpp
index a0aaf002888cc1b9044dcf09d90cf68de1ef50a7..43fcba0de2dc343fbdee87d762992fa00c9ff95b 100644 (file)
--- a/qdc.cpp
+++ b/qdc.cpp
@@ -6,6 +6,7 @@
 
 //#include "ryg_rans/rans64.h"
 #include "ryg_rans/rans_byte.h"
+#include "ryg_rans/renormalize.h"
 
 #include <memory>
 
 #define NUM_SYMS 256
 #define ESCAPE_LIMIT (NUM_SYMS - 1)
 
+static constexpr uint32_t prob_bits = 12;
+static constexpr uint32_t prob_scale = 1 << prob_bits;
+
 using namespace std;
 
 void fdct_int32(short *const In);
 void idct_int32(short *const In);
 
-unsigned char pix[WIDTH * HEIGHT];
-short coeff[WIDTH * HEIGHT];
+unsigned char rgb[WIDTH * HEIGHT * 3];
+unsigned char pix_y[WIDTH * HEIGHT];
+unsigned char pix_cb[(WIDTH/2) * HEIGHT];
+unsigned char pix_cr[(WIDTH/2) * HEIGHT];
+unsigned char full_pix_cb[WIDTH * HEIGHT];
+unsigned char full_pix_cr[WIDTH * HEIGHT];
+short coeff_y[WIDTH * HEIGHT], coeff_cb[(WIDTH/2) * HEIGHT], coeff_cr[(WIDTH/2) * HEIGHT];
+
+int clamp(int x)
+{
+       if (x < 0) return 0;
+       if (x > 255) return 255;
+       return x;
+}
 
 static const unsigned char std_luminance_quant_tbl[64] = {
 #if 0
@@ -77,8 +93,123 @@ void SymbolStats::calc_cum_freqs()
         cum_freqs[i+1] = cum_freqs[i] + freqs[i];
 }
 
+static double cache[NUM_SYMS + 1][prob_scale + 1];
+static double log2cache[prob_scale + 1];
+static int64_t cachefill = 0;
+
+double find_optimal_cost(const uint32_t *cum_freqs, int sym_to, int available_slots)
+{
+       assert(sym_to >= 0);
+
+       while (sym_to > 0 && cum_freqs[sym_to] == cum_freqs[sym_to - 1]) { --sym_to; }
+       if (cache[sym_to][available_slots] >= 0.0) {
+               //printf("CACHE: %d,%d\n", sym_to, available_slots);
+               return cache[sym_to][available_slots];
+       }
+       if (sym_to == 0) {
+               return 0.0;
+       }
+       if (sym_to == 1) {
+               return cum_freqs[0] * log2cache[available_slots];
+       }
+       if (available_slots == 1) {
+               return cum_freqs[0] * log2cache[1] + find_optimal_cost(cum_freqs, sym_to - 1, 0);
+       }
+
+//     printf("UNCACHE: %d,%d\n", sym_to, available_slots);
+#if 0
+       // ok, test all possible options for the last symbol (TODO: save the choice)
+       double best_so_far = HUGE_VAL;
+       //for (int i = num_syms - 1; i < available_slots; ++i) {
+       double f = freqs[sym_to - 1];
+       for (int i = available_slots; i --> 0; ) {
+               double cost1 = f * log2cache[available_slots - i];
+               double cost2 = find_optimal_cost(freqs, sym_to - 1, i);
+
+               if (sym_to == 3 && available_slots == 838) {
+                       printf("%d %f\n", i, cost1 + cost2);
+               } else
+               if (cost1 + cost2 > best_so_far) {
+                       break;
+               }
+               best_so_far = cost1 + cost2;
+       }
+#elif 1
+       // Minimize the number of total bits spent as a function of how many slots
+       // we assign to this symbol.
+       //
+       // The cost function is convex (I don't know how to prove it, but it makes
+       // intuitively a lot of sense). Find a reasonable guess and see what way
+       // we should search, then iterate until we either hit the end or we start
+       // increasing again.
+       double f = cum_freqs[sym_to - 1] - cum_freqs[sym_to - 2];
+       double start = lrint(available_slots * f / cum_freqs[sym_to - 1]);
+
+       int x1 = std::max<int>(floor(start), 1);
+       int x2 = x1 + 1;
+
+       double f1 = f * log2cache[x1] + find_optimal_cost(cum_freqs, sym_to - 1, available_slots - x1);
+       double f2 = f * log2cache[x2] + find_optimal_cost(cum_freqs, sym_to - 1, available_slots - x2);
+
+       int x, direction;  // -1 or +1
+       double best_so_far = std::min(f1, f2);
+       if (isinf(f1) && isinf(f2)) {
+               // The cost isn't infinite due to the first term, so we need to go downwards
+               // to give the second term more room to breathe.
+               x = x1;
+               direction = -1;
+       } else if (f1 < f2) {
+               x = x1;
+               direction = -1;
+       } else {
+               x = x2;
+               direction = 1;
+       }
+
+       //printf("[%d,%d] freq=%ld cumfreq=%d From %d and %d, chose %d [%f] and direction=%d\n",
+       //      sym_to, available_slots, freqs[sym_to - 1], cum_freqs[sym_to - 1], x1, x2, x, best_so_far, direction);
+
+       while ((x + direction) > 0 && (x + direction) <= available_slots) {
+               x += direction;
+               double fn = f * log2cache[x] + find_optimal_cost(cum_freqs, sym_to - 1, available_slots - x);
+       //      printf("[%d,%d] %d is %f\n", sym_to, available_slots, x, fn);
+               if (fn > best_so_far) {
+                       break;
+               }
+               best_so_far = fn;
+       }
+#endif
+       if (++cachefill % 131072 == 0) {
+       //      printf("%d,%d = %f (cachefill = %.2f%%)\n", sym_to, available_slots, best_so_far,
+       //              100.0 * (cachefill / double((NUM_SYMS + 1) * (prob_scale + 1))));
+       }
+       assert(best_so_far >= 0.0);
+       assert(cache[sym_to][available_slots] < 0.0);
+       cache[sym_to][available_slots] = best_so_far;
+       return best_so_far;
+}
+
+double find_optimal_cost(const uint32_t *cum_freqs, const uint64_t *freqs)
+{
+       for (int j = 0; j <= NUM_SYMS; ++j) {
+               for (unsigned k = 0; k <= prob_scale; ++k) {
+                       cache[j][k] = -1.0;
+               }
+       }
+       for (unsigned k = 0; k <= prob_scale; ++k) {
+               log2cache[k] = -log2(k * (1.0 / prob_scale));
+               //printf("log2cache[%d] = %f\n", k, log2cache[k]);
+       }
+       cachefill = 0;
+       double ret = find_optimal_cost(cum_freqs, NUM_SYMS, prob_scale);
+       printf("Used %ld function invocations\n", cachefill);
+       return ret;
+}
+
 void SymbolStats::normalize_freqs(uint32_t target_total)
 {
+    uint64_t real_freq[NUM_SYMS + 1];  // hack
+
     assert(target_total >= NUM_SYMS);
 
     calc_cum_freqs();
@@ -86,9 +217,23 @@ void SymbolStats::normalize_freqs(uint32_t target_total)
 
     if (cur_total == 0) return;
 
+    double ideal_cost = 0.0;
+    for (int i = 1; i <= NUM_SYMS; i++)
+    {
+      real_freq[i] = cum_freqs[i] - cum_freqs[i - 1];
+      if (real_freq[i] > 0)
+        ideal_cost -= real_freq[i] * log2(real_freq[i] / double(cur_total));
+    }
+
+    OptimalRenormalize(cum_freqs, NUM_SYMS, prob_scale);
+
+#if 0
+    double optimal_cost = find_optimal_cost(cum_freqs + 1, real_freq + 1);
+
     // resample distribution based on cumulative freqs
     for (int i = 1; i <= NUM_SYMS; i++)
-        cum_freqs[i] = ((uint64_t)target_total * cum_freqs[i])/cur_total;
+        //cum_freqs[i] = ((uint64_t)target_total * cum_freqs[i])/cur_total;
+        cum_freqs[i] = lrint(cum_freqs[i] * double(target_total) / cur_total);
 
     // if we nuked any non-0 frequency symbol to 0, we need to steal
     // the range to make the frequency nonzero from elsewhere.
@@ -121,6 +266,7 @@ void SymbolStats::normalize_freqs(uint32_t target_total)
             }
         }
     }
+#endif
 
     // calculate updated freqs and make sure we didn't screw anything up
     assert(cum_freqs[0] == 0 && cum_freqs[NUM_SYMS] == target_total);
@@ -133,6 +279,23 @@ void SymbolStats::normalize_freqs(uint32_t target_total)
         // calc updated freq
         freqs[i] = cum_freqs[i+1] - cum_freqs[i];
     }
+
+    double calc_cost = 0.0;
+    for (int i = 1; i <= NUM_SYMS; i++)
+    {
+      uint64_t freq = cum_freqs[i] - cum_freqs[i - 1];
+      if (real_freq[i] > 0)
+        calc_cost -= real_freq[i] * log2(freq / double(target_total));
+    }
+
+    static double total_loss = 0.0;
+    total_loss += calc_cost - ideal_cost;
+    static double total_loss_with_dp = 0.0;
+       double optimal_cost = 0.0;
+    //total_loss_with_dp += optimal_cost - ideal_cost;
+    printf("ideal cost = %.0f bits, DP cost = %.0f bits, calc cost = %.0f bits (loss = %.2f bytes, total loss = %.2f bytes, total loss with DP = %.2f bytes)\n",
+               ideal_cost, optimal_cost,
+                calc_cost, (calc_cost - ideal_cost) / 8.0, total_loss / 8.0, total_loss_with_dp / 8.0);
 }
 
 SymbolStats stats[64];
@@ -165,9 +328,6 @@ void write_varint(int x, FILE *fp)
 
 class RansEncoder {
 public:
-       static constexpr uint32_t prob_bits = 12;
-       static constexpr uint32_t prob_scale = 1 << prob_bits;
-
        RansEncoder()
        {
                out_buf.reset(new uint8_t[out_max_size]);
@@ -175,11 +335,11 @@ public:
                clear();
        }
 
-       void init_prob(const SymbolStats &s1, const SymbolStats &s2)
+       void init_prob(const SymbolStats &s)
        {
                for (int i = 0; i < NUM_SYMS; i++) {
-                       //printf("%d: cumfreqs=%d freqs=%d prob_bits=%d\n", i, s.cum_freqs[i], s.freqs[i], prob_bits);
-                       RansEncSymbolInit(&esyms[i], s1.cum_freqs[i], s1.freqs[i], prob_bits);
+                       printf("%d: cumfreqs=%d freqs=%d prob_bits=%d\n", i, s.cum_freqs[i], s.freqs[i], prob_bits);
+                       RansEncSymbolInit(&esyms[i], s.cum_freqs[i], s.freqs[i], prob_bits);
                }
        }
 
@@ -217,14 +377,14 @@ public:
 
                clear();
 
-               //printf("Saving block: %d rANS bytes, %d sign bytes\n", num_rans_bytes, num_sign_bytes);
+               printf("Saving block: %d rANS bytes, %d sign bytes\n", num_rans_bytes, num_sign_bytes);
                return num_rans_bytes + num_sign_bytes;
                //return num_rans_bytes;
        }
 
        void encode_coeff(short signed_k)
        {
-               //printf("encoding coeff %d\n", signed_k);
+               printf("encoding coeff %d\n", signed_k);
                short k = abs(signed_k);
                if (k >= ESCAPE_LIMIT) {
                        // Put the coefficient as a 1/(2^12) symbol _before_
@@ -294,46 +454,107 @@ static inline int unquantize(int qf, int coeff_idx)
        return (2 * qf * w * s) / 32;
 }
 
-int main(void)
+void readpix(unsigned char *ptr, const char *filename)
 {
-       FILE *fp = fopen("pic.pgm", "rb");
-       fread(pix, 1, WIDTH * HEIGHT, fp);
+       FILE *fp = fopen(filename, "rb");
+       if (fp == nullptr) {
+               perror(filename);
+               exit(1);
+       }
+
+       fseek(fp, 0, SEEK_END);
+       long len = ftell(fp);
+       assert(len >= WIDTH * HEIGHT * 3);
+       fseek(fp, len - WIDTH * HEIGHT * 3, SEEK_SET);
+
+       fread(ptr, 1, WIDTH * HEIGHT * 3, fp);
        fclose(fp);
+}
+
+void convert_ycbcr()
+{
+       double coeff[3] = { 0.2126, 0.7152, 0.0722 };  // sum = 1.0
+       double cb_fac = 1.0 / (coeff[0] + coeff[1] + 1.0f - coeff[2]);  // 0.539
+       double cr_fac = 1.0 / (1.0f - coeff[0] + coeff[1] + coeff[2]);  // 0.635 
+
+       unique_ptr<float[]> temp_cb(new float[WIDTH * HEIGHT]);
+       unique_ptr<float[]> temp_cr(new float[WIDTH * HEIGHT]);
+       for (unsigned yb = 0; yb < HEIGHT; ++yb) {
+               for (unsigned xb = 0; xb < WIDTH; ++xb) {
+                       int r = rgb[((yb * WIDTH) + xb) * 3 + 0];
+                       int g = rgb[((yb * WIDTH) + xb) * 3 + 1];
+                       int b = rgb[((yb * WIDTH) + xb) * 3 + 2];
+                       double y = std::min(std::max(coeff[0] * r + coeff[1] * g + coeff[2] * b, 0.0), 255.0);
+                       double cb = (b - y) * cb_fac + 128.0;
+                       double cr = (r - y) * cr_fac + 128.0;
+                       pix_y[(yb * WIDTH) + xb] = lrint(y);
+                       temp_cb[(yb * WIDTH) + xb] = cb;
+                       temp_cr[(yb * WIDTH) + xb] = cr;
+                       full_pix_cb[(yb * WIDTH) + xb] = lrint(std::min(std::max(cb, 0.0), 255.0));
+                       full_pix_cr[(yb * WIDTH) + xb] = lrint(std::min(std::max(cr, 0.0), 255.0));
+               }
+       }
+
+       // Simple 4:2:2 subsampling with left convention.
+       for (unsigned yb = 0; yb < HEIGHT; ++yb) {
+               for (unsigned xb = 0; xb < WIDTH / 2; ++xb) {
+                       int c0 = yb * WIDTH + std::max(int(xb) * 2 - 1, 0);
+                       int c1 = yb * WIDTH + xb * 2;
+                       int c2 = yb * WIDTH + xb * 2 + 1;
+                       
+                       double cb = 0.25 * temp_cb[c0] + 0.5 * temp_cb[c1] + 0.25 * temp_cb[c2];
+                       double cr = 0.25 * temp_cr[c0] + 0.5 * temp_cr[c1] + 0.25 * temp_cr[c2];
+                       cb = std::min(std::max(cb, 0.0), 255.0);
+                       cr = std::min(std::max(cr, 0.0), 255.0);
+                       pix_cb[(yb * WIDTH/2) + xb] = lrint(cb);
+                       pix_cr[(yb * WIDTH/2) + xb] = lrint(cr);
+               }
+       }
+}
+
+int main(int argc, char **argv)
+{
+       if (argc >= 2)
+               readpix(rgb, argv[1]);
+       else
+               readpix(rgb, "color.pnm");
+       convert_ycbcr();
 
        double sum_sq_err = 0.0;
+       //double last_cb_cfl_fac = 0.0;
+       //double last_cr_cfl_fac = 0.0;
 
+       // DCT and quantize luma
        for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
                for (unsigned xb = 0; xb < WIDTH; xb += 8) {
                        // Read one block
-                       short in[64];
+                       short in_y[64];
                        for (unsigned y = 0; y < 8; ++y) {
                                for (unsigned x = 0; x < 8; ++x) {
-                                       in[y * 8 + x] = pix[(yb + y) * WIDTH + (xb + x)];
+                                       in_y[y * 8 + x] = pix_y[(yb + y) * WIDTH + (xb + x)];
                                }
                        }
 
                        // FDCT it
-                       fdct_int32(in);
+                       fdct_int32(in_y);
 
                        for (unsigned y = 0; y < 8; ++y) {
                                for (unsigned x = 0; x < 8; ++x) {
                                        int coeff_idx = y * 8 + x;
-                                       int k = quantize(in[coeff_idx], coeff_idx);
-                                       coeff[(yb + y) * WIDTH + (xb + x)] = k;
+                                       int k = quantize(in_y[coeff_idx], coeff_idx);
+                                       coeff_y[(yb + y) * WIDTH + (xb + x)] = k;
 
                                        // Store back for reconstruction / PSNR calculation
-                                       in[coeff_idx] = unquantize(k, coeff_idx);
+                                       in_y[coeff_idx] = unquantize(k, coeff_idx);
                                }
                        }
 
-                       idct_int32(in);
+                       idct_int32(in_y);
 
                        for (unsigned y = 0; y < 8; ++y) {
                                for (unsigned x = 0; x < 8; ++x) {
-                                       int k = in[y * 8 + x];
-                                       if (k < 0) k = 0;
-                                       if (k > 255) k = 255;
-                                       uint8_t *ptr = &pix[(yb + y) * WIDTH + (xb + x)];
+                                       int k = clamp(in_y[y * 8 + x]);
+                                       uint8_t *ptr = &pix_y[(yb + y) * WIDTH + (xb + x)];
                                        sum_sq_err += (*ptr - k) * (*ptr - k);
                                        *ptr = k;
                                }
@@ -344,16 +565,216 @@ int main(void)
        double psnr_db = 20 * log10(255.0 / sqrt(mse));
        printf("psnr = %.2f dB\n", psnr_db);
 
+       //double chroma_energy = 0.0, chroma_energy_pred = 0.0;
+
+       // DCT and quantize chroma
+       //double last_cb_cfl_fac = 0.0, last_cr_cfl_fac = 0.0;
+       for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
+               for (unsigned xb = 0; xb < WIDTH/2; xb += 8) {
+#if 0
+                       // TF switch: Two 8x8 luma blocks -> one 16x8 block, then drop high frequencies
+                       printf("in blocks:\n");
+                       for (unsigned y = 0; y < 8; ++y) {
+                               for (unsigned x = 0; x < 8; ++x) {
+                                       short a = coeff_y[(yb + y) * WIDTH + (xb*2 + x)];
+                                       printf(" %4d", a);
+                               }
+                               printf(" | ");
+                               for (unsigned x = 0; x < 8; ++x) {
+                                       short b = coeff_y[(yb + y) * WIDTH + (xb*2 + x + 8)];
+                                       printf(" %4d", b);
+                               }
+                               printf("\n");
+                       }
+
+                       short in_y[64];
+                       for (unsigned y = 0; y < 8; ++y) {
+                               for (unsigned x = 0; x < 4; ++x) {
+                                       short a = coeff_y[(yb + y) * WIDTH + (xb*2 + x)];
+                                       short b = coeff_y[(yb + y) * WIDTH + (xb*2 + x + 8)];
+                                       b = a - b;
+                                       a = 2 * a - b;
+                                       in_y[y * 8 + x * 2 + 0] = a;
+                                       in_y[y * 8 + x * 2 + 1] = b;
+                               }
+                       }
+
+                       printf("tf-ed block:\n");
+                       for (unsigned y = 0; y < 8; ++y) {
+                               for (unsigned x = 0; x < 8; ++x) {
+                                       short a = in_y[y * 8 + x];
+                                       printf(" %4d", a);
+                               }
+                               printf("\n");
+                       }
+#else
+                       // Read Y block with no tf switch (from reconstructed luma)
+                       short in_y[64];
+                       for (unsigned y = 0; y < 8; ++y) {
+                               for (unsigned x = 0; x < 8; ++x) {
+                                       in_y[y * 8 + x] = pix_y[(yb + y) * (WIDTH) + (xb + x) * 2];
+                               }
+                       }
+                       fdct_int32(in_y);
+#endif
+
+                       // Read one block
+                       short in_cb[64], in_cr[64];
+                       for (unsigned y = 0; y < 8; ++y) {
+                               for (unsigned x = 0; x < 8; ++x) {
+                                       in_cb[y * 8 + x] = pix_cb[(yb + y) * (WIDTH/2) + (xb + x)];
+                                       in_cr[y * 8 + x] = pix_cr[(yb + y) * (WIDTH/2) + (xb + x)];
+                               }
+                       }
+
+                       // FDCT it
+                       fdct_int32(in_cb);
+                       fdct_int32(in_cr);
+
+#if 0
+                       // Chroma from luma
+                       double x0 = in_y[1];
+                       double x1 = in_y[8];
+                       double x2 = in_y[9];
+                       double denom = (x0 * x0 + x1 * x1 + x2 * x2);
+                       //double denom = (x1 * x1);
+       
+                       double y0 = in_cb[1];
+                       double y1 = in_cb[8];
+                       double y2 = in_cb[9];
+                       double cb_cfl_fac = (x0 * y0 + x1 * y1 + x2 * y2) / denom;
+                       //double cb_cfl_fac = (x1 * y1) / denom;
+
+                       for (unsigned y = 0; y < 8; ++y) {
+                               for (unsigned x = 0; x < 8; ++x) {
+                                       short a = in_y[y * 8 + x];
+                                       printf(" %4d", a);
+                               }
+                               printf(" | ");
+                               for (unsigned x = 0; x < 8; ++x) {
+                                       short a = in_cb[y * 8 + x];
+                                       printf(" %4d", a);
+                               }
+                               printf("\n");
+                       }
+                       printf("(%d,%d,%d) -> (%d,%d,%d) gives %f\n",
+                               in_y[1], in_y[8], in_y[9], 
+                               in_cb[1], in_cb[8], in_cb[9],
+                               cb_cfl_fac);
+
+                       y0 = in_cr[1];
+                       y1 = in_cr[8];
+                       y2 = in_cr[9];
+                       double cr_cfl_fac = (x0 * y0 + x1 * y1 + x2 * y2) / denom;
+                       //double cr_cfl_fac = (x1 * y1) / denom;
+                       printf("cb CfL = %7.3f  dc = %5d    cr CfL = %7.3f  dc = %d\n",
+                               cb_cfl_fac, in_cb[0] - in_y[0],
+                               cr_cfl_fac, in_cr[0] - in_y[0]);
+
+                       if (denom == 0.0) { cb_cfl_fac = cr_cfl_fac = 0.0; }
+
+                       // CHEAT
+                       //last_cb_cfl_fac = cb_cfl_fac;
+                       //last_cr_cfl_fac = cr_cfl_fac;
+
+                       for (unsigned coeff_idx = 1; coeff_idx < 64; ++coeff_idx) {
+                               //printf("%2d: cb = %3d prediction = %f * %3d = %7.3f\n", coeff_idx, in_cb[coeff_idx], last_cb_cfl_fac, in_y[coeff_idx], last_cb_cfl_fac * in_y[coeff_idx]);
+                               //printf("%2d: cr = %3d prediction = %f * %3d = %7.3f\n", coeff_idx, in_cr[coeff_idx], last_cr_cfl_fac, in_y[coeff_idx], last_cr_cfl_fac * in_y[coeff_idx]);
+                               double cb_pred = last_cb_cfl_fac * in_y[coeff_idx];
+                               chroma_energy += in_cb[coeff_idx] * in_cb[coeff_idx];
+                               chroma_energy_pred += (in_cb[coeff_idx] - cb_pred) * (in_cb[coeff_idx] - cb_pred);
+
+                               //in_cb[coeff_idx] -= lrint(last_cb_cfl_fac * in_y[coeff_idx]);
+                               //in_cr[coeff_idx] -= lrint(last_cr_cfl_fac * in_y[coeff_idx]);
+                               //in_cr[coeff_idx] -= lrint(last_cr_cfl_fac * in_y[coeff_idx]);
+                               //in_cb[coeff_idx] = lrint(in_y[coeff_idx] * (1.0 / sqrt(2)));
+                               //in_cr[coeff_idx] = lrint(in_y[coeff_idx] * (1.0 / sqrt(2)));
+                               //in_cb[coeff_idx] = lrint(in_y[coeff_idx]);
+                               //in_cr[coeff_idx] = lrint(in_y[coeff_idx]);
+                       }
+                       //in_cb[0] += 1024;
+                       //in_cr[0] += 1024;
+                       //in_cb[0] -= in_y[0];
+                       //in_cr[0] -= in_y[0];
+#endif
+
+                       for (unsigned y = 0; y < 8; ++y) {
+                               for (unsigned x = 0; x < 8; ++x) {
+                                       int coeff_idx = y * 8 + x;
+                                       int k_cb = quantize(in_cb[coeff_idx], coeff_idx);
+                                       coeff_cb[(yb + y) * (WIDTH/2) + (xb + x)] = k_cb;
+                                       int k_cr = quantize(in_cr[coeff_idx], coeff_idx);
+                                       coeff_cr[(yb + y) * (WIDTH/2) + (xb + x)] = k_cr;
+
+                                       // Store back for reconstruction / PSNR calculation
+                                       in_cb[coeff_idx] = unquantize(k_cb, coeff_idx);
+                                       in_cr[coeff_idx] = unquantize(k_cr, coeff_idx);
+                               }
+                       }
+
+                       idct_int32(in_y);  // DEBUG
+                       idct_int32(in_cb);
+                       idct_int32(in_cr);
+
+                       for (unsigned y = 0; y < 8; ++y) {
+                               for (unsigned x = 0; x < 8; ++x) {
+                                       pix_cb[(yb + y) * (WIDTH/2) + (xb + x)] = clamp(in_cb[y * 8 + x]);
+                                       pix_cr[(yb + y) * (WIDTH/2) + (xb + x)] = clamp(in_cr[y * 8 + x]);
+
+                       //              pix_cb[(yb + y) * (WIDTH/2) + (xb + x)] = in_y[y * 8 + x];
+                       //              pix_cr[(yb + y) * (WIDTH/2) + (xb + x)] = in_y[y * 8 + x];
+                               }
+                       }
+
+#if 0
+                       last_cb_cfl_fac = cb_cfl_fac;
+                       last_cr_cfl_fac = cr_cfl_fac;
+#endif
+               }
+       }
+
+#if 0
+       printf("chroma_energy = %f, with_pred = %f\n",
+               chroma_energy / (WIDTH * HEIGHT), chroma_energy_pred / (WIDTH * HEIGHT));
+#endif
+
        // DC coefficient pred from the right to left
        for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
                for (unsigned xb = 0; xb < WIDTH - 8; xb += 8) {
-                       coeff[yb * WIDTH + xb] -= coeff[yb * WIDTH + (xb + 8)];
+                       coeff_y[yb * WIDTH + xb] -= coeff_y[yb * WIDTH + (xb + 8)];
                }
        }
 
-       fp = fopen("reconstructed.pgm", "wb");
+       FILE *fp = fopen("reconstructed.pgm", "wb");
        fprintf(fp, "P5\n%d %d\n255\n", WIDTH, HEIGHT);
-       fwrite(pix, 1, WIDTH * HEIGHT, fp);
+       fwrite(pix_y, 1, WIDTH * HEIGHT, fp);
+       fclose(fp);
+
+       fp = fopen("reconstructed.pnm", "wb");
+       fprintf(fp, "P6\n%d %d\n255\n", WIDTH, HEIGHT);
+       for (unsigned yb = 0; yb < HEIGHT; ++yb) {
+               for (unsigned xb = 0; xb < WIDTH; ++xb) {
+                       int y = pix_y[(yb * WIDTH) + xb];
+                       int cb, cr;
+                       int c0 = yb * (WIDTH/2) + xb/2;
+                       if (xb % 2 == 0) {
+                               cb = pix_cb[c0] - 128.0;
+                               cr = pix_cr[c0] - 128.0;
+                       } else {
+                               int c1 = yb * (WIDTH/2) + std::min<int>(xb/2 + 1, WIDTH/2 - 1);
+                               cb = 0.5 * (pix_cb[c0] + pix_cb[c1]) - 128.0;
+                               cr = 0.5 * (pix_cr[c0] + pix_cr[c1]) - 128.0;
+                       }
+
+                       double r = y + 1.5748 * cr;
+                       double g = y - 0.1873 * cb - 0.4681 * cr;
+                       double b = y + 1.8556 * cb;
+
+                       putc(clamp(lrint(r)), fp);
+                       putc(clamp(lrint(g)), fp);
+                       putc(clamp(lrint(b)), fp);
+               }
+       }
        fclose(fp);
 
        // For each coefficient, make some tables.
@@ -363,32 +784,45 @@ int main(void)
        }
        for (unsigned y = 0; y < 8; ++y) {
                for (unsigned x = 0; x < 8; ++x) {
-                       SymbolStats &s = stats[pick_stats_for(x, y)];
+                       SymbolStats &s_luma = stats[pick_stats_for(x, y)];
+                       SymbolStats &s_chroma = stats[pick_stats_for(x, y) + 8];  // HACK
+                       //SymbolStats &s_chroma = stats[pick_stats_for(x, y)];
 
+                       // Luma
                        for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
                                for (unsigned xb = 0; xb < WIDTH; xb += 8) {
-                                       short k = abs(coeff[(yb + y) * WIDTH + (xb + x)]);
+                                       short k = abs(coeff_y[(yb + y) * WIDTH + (xb + x)]);
                                        if (k >= ESCAPE_LIMIT) {
-                                               //printf("coeff (%d,%d) had value %d\n", y, x, k);
                                                k = ESCAPE_LIMIT;
                                                extra_bits += 12;  // escape this one
                                        }
-                                       //if (y != 0 || x != 0) ++sign_bits;
                                        if (k != 0) ++sign_bits;
-                                       ++s.freqs[k];
+                                       ++s_luma.freqs[k];
+                               }
+                       }
+                       // Chroma
+                       for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
+                               for (unsigned xb = 0; xb < WIDTH/2; xb += 8) {
+                                       short k_cb = abs(coeff_cb[(yb + y) * WIDTH/2 + (xb + x)]);
+                                       short k_cr = abs(coeff_cr[(yb + y) * WIDTH/2 + (xb + x)]);
+                                       if (k_cb >= ESCAPE_LIMIT) {
+                                               k_cb = ESCAPE_LIMIT;
+                                               extra_bits += 12;  // escape this one
+                                       }
+                                       if (k_cr >= ESCAPE_LIMIT) {
+                                               k_cr = ESCAPE_LIMIT;
+                                               extra_bits += 12;  // escape this one
+                                       }
+                                       if (k_cb != 0) ++sign_bits;
+                                       if (k_cr != 0) ++sign_bits;
+                                       ++s_chroma.freqs[k_cb];
+                                       ++s_chroma.freqs[k_cr];
                                }
                        }
                }
        }
        for (unsigned i = 0; i < 64; ++i) {
-#if 0
-               printf("coeff %i:", i);
-               for (unsigned j = 0; j <= ESCAPE_LIMIT; ++j) {
-                       printf(" %d", stats[i].freqs[j]);
-               }
-               printf("\n");
-#endif
-               stats[i].normalize_freqs(RansEncoder::prob_scale);
+               stats[i].normalize_freqs(prob_scale);
        }
 
        FILE *codedfp = fopen("coded.dat", "wb");
@@ -397,42 +831,35 @@ int main(void)
                exit(1);
        }
 
-       // TODO: varint or something on the freqs
+       // TODO: rather gamma-k or something
        for (unsigned i = 0; i < 64; ++i) {
                if (stats[i].cum_freqs[NUM_SYMS] == 0) {
                        continue;
                }
                printf("writing table %d\n", i);
-#if 0
-               for (unsigned j = 0; j <= NUM_SYMS; ++j) {
-                       uint16_t freq = stats[i].cum_freqs[j];
-                       fwrite(&freq, 1, sizeof(freq), codedfp);
-                       printf("%d: %d\n", j, stats[i].freqs[j]);
-               }
-#else
-               // TODO: rather gamma-k or something
                for (unsigned j = 0; j < NUM_SYMS; ++j) {
                        write_varint(stats[i].freqs[j], codedfp);
                }
-#endif
        }
 
        RansEncoder rans_encoder;
 
        size_t tot_bytes = 0;
+
+       // Luma
        for (unsigned y = 0; y < 8; ++y) {
                for (unsigned x = 0; x < 8; ++x) {
-                       SymbolStats &s1 = stats[pick_stats_for(x, y)];
-                       SymbolStats &s2 = stats[pick_stats_for(x, y) + 8];
+                       SymbolStats &s_luma = stats[pick_stats_for(x, y)];
+                       rans_encoder.init_prob(s_luma);
 
-                       rans_encoder.init_prob(s1, s2);
+                       // Luma
 
                        // need to reverse later
                        rans_encoder.clear();
                        size_t num_bytes = 0;
                        for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
                                for (unsigned xb = 0; xb < WIDTH; xb += 8) {
-                                       int k = coeff[(yb + y) * WIDTH + (xb + x)];
+                                       int k = coeff_y[(yb + y) * WIDTH + (xb + x)];
                                        //printf("encoding coeff %d xb,yb=%d,%d: %d\n", y*8+x, xb, yb, k);
                                        rans_encoder.encode_coeff(k);
                                }
@@ -444,9 +871,62 @@ int main(void)
                                num_bytes += rans_encoder.save_block(codedfp);
                        }
                        tot_bytes += num_bytes;
-                       printf("coeff %d: %ld bytes\n", y * 8 + x, num_bytes);
+                       printf("coeff %d Y': %ld bytes\n", y * 8 + x, num_bytes);
+               }
+       }
+
+       // Cb
+       for (unsigned y = 0; y < 8; ++y) {
+               for (unsigned x = 0; x < 8; ++x) {
+                       SymbolStats &s_chroma = stats[pick_stats_for(x, y) + 8];
+                       //SymbolStats &s_chroma = stats[pick_stats_for(x, y)];
+                       rans_encoder.init_prob(s_chroma);
+
+                       rans_encoder.clear();
+                       size_t num_bytes = 0;
+                       for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
+                               for (unsigned xb = 0; xb < WIDTH/2; xb += 8) {
+                                       int k = coeff_cb[(yb + y) * WIDTH/2 + (xb + x)];
+                                       rans_encoder.encode_coeff(k);
+                               }
+                               if (yb % 16 == 8) {
+                                       num_bytes += rans_encoder.save_block(codedfp);
+                               }
+                       }
+                       if (HEIGHT % 16 != 0) {
+                               num_bytes += rans_encoder.save_block(codedfp);
+                       }
+                       tot_bytes += num_bytes;
+                       printf("coeff %d Cb: %ld bytes\n", y * 8 + x, num_bytes);
                }
        }
+
+       // Cr
+       for (unsigned y = 0; y < 8; ++y) {
+               for (unsigned x = 0; x < 8; ++x) {
+                       SymbolStats &s_chroma = stats[pick_stats_for(x, y) + 8];
+                       //SymbolStats &s_chroma = stats[pick_stats_for(x, y)];
+                       rans_encoder.init_prob(s_chroma);
+
+                       rans_encoder.clear();
+                       size_t num_bytes = 0;
+                       for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
+                               for (unsigned xb = 0; xb < WIDTH/2; xb += 8) {
+                                       int k = coeff_cr[(yb + y) * WIDTH/2 + (xb + x)];
+                                       rans_encoder.encode_coeff(k);
+                               }
+                               if (yb % 16 == 8) {
+                                       num_bytes += rans_encoder.save_block(codedfp);
+                               }
+                       }
+                       if (HEIGHT % 16 != 0) {
+                               num_bytes += rans_encoder.save_block(codedfp);
+                       }
+                       tot_bytes += num_bytes;
+                       printf("coeff %d Cr: %ld bytes\n", y * 8 + x, num_bytes);
+               }
+       }
+
        printf("%ld bytes + %ld sign bits (%ld) + %ld escape bits (%ld) = %ld total bytes\n",
                tot_bytes - sign_bits / 8 - extra_bits / 8,
                sign_bits,