]> git.sesse.net Git - narabu/commitdiff
Add the beginnings of a GPU encoder.
authorSteinar H. Gunderson <sgunderson@bigfoot.com>
Sat, 7 Oct 2017 21:22:14 +0000 (23:22 +0200)
committerSteinar H. Gunderson <sgunderson@bigfoot.com>
Sat, 7 Oct 2017 21:22:14 +0000 (23:22 +0200)
It doesn't really work currently (too buggy), only does DCT
(not the rANS part), and only encodes luma.

Makefile
encoder.shader [new file with mode: 0644]
narabu-encoder.cpp [new file with mode: 0644]

index cad417ad06e73b5d832c5779509da9c3c4f4ab6e..b7e59105852dc799b15291c23acea82331d6ffde 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -1,4 +1,4 @@
-all: narabu qdc qdd psnr
+all: narabu narabu-encoder qdc qdd psnr
 CFLAGS=-O2 -g -Wall -std=gnu++17 $(shell pkg-config --cflags movit)
 CXXFLAGS=$(CFLAGS)
 LDFLAGS=$(shell pkg-config --libs movit) -lepoxy -lSDL2
@@ -15,6 +15,9 @@ psnr: psnr.o
 narabu: narabu.o util.o
        $(CXX) $(LDFLAGS) -o $@ $^
 
+narabu-encoder: narabu-encoder.o util.o ryg_rans/renormalize.o
+       $(CXX) $(LDFLAGS) -o $@ $^
+
 psnr.o: psnr.cpp
        $(CXX) $(CXXFLAGS) -fpermissive -o $@ -c $^
 
diff --git a/encoder.shader b/encoder.shader
new file mode 100644 (file)
index 0000000..c49a3fe
--- /dev/null
@@ -0,0 +1,163 @@
+#version 440
+#extension GL_ARB_shader_clock : enable
+
+layout(local_size_x = 8) in;
+
+layout(r16ui) uniform restrict writeonly uimage2D dc_ac7_tex;
+layout(r16ui) uniform restrict writeonly uimage2D ac1_ac6_tex;
+layout(r16ui) uniform restrict writeonly uimage2D ac2_ac5_tex;
+layout(r8i) uniform restrict writeonly iimage2D ac3_tex;
+layout(r8i) uniform restrict writeonly iimage2D ac4_tex;
+layout(r8i) uniform restrict readonly iimage2D image_tex;
+
+shared float temp[64];
+
+const float W[64] = {
+         8, 16, 19, 22, 26, 27, 29, 34,
+        16, 16, 22, 24, 27, 29, 34, 37,
+        19, 22, 26, 27, 29, 34, 34, 38,
+        22, 22, 26, 27, 29, 34, 37, 40,
+        22, 26, 27, 29, 32, 35, 40, 48,
+        26, 27, 29, 32, 35, 40, 48, 58,
+        26, 27, 29, 34, 38, 46, 56, 69,
+        27, 29, 35, 38, 46, 56, 69, 83
+};
+const float S = 4.0;  // whatever?
+
+// NOTE: Contains factors to counteract the scaling in the DCT implementation.
+const float quant_matrix[64] = {
+        1.0 / 64.0,         1.0 / (W[ 1] * S),  1.0 / (W[ 2] * S),  1.0 / (W[ 3] * S),  1.0 / (W[ 4] * S),  1.0 / (W[ 5] * S),  1.0 / (W[ 6] * S),  1.0 / (W[ 7] * S),
+        1.0 / (W[ 8] * S),  2.0 / (W[ 9] * S),  2.0 / (W[10] * S),  2.0 / (W[11] * S),  2.0 / (W[12] * S),  2.0 / (W[13] * S),  2.0 / (W[14] * S),  2.0 / (W[15] * S),
+        1.0 / (W[16] * S),  2.0 / (W[17] * S),  2.0 / (W[18] * S),  2.0 / (W[19] * S),  2.0 / (W[20] * S),  2.0 / (W[21] * S),  2.0 / (W[22] * S),  2.0 / (W[23] * S),
+        1.0 / (W[24] * S),  2.0 / (W[25] * S),  2.0 / (W[26] * S),  2.0 / (W[27] * S),  2.0 / (W[28] * S),  2.0 / (W[29] * S),  2.0 / (W[30] * S),  2.0 / (W[31] * S),
+        1.0 / (W[32] * S),  2.0 / (W[33] * S),  2.0 / (W[34] * S),  2.0 / (W[35] * S),  2.0 / (W[36] * S),  2.0 / (W[37] * S),  2.0 / (W[38] * S),  2.0 / (W[39] * S),
+        1.0 / (W[40] * S),  2.0 / (W[41] * S),  2.0 / (W[42] * S),  2.0 / (W[43] * S),  2.0 / (W[44] * S),  2.0 / (W[45] * S),  2.0 / (W[46] * S),  2.0 / (W[47] * S),
+        1.0 / (W[48] * S),  2.0 / (W[49] * S),  2.0 / (W[50] * S),  2.0 / (W[51] * S),  2.0 / (W[52] * S),  2.0 / (W[53] * S),  2.0 / (W[54] * S),  2.0 / (W[55] * S),
+        1.0 / (W[56] * S),  2.0 / (W[57] * S),  2.0 / (W[58] * S),  2.0 / (W[59] * S),  2.0 / (W[60] * S),  2.0 / (W[61] * S),  2.0 / (W[62] * S),  2.0 / (W[63] * S)
+};
+
+// Clamp and pack a 9-bit and a 7-bit signed value into a 16-bit word.
+uint pack_9_7(int v9, int v7)
+{
+       return (uint(clamp(v9, -256, 255)) & 0x1ffu) | ((uint(clamp(v7, -64, 63)) & 0x7fu) << 9);
+}
+
+// Scaled 1D DCT. y0 output is scaled by 8, everything else is scaled by 16.
+void dct_1d(inout float y0, inout float y1, inout float y2, inout float y3, inout float y4, inout float y5, inout float y6, inout float y7)
+{
+       const float a1 = 0.7071067811865474;   // sqrt(2)
+       const float a2 = 0.5411961001461971;   // cos(3/8 pi) * sqrt(2)
+       const float a4 = 1.3065629648763766;   // cos(pi/8) * sqrt(2)
+       // static const float a5 = 0.5 * (a4 - a2);
+       const float a5 = 0.3826834323650897;
+
+       // phase 1
+       const float p1_0 = y0 + y7;
+       const float p1_1 = y1 + y6;
+       const float p1_2 = y2 + y5;
+       const float p1_3 = y3 + y4;
+       const float p1_4 = y3 - y4;
+       const float p1_5 = y2 - y5;
+       const float p1_6 = y1 - y6;
+       const float p1_7 = y0 - y7;
+
+       // phase 2
+       const float p2_0 = p1_0 + p1_3;
+       const float p2_1 = p1_1 + p1_2;
+       const float p2_2 = p1_1 - p1_2;
+       const float p2_3 = p1_0 - p1_3;
+       const float p2_4 = p1_4 + p1_5;  // Inverted.
+       const float p2_5 = p1_5 + p1_6;
+       const float p2_6 = p1_6 + p1_7;
+
+       // phase 3
+       const float p3_0 = p2_0 + p2_1;
+       const float p3_1 = p2_0 - p2_1;
+       const float p3_2 = p2_2 + p2_3;
+       
+       // phase 4
+       const float p4_2 = p3_2 * a1;
+       const float p4_4 = p2_4 * a2 + (p2_4 - p2_6) * a5;
+       const float p4_5 = p2_5 * a1;
+       const float p4_6 = p2_6 * a4 + (p2_4 - p2_6) * a5;
+
+       // phase 5
+       const float p5_2 = p4_2 + p2_3;
+       const float p5_3 = p4_2 - p2_3;
+       const float p5_5 = p4_5 + p1_7;
+       const float p5_7 = p1_7 - p4_5;
+       
+       // phase 6
+       y0 = p3_0;
+       y4 = p3_1;
+       y2 = p5_2;
+       y6 = p5_3;
+       y5 = p4_4 + p5_7;
+       y1 = p5_5 + p4_6;
+       y7 = p5_5 - p4_6;
+       y3 = p5_7 - p4_4;
+}
+void main()
+{
+       uint x = 8 * gl_WorkGroupID.x;
+       uint y = 8 * gl_WorkGroupID.y;
+       uint n = gl_LocalInvocationID.x;
+
+       // Load column.
+       float y0 = imageLoad(image_tex, ivec2(x + n, y + 0)).x;
+       float y1 = imageLoad(image_tex, ivec2(x + n, y + 1)).x;
+       float y2 = imageLoad(image_tex, ivec2(x + n, y + 2)).x;
+       float y3 = imageLoad(image_tex, ivec2(x + n, y + 3)).x;
+       float y4 = imageLoad(image_tex, ivec2(x + n, y + 4)).x;
+       float y5 = imageLoad(image_tex, ivec2(x + n, y + 5)).x;
+       float y6 = imageLoad(image_tex, ivec2(x + n, y + 6)).x;
+       float y7 = imageLoad(image_tex, ivec2(x + n, y + 7)).x;
+
+       // Vertical DCT.
+       dct_1d(y0, y1, y2, y3, y4, y5, y6, y7);
+
+       // Communicate with the other shaders in the group.
+       temp[n + 0 * 8] = y0;
+       temp[n + 1 * 8] = y1;
+       temp[n + 2 * 8] = y2;
+       temp[n + 3 * 8] = y3;
+       temp[n + 4 * 8] = y4;
+       temp[n + 5 * 8] = y5;
+       temp[n + 6 * 8] = y6;
+       temp[n + 7 * 8] = y7;
+
+       memoryBarrierShared();
+       barrier();
+
+       // Load row (so transpose, in a sense).
+       y0 = temp[n * 8 + 0];
+       y1 = temp[n * 8 + 1];
+       y2 = temp[n * 8 + 2];
+       y3 = temp[n * 8 + 3];
+       y4 = temp[n * 8 + 4];
+       y5 = temp[n * 8 + 5];
+       y6 = temp[n * 8 + 6];
+       y7 = temp[n * 8 + 7];
+
+       // Horizontal DCT.
+       dct_1d(y0, y1, y2, y3, y4, y5, y6, y7);
+
+       // Quantize.
+       int c0 = int(round(y0 * quant_matrix[n * 8 + 0]));
+       int c1 = int(round(y1 * quant_matrix[n * 8 + 1]));
+       int c2 = int(round(y2 * quant_matrix[n * 8 + 2]));
+       int c3 = int(round(y3 * quant_matrix[n * 8 + 3]));
+       int c4 = int(round(y4 * quant_matrix[n * 8 + 4]));
+       int c5 = int(round(y5 * quant_matrix[n * 8 + 5]));
+       int c6 = int(round(y6 * quant_matrix[n * 8 + 6]));
+       int c7 = int(round(y7 * quant_matrix[n * 8 + 7]));
+
+       // Clamp, pack and store.
+       uint sx = gl_WorkGroupID.x;
+       imageStore(dc_ac7_tex,  ivec2(sx, y + n), uvec4(pack_9_7(c0, c7), 0, 0, 0));
+       imageStore(ac1_ac6_tex, ivec2(sx, y + n), uvec4(pack_9_7(c1, c6), 0, 0, 0));
+       imageStore(ac2_ac5_tex, ivec2(sx, y + n), uvec4(pack_9_7(c2, c5), 0, 0, 0));
+       imageStore(ac3_tex,     ivec2(sx, y + n), ivec4(c3, 0, 0, 0));
+       imageStore(ac4_tex,     ivec2(sx, y + n), ivec4(c4, 0, 0, 0));
+}
+
diff --git a/narabu-encoder.cpp b/narabu-encoder.cpp
new file mode 100644 (file)
index 0000000..9e9f0ab
--- /dev/null
@@ -0,0 +1,569 @@
+#include <stdio.h>
+#include <stdint.h>
+#include <stdlib.h>
+#include <string.h>
+#include <assert.h>
+#include <math.h>
+#include <SDL2/SDL.h>
+#include <SDL2/SDL_error.h>
+#include <SDL2/SDL_video.h>
+#include <epoxy/gl.h>
+
+#include <algorithm>
+#include <chrono>
+#include <memory>
+#include <numeric>
+#include <random>
+#include <vector>
+#include <unordered_map>
+
+#include <movit/util.h>
+
+#include "ryg_rans/rans_byte.h"
+#include "ryg_rans/renormalize.h"
+#include "util.h"
+
+#define WIDTH 1280
+#define HEIGHT 720
+#define WIDTH_BLOCKS (WIDTH/8)
+#define WIDTH_BLOCKS_CHROMA (WIDTH/16)
+#define HEIGHT_BLOCKS (HEIGHT/8)
+#define NUM_BLOCKS (WIDTH_BLOCKS * HEIGHT_BLOCKS)
+#define NUM_BLOCKS_CHROMA (WIDTH_BLOCKS_CHROMA * HEIGHT_BLOCKS)
+
+#define NUM_SYMS 256
+#define ESCAPE_LIMIT (NUM_SYMS - 1)
+#define BLOCKS_PER_STREAM 320
+
+static constexpr uint32_t prob_bits = 12;
+static constexpr uint32_t prob_scale = 1 << prob_bits;
+
+unsigned char rgb[WIDTH * HEIGHT * 3];
+unsigned char pix_y[WIDTH * HEIGHT];
+unsigned char pix_cb[(WIDTH/2) * HEIGHT];
+unsigned char pix_cr[(WIDTH/2) * HEIGHT];
+
+using namespace std;
+using namespace std::chrono;
+
+void write_varint(int x, FILE *fp)
+{
+       while (x >= 128) {
+               putc((x & 0x7f) | 0x80, fp);
+               x >>= 7;
+       }
+       putc(x, fp);
+}
+
+void readpix(unsigned char *ptr, const char *filename)
+{
+       FILE *fp = fopen(filename, "rb");
+       if (fp == nullptr) {
+               perror(filename);
+               exit(1);
+       }
+
+       fseek(fp, 0, SEEK_END);
+       long len = ftell(fp);
+       assert(len >= WIDTH * HEIGHT * 3);
+       fseek(fp, len - WIDTH * HEIGHT * 3, SEEK_SET);
+
+       fread(ptr, 1, WIDTH * HEIGHT * 3, fp);
+       fclose(fp);
+}
+
+struct SymbolStats
+{
+    uint32_t freqs[NUM_SYMS];
+    uint32_t cum_freqs[NUM_SYMS + 1];
+
+    void clear();
+    void calc_cum_freqs();
+    void normalize_freqs(uint32_t target_total);
+};
+
+void SymbolStats::clear()
+{
+    for (int i=0; i < NUM_SYMS; i++)
+        freqs[i] = 0;
+}
+
+void SymbolStats::calc_cum_freqs()
+{
+    cum_freqs[0] = 0;
+    for (int i=0; i < NUM_SYMS; i++)
+        cum_freqs[i+1] = cum_freqs[i] + freqs[i];
+}
+
+void SymbolStats::normalize_freqs(uint32_t target_total)
+{
+    uint64_t real_freq[NUM_SYMS + 1];  // hack
+
+    assert(target_total >= NUM_SYMS);
+
+    calc_cum_freqs();
+    uint32_t cur_total = cum_freqs[NUM_SYMS];
+
+    if (cur_total == 0) return;
+
+    double ideal_cost = 0.0;
+    for (int i = 1; i <= NUM_SYMS; i++)
+    {
+      real_freq[i] = cum_freqs[i] - cum_freqs[i - 1];
+      if (real_freq[i] > 0)
+        ideal_cost -= real_freq[i] * log2(real_freq[i] / double(cur_total));
+    }
+
+    OptimalRenormalize(cum_freqs, NUM_SYMS, prob_scale);
+
+    // calculate updated freqs and make sure we didn't screw anything up
+    assert(cum_freqs[0] == 0 && cum_freqs[NUM_SYMS] == target_total);
+    for (int i=0; i < NUM_SYMS; i++) {
+        if (freqs[i] == 0)
+            assert(cum_freqs[i+1] == cum_freqs[i]);
+        else
+            assert(cum_freqs[i+1] > cum_freqs[i]);
+
+        // calc updated freq
+        freqs[i] = cum_freqs[i+1] - cum_freqs[i];
+    }
+
+    double calc_cost = 0.0;
+    for (int i = 1; i <= NUM_SYMS; i++)
+    {
+      uint64_t freq = cum_freqs[i] - cum_freqs[i - 1];
+      if (real_freq[i] > 0)
+        calc_cost -= real_freq[i] * log2(freq / double(target_total));
+    }
+
+    static double total_loss = 0.0;
+    total_loss += calc_cost - ideal_cost;
+    static double total_loss_with_dp = 0.0;
+       double optimal_cost = 0.0;
+    //total_loss_with_dp += optimal_cost - ideal_cost;
+    printf("ideal cost = %.0f bits, DP cost = %.0f bits, calc cost = %.0f bits (loss = %.2f bytes, total loss = %.2f bytes, total loss with DP = %.2f bytes)\n",
+               ideal_cost, optimal_cost,
+                calc_cost, (calc_cost - ideal_cost) / 8.0, total_loss / 8.0, total_loss_with_dp / 8.0);
+}
+
+SymbolStats stats[128];
+
+const int luma_mapping[64] = {
+       0, 0, 1, 1, 2, 2, 3, 3,
+       0, 0, 1, 2, 2, 2, 3, 3,
+       1, 1, 2, 2, 2, 3, 3, 3,
+       1, 1, 2, 2, 2, 3, 3, 3,
+       1, 2, 2, 2, 2, 3, 3, 3,
+       2, 2, 2, 2, 3, 3, 3, 3,
+       2, 2, 3, 3, 3, 3, 3, 3,
+       3, 3, 3, 3, 3, 3, 3, 3,
+};
+
+int pick_stats_for(int x, int y)
+{
+       return luma_mapping[y * 8 + x];
+}
+
+class RansEncoder {
+public:
+       RansEncoder()
+       {
+               out_buf.reset(new uint8_t[out_max_size]);
+               clear();
+       }
+
+       void init_prob(SymbolStats &s)
+       {
+               for (int i = 0; i < NUM_SYMS; i++) {
+                       //printf("%d: cumfreqs=%d freqs=%d prob_bits=%d\n", i, s.cum_freqs[i], s.freqs[i], prob_bits + 1);
+                       RansEncSymbolInit(&esyms[i], s.cum_freqs[i], s.freqs[i], prob_bits + 1);
+               }
+               sign_bias = s.cum_freqs[NUM_SYMS];
+       }
+
+       void clear()
+       {
+               out_end = out_buf.get() + out_max_size;
+               ptr = out_end; // *end* of output buffer
+               RansEncInit(&rans);
+       }
+
+       uint32_t save_block(FILE *codedfp)  // Returns number of bytes.
+       {
+               RansEncFlush(&rans, &ptr);
+               //printf("post-flush = %08x\n", rans);
+
+               uint32_t num_rans_bytes = out_end - ptr;
+               if (num_rans_bytes == last_block.size() &&
+                   memcmp(last_block.data(), ptr, last_block.size()) == 0) {
+                       write_varint(0, codedfp);
+                       clear();
+                       return 1;
+               } else {
+                       last_block = string((const char *)ptr, num_rans_bytes);
+               }
+
+               write_varint(num_rans_bytes, codedfp);
+               //fwrite(&num_rans_bytes, 1, 4, codedfp);
+               fwrite(ptr, 1, num_rans_bytes, codedfp);
+
+               //printf("first rANS bytes: %02x %02x %02x %02x %02x %02x %02x %02x\n", ptr[0], ptr[1], ptr[2], ptr[3], ptr[4], ptr[5], ptr[6], ptr[7]);
+
+
+               clear();
+
+               //printf("Saving block: %d rANS bytes\n", num_rans_bytes);
+               return num_rans_bytes;
+               //return num_rans_bytes;
+       }
+
+       void encode_coeff(short signed_k)
+       {
+               //printf("encoding coeff %d (sym %d), rans before encoding = %08x\n", signed_k, ((abs(signed_k) - 1) & 255), rans);
+               unsigned short k = abs(signed_k);
+               if (k >= ESCAPE_LIMIT) {
+                       // Put the coefficient as a 1/(2^12) symbol _before_
+                       // the 255 coefficient, since the decoder will read the
+                       // 255 coefficient first.
+                       RansEncPut(&rans, &ptr, k, 1, prob_bits);
+                       k = ESCAPE_LIMIT;
+               }
+               RansEncPutSymbol(&rans, &ptr, &esyms[(k - 1) & (NUM_SYMS - 1)]);
+               if (signed_k < 0) {
+                       rans += sign_bias;
+               }
+       }
+
+private:
+       static constexpr size_t out_max_size = 32 << 20; // 32 MB.
+       static constexpr size_t max_num_sign = 1048576;  // Way too big. And actually bytes.
+
+       unique_ptr<uint8_t[]> out_buf;
+       uint8_t *out_end;
+       uint8_t *ptr;
+       RansState rans;
+       RansEncSymbol esyms[NUM_SYMS];
+       uint32_t sign_bias;
+
+       std::string last_block;
+};
+
+// Should be done on the GPU, of course, but irrelevant for the demonstration.
+void convert_ycbcr()
+{
+       double coeff[3] = { 0.2126, 0.7152, 0.0722 };  // sum = 1.0
+       double cb_fac = 1.0 / (coeff[0] + coeff[1] + 1.0f - coeff[2]);  // 0.539
+       double cr_fac = 1.0 / (1.0f - coeff[0] + coeff[1] + coeff[2]);  // 0.635 
+
+       unique_ptr<float[]> temp_cb(new float[WIDTH * HEIGHT]);
+       unique_ptr<float[]> temp_cr(new float[WIDTH * HEIGHT]);
+       for (unsigned yb = 0; yb < HEIGHT; ++yb) {
+               for (unsigned xb = 0; xb < WIDTH; ++xb) {
+                       int r = rgb[((yb * WIDTH) + xb) * 3 + 0];
+                       int g = rgb[((yb * WIDTH) + xb) * 3 + 1];
+                       int b = rgb[((yb * WIDTH) + xb) * 3 + 2];
+                       double y = std::min(std::max(coeff[0] * r + coeff[1] * g + coeff[2] * b, 0.0), 255.0);
+                       double cb = (b - y) * cb_fac + 128.0;
+                       double cr = (r - y) * cr_fac + 128.0;
+                       pix_y[(yb * WIDTH) + xb] = lrint(y);
+                       temp_cb[(yb * WIDTH) + xb] = cb;
+                       temp_cr[(yb * WIDTH) + xb] = cr;
+               }
+       }
+
+       // Simple 4:2:2 subsampling with left convention.
+       for (unsigned yb = 0; yb < HEIGHT; ++yb) {
+               for (unsigned xb = 0; xb < WIDTH / 2; ++xb) {
+                       int c0 = yb * WIDTH + std::max(int(xb) * 2 - 1, 0);
+                       int c1 = yb * WIDTH + xb * 2;
+                       int c2 = yb * WIDTH + xb * 2 + 1;
+                       
+                       double cb = 0.25 * temp_cb[c0] + 0.5 * temp_cb[c1] + 0.25 * temp_cb[c2];
+                       double cr = 0.25 * temp_cr[c0] + 0.5 * temp_cr[c1] + 0.25 * temp_cr[c2];
+                       cb = std::min(std::max(cb, 0.0), 255.0);
+                       cr = std::min(std::max(cr, 0.0), 255.0);
+                       pix_cb[(yb * WIDTH/2) + xb] = lrint(cb);
+                       pix_cr[(yb * WIDTH/2) + xb] = lrint(cr);
+               }
+       }
+}
+
+int main(int argc, char **argv)
+{
+       // Set up an OpenGL context using SDL.
+       if (SDL_Init(SDL_INIT_VIDEO) == -1) {
+               fprintf(stderr, "SDL_Init failed: %s\n", SDL_GetError());
+               exit(1);
+       }
+       SDL_GL_SetAttribute(SDL_GL_DEPTH_SIZE, 0);
+       SDL_GL_SetAttribute(SDL_GL_STENCIL_SIZE, 0);
+       SDL_GL_SetAttribute(SDL_GL_DOUBLEBUFFER, 1);
+       SDL_GL_SetAttribute(SDL_GL_CONTEXT_PROFILE_MASK, SDL_GL_CONTEXT_PROFILE_CORE);
+       SDL_GL_SetAttribute(SDL_GL_CONTEXT_MAJOR_VERSION, 4);
+       SDL_GL_SetAttribute(SDL_GL_CONTEXT_MINOR_VERSION, 5);
+
+       SDL_Window *window = SDL_CreateWindow("OpenGL window for unit test",
+               SDL_WINDOWPOS_UNDEFINED,
+               SDL_WINDOWPOS_UNDEFINED,
+               32, 32,
+               SDL_WINDOW_OPENGL);
+       SDL_GLContext context = SDL_GL_CreateContext(window);
+       assert(context != nullptr);
+
+       if (argc >= 2)
+               readpix(rgb, argv[1]);
+       else
+               readpix(rgb, "color.pnm");
+       convert_ycbcr();
+
+       // Compile the shader.
+       string shader_src = ::read_file("encoder.shader");
+       GLuint shader_num = compile_shader(shader_src, GL_COMPUTE_SHADER);
+       GLuint glsl_program_num = glCreateProgram();
+       glAttachShader(glsl_program_num, shader_num);
+       glLinkProgram(glsl_program_num);
+
+       GLint success;
+       glGetProgramiv(glsl_program_num, GL_LINK_STATUS, &success);
+       if (success == GL_FALSE) {
+               GLchar error_log[1024] = {0};
+               glGetProgramInfoLog(glsl_program_num, 1024, nullptr, error_log);
+               fprintf(stderr, "Error linking program: %s\n", error_log);
+               exit(1);
+       }
+
+       glUseProgram(glsl_program_num);
+
+       // Upload luma.
+       GLuint y_tex;
+       glGenTextures(1, &y_tex);
+        glBindTexture(GL_TEXTURE_2D, y_tex);
+        glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST);
+        glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST);
+        glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_REPEAT);
+        glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_REPEAT);
+        glTexImage2D(GL_TEXTURE_2D, 0, GL_R8I, WIDTH, HEIGHT, 0, GL_RED_INTEGER, GL_UNSIGNED_BYTE, pix_y);
+       check_error();
+
+       // Make destination textures.
+       GLuint dc_ac7_tex, ac1_ac6_tex, ac2_ac5_tex;
+       for (GLuint *tex : { &dc_ac7_tex, &ac1_ac6_tex, &ac2_ac5_tex }) {
+               glGenTextures(1, tex);
+               glBindTexture(GL_TEXTURE_2D, *tex);
+               glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST);
+               glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST);
+               glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_REPEAT);
+               glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_REPEAT);
+               glTexImage2D(GL_TEXTURE_2D, 0, GL_R16UI, WIDTH / 8, HEIGHT, 0, GL_RED_INTEGER, GL_UNSIGNED_SHORT, nullptr);
+               check_error();
+       }
+
+       GLuint ac3_tex, ac4_tex;
+       for (GLuint *tex : { &ac3_tex, &ac4_tex }) {
+               glGenTextures(1, tex);
+               glBindTexture(GL_TEXTURE_2D, *tex);
+               glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST);
+               glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST);
+               glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_REPEAT);
+               glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_REPEAT);
+               glTexImage2D(GL_TEXTURE_2D, 0, GL_R8I, WIDTH / 8, HEIGHT, 0, GL_RED_INTEGER, GL_BYTE, nullptr);
+               check_error();
+       }
+
+       GLint dc_ac7_tex_uniform = glGetUniformLocation(glsl_program_num, "dc_ac7_tex");
+       GLint ac1_ac6_tex_uniform = glGetUniformLocation(glsl_program_num, "ac1_ac6_tex");
+       GLint ac2_ac5_tex_uniform = glGetUniformLocation(glsl_program_num, "ac2_ac5_tex");
+       GLint ac3_tex_uniform = glGetUniformLocation(glsl_program_num, "ac3_tex");
+       GLint ac4_tex_uniform = glGetUniformLocation(glsl_program_num, "ac4_tex");
+       GLint image_tex_uniform = glGetUniformLocation(glsl_program_num, "image_tex");
+
+       glUniform1i(dc_ac7_tex_uniform, 0);
+       glUniform1i(ac1_ac6_tex_uniform, 1);
+       glUniform1i(ac2_ac5_tex_uniform, 2);
+       glUniform1i(ac3_tex_uniform, 3);
+       glUniform1i(ac4_tex_uniform, 4);
+       glUniform1i(image_tex_uniform, 5);
+       glBindImageTexture(0, dc_ac7_tex, 0, GL_FALSE, 0, GL_WRITE_ONLY, GL_R16UI);
+       glBindImageTexture(1, ac1_ac6_tex, 0, GL_FALSE, 0, GL_WRITE_ONLY, GL_R16UI);
+       glBindImageTexture(2, ac2_ac5_tex, 0, GL_FALSE, 0, GL_WRITE_ONLY, GL_R16UI);
+       glBindImageTexture(3, ac3_tex, 0, GL_FALSE, 0, GL_WRITE_ONLY, GL_R8I);
+       glBindImageTexture(4, ac4_tex, 0, GL_FALSE, 0, GL_WRITE_ONLY, GL_R8I);
+       glBindImageTexture(5, y_tex, 0, GL_FALSE, 0, GL_READ_ONLY, GL_R8I);
+       check_error();
+
+       steady_clock::time_point start = steady_clock::now();
+       unsigned num_iterations = 1000;
+       for (unsigned i = 0; i < num_iterations; ++i) {
+               glDispatchCompute(WIDTH_BLOCKS, HEIGHT_BLOCKS, 1);
+       }
+       check_error();
+       glFinish();
+       steady_clock::time_point now = steady_clock::now();
+
+       // CPU part starts here -- will be GPU later.
+       // We only do luma for now.
+
+       int16_t *coeff_y = new int16_t[WIDTH * HEIGHT];
+
+       glBindTexture(GL_TEXTURE_2D, dc_ac7_tex);
+       uint16_t *dc_ac7_data = new uint16_t[(WIDTH/8) * HEIGHT];
+       glGetTexImage(GL_TEXTURE_2D, 0, GL_RED_INTEGER, GL_UNSIGNED_SHORT, dc_ac7_data);
+       check_error();
+
+       glBindTexture(GL_TEXTURE_2D, ac1_ac6_tex);
+       uint16_t *ac1_ac6_data = new uint16_t[(WIDTH/8) * HEIGHT];
+       glGetTexImage(GL_TEXTURE_2D, 0, GL_RED_INTEGER, GL_UNSIGNED_SHORT, ac1_ac6_data);
+       check_error();
+
+       glBindTexture(GL_TEXTURE_2D, ac2_ac5_tex);
+       uint16_t *ac2_ac5_data = new uint16_t[(WIDTH/8) * HEIGHT];
+       glGetTexImage(GL_TEXTURE_2D, 0, GL_RED_INTEGER, GL_UNSIGNED_SHORT, ac2_ac5_data);
+       check_error();
+
+       glBindTexture(GL_TEXTURE_2D, ac3_tex);
+       int8_t *ac3_data = new int8_t[(WIDTH/8) * HEIGHT];
+       glGetTexImage(GL_TEXTURE_2D, 0, GL_RED_INTEGER, GL_BYTE, ac3_data);
+       check_error();
+
+       glBindTexture(GL_TEXTURE_2D, ac4_tex);
+       int8_t *ac4_data = new int8_t[(WIDTH/8) * HEIGHT];
+       glGetTexImage(GL_TEXTURE_2D, 0, GL_RED_INTEGER, GL_BYTE, ac4_data);
+       check_error();
+
+       for (unsigned y = 0; y < HEIGHT; ++y) {
+               for (unsigned xb = 0; xb < WIDTH/8; ++xb) {
+                       coeff_y[y * WIDTH + xb*8 + 0] = int(dc_ac7_data[y * (WIDTH/8) + xb] << 23) >> 23;
+                       coeff_y[y * WIDTH + xb*8 + 7] = int(dc_ac7_data[y * (WIDTH/8) + xb] << 16) >> 25;
+                       coeff_y[y * WIDTH + xb*8 + 1] = int(ac1_ac6_data[y * (WIDTH/8) + xb] << 23) >> 23;
+                       coeff_y[y * WIDTH + xb*8 + 6] = int(ac1_ac6_data[y * (WIDTH/8) + xb] << 16) >> 25;
+                       coeff_y[y * WIDTH + xb*8 + 2] = int(ac2_ac5_data[y * (WIDTH/8) + xb] << 23) >> 23;
+                       coeff_y[y * WIDTH + xb*8 + 5] = int(ac2_ac5_data[y * (WIDTH/8) + xb] << 16) >> 25;
+                       coeff_y[y * WIDTH + xb*8 + 3] = ac3_data[y * (WIDTH/8) + xb];
+                       coeff_y[y * WIDTH + xb*8 + 4] = ac4_data[y * (WIDTH/8) + xb];
+               }
+       }
+
+#if 1
+       for (unsigned y = 0; y < HEIGHT; ++y) {
+               for (unsigned xb = 0; xb < WIDTH/8; ++xb) {
+                       printf("%4d %4d %4d %4d %4d %4d %4d %4d | ",
+                               coeff_y[y * WIDTH + xb*8 + 0],
+                               coeff_y[y * WIDTH + xb*8 + 1],
+                               coeff_y[y * WIDTH + xb*8 + 2],
+                               coeff_y[y * WIDTH + xb*8 + 3],
+                               coeff_y[y * WIDTH + xb*8 + 4],
+                               coeff_y[y * WIDTH + xb*8 + 5],
+                               coeff_y[y * WIDTH + xb*8 + 6],
+                               coeff_y[y * WIDTH + xb*8 + 7]);
+               }
+               printf("\n");
+       }
+#endif
+
+       // DC coefficient pred from the right to left (within each slice)
+       for (unsigned block_idx = 0; block_idx < NUM_BLOCKS; block_idx += BLOCKS_PER_STREAM) {
+               int prev_k = 128;
+
+               for (unsigned subblock_idx = BLOCKS_PER_STREAM; subblock_idx --> 0; ) {
+                       unsigned yb = (block_idx + subblock_idx) / WIDTH_BLOCKS;
+                       unsigned xb = (block_idx + subblock_idx) % WIDTH_BLOCKS;
+                       int k = coeff_y[(yb * 8) * WIDTH + xb * 8];
+
+                       coeff_y[(yb * 8) * WIDTH + xb * 8] = k - prev_k;
+
+                       prev_k = k;
+               }
+       }
+
+       // For each coefficient, make some tables.
+       size_t extra_bits = 0;
+       for (unsigned i = 0; i < 64; ++i) {
+               stats[i].clear();
+       }
+       for (unsigned y = 0; y < 8; ++y) {
+               for (unsigned x = 0; x < 8; ++x) {
+                       SymbolStats &s_luma = stats[pick_stats_for(x, y)];
+
+                       // Luma
+                       for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
+                               for (unsigned xb = 0; xb < WIDTH; xb += 8) {
+                                       unsigned short k = abs(coeff_y[(yb + y) * WIDTH + (xb + x)]);
+                                       if (k >= ESCAPE_LIMIT) {
+                                               k = ESCAPE_LIMIT;
+                                               extra_bits += 12;  // escape this one
+                                       }
+                                       ++s_luma.freqs[(k - 1) & (NUM_SYMS - 1)];
+                               }
+                       }
+               }
+       }
+
+       for (unsigned i = 0; i < 64; ++i) {
+               stats[i].freqs[NUM_SYMS - 1] /= 2;  // zero, has no sign bits (yes, this is trickery)
+               stats[i].normalize_freqs(prob_scale);
+               stats[i].cum_freqs[NUM_SYMS] += stats[i].freqs[NUM_SYMS - 1];
+               stats[i].freqs[NUM_SYMS - 1] *= 2;
+       }
+
+       FILE *codedfp = fopen("coded.dat", "wb");
+       if (codedfp == nullptr) {
+               perror("coded.dat");
+               exit(1);
+       }
+
+       for (unsigned r = 0; r < 2; ++r) {  // Hack to write fake chroma tables.
+               // TODO: rather gamma-k or something
+               for (unsigned i = 0; i < 64; ++i) {
+                       if (stats[i].cum_freqs[NUM_SYMS] == 0) {
+                               continue;
+                       }
+                       printf("writing table %d\n", i);
+                       for (unsigned j = 0; j < NUM_SYMS; ++j) {
+                               write_varint(stats[i].freqs[j], codedfp);
+                       }
+               }
+       }
+
+       RansEncoder rans_encoder;
+
+       size_t tot_bytes = 0;
+
+       // Luma
+       for (unsigned y = 0; y < 8; ++y) {
+               for (unsigned x = 0; x < 8; ++x) {
+                       SymbolStats &s_luma = stats[pick_stats_for(x, y)];
+                       rans_encoder.init_prob(s_luma);
+
+                       // Luma
+                       std::vector<int> lens;
+
+                       rans_encoder.clear();
+                       size_t num_bytes = 0;
+                       for (unsigned block_idx = 0; block_idx < NUM_BLOCKS; ++block_idx) {
+                               unsigned yb = block_idx / WIDTH_BLOCKS;
+                               unsigned xb = block_idx % WIDTH_BLOCKS;
+
+                               int k = coeff_y[(yb * 8 + y) * WIDTH + (xb * 8 + x)];
+                               rans_encoder.encode_coeff(k);
+
+                               if (block_idx % BLOCKS_PER_STREAM == (BLOCKS_PER_STREAM - 1) || block_idx == NUM_BLOCKS - 1) {
+                                       int l = rans_encoder.save_block(codedfp);
+                                       num_bytes += l;
+                                       lens.push_back(l);
+                               }
+                       }
+                       tot_bytes += num_bytes;
+                       printf("coeff %d Y': %ld bytes\n", y * 8 + x, num_bytes);
+               }
+       }
+
+       printf("%ld bytes + %ld escape bits (%ld) = %ld total bytes\n",
+               tot_bytes - extra_bits / 8,
+               extra_bits,
+               extra_bits / 8,
+               tot_bytes);
+
+       printf("\n");
+       printf("Each iteration took %.3f ms (but note that is DCT only, no rANS).\n", 1e3 * duration<double>(now - start).count() / num_iterations);
+
+}