]> git.sesse.net Git - narabu/blob - narabu-encoder.cpp
Speed up the histogram counting immensely by adding via local memory.
[narabu] / narabu-encoder.cpp
1 #include <stdio.h>
2 #include <stdint.h>
3 #include <stdlib.h>
4 #include <string.h>
5 #include <assert.h>
6 #include <math.h>
7 #include <SDL2/SDL.h>
8 #include <SDL2/SDL_error.h>
9 #include <SDL2/SDL_video.h>
10 #include <epoxy/gl.h>
11
12 #include <algorithm>
13 #include <chrono>
14 #include <memory>
15 #include <numeric>
16 #include <random>
17 #include <vector>
18 #include <unordered_map>
19
20 #include <movit/util.h>
21
22 #include "ryg_rans/rans_byte.h"
23 #include "ryg_rans/renormalize.h"
24 #include "util.h"
25
26 #define WIDTH 1280
27 #define HEIGHT 720
28 #define WIDTH_BLOCKS (WIDTH/8)
29 #define WIDTH_BLOCKS_CHROMA (WIDTH/16)
30 #define HEIGHT_BLOCKS (HEIGHT/8)
31 #define NUM_BLOCKS (WIDTH_BLOCKS * HEIGHT_BLOCKS)
32 #define NUM_BLOCKS_CHROMA (WIDTH_BLOCKS_CHROMA * HEIGHT_BLOCKS)
33
34 #define NUM_SYMS 256
35 #define ESCAPE_LIMIT (NUM_SYMS - 1)
36 #define BLOCKS_PER_STREAM 320
37
38 static constexpr uint32_t prob_bits = 12;
39 static constexpr uint32_t prob_scale = 1 << prob_bits;
40
41 unsigned char rgb[WIDTH * HEIGHT * 3];
42 unsigned char pix_y[WIDTH * HEIGHT];
43 unsigned char pix_cb[(WIDTH/2) * HEIGHT];
44 unsigned char pix_cr[(WIDTH/2) * HEIGHT];
45
46 using namespace std;
47 using namespace std::chrono;
48
49 void write_varint(int x, FILE *fp)
50 {
51         while (x >= 128) {
52                 putc((x & 0x7f) | 0x80, fp);
53                 x >>= 7;
54         }
55         putc(x, fp);
56 }
57
58 void readpix(unsigned char *ptr, const char *filename)
59 {
60         FILE *fp = fopen(filename, "rb");
61         if (fp == nullptr) {
62                 perror(filename);
63                 exit(1);
64         }
65
66         fseek(fp, 0, SEEK_END);
67         long len = ftell(fp);
68         assert(len >= WIDTH * HEIGHT * 3);
69         fseek(fp, len - WIDTH * HEIGHT * 3, SEEK_SET);
70
71         fread(ptr, 1, WIDTH * HEIGHT * 3, fp);
72         fclose(fp);
73 }
74
75 struct SymbolStats
76 {
77     uint32_t freqs[NUM_SYMS];
78     uint32_t cum_freqs[NUM_SYMS + 1];
79
80     void clear();
81     void calc_cum_freqs();
82     void normalize_freqs(uint32_t target_total);
83 };
84
85 void SymbolStats::clear()
86 {
87     for (int i=0; i < NUM_SYMS; i++)
88         freqs[i] = 0;
89 }
90
91 void SymbolStats::calc_cum_freqs()
92 {
93     cum_freqs[0] = 0;
94     for (int i=0; i < NUM_SYMS; i++)
95         cum_freqs[i+1] = cum_freqs[i] + freqs[i];
96 }
97
98 void SymbolStats::normalize_freqs(uint32_t target_total)
99 {
100     uint64_t real_freq[NUM_SYMS + 1];  // hack
101
102     assert(target_total >= NUM_SYMS);
103
104     calc_cum_freqs();
105     uint32_t cur_total = cum_freqs[NUM_SYMS];
106
107     if (cur_total == 0) return;
108
109     double ideal_cost = 0.0;
110     for (int i = 1; i <= NUM_SYMS; i++)
111     {
112       real_freq[i] = cum_freqs[i] - cum_freqs[i - 1];
113       if (real_freq[i] > 0)
114         ideal_cost -= real_freq[i] * log2(real_freq[i] / double(cur_total));
115     }
116
117     OptimalRenormalize(cum_freqs, NUM_SYMS, prob_scale);
118
119     // calculate updated freqs and make sure we didn't screw anything up
120     assert(cum_freqs[0] == 0 && cum_freqs[NUM_SYMS] == target_total);
121     for (int i=0; i < NUM_SYMS; i++) {
122         if (freqs[i] == 0)
123             assert(cum_freqs[i+1] == cum_freqs[i]);
124         else
125             assert(cum_freqs[i+1] > cum_freqs[i]);
126
127         // calc updated freq
128         freqs[i] = cum_freqs[i+1] - cum_freqs[i];
129     }
130
131     double calc_cost = 0.0;
132     for (int i = 1; i <= NUM_SYMS; i++)
133     {
134       uint64_t freq = cum_freqs[i] - cum_freqs[i - 1];
135       if (real_freq[i] > 0)
136         calc_cost -= real_freq[i] * log2(freq / double(target_total));
137     }
138
139     static double total_loss = 0.0;
140     total_loss += calc_cost - ideal_cost;
141     static double total_loss_with_dp = 0.0;
142         double optimal_cost = 0.0;
143     //total_loss_with_dp += optimal_cost - ideal_cost;
144     printf("ideal cost = %.0f bits, DP cost = %.0f bits, calc cost = %.0f bits (loss = %.2f bytes, total loss = %.2f bytes, total loss with DP = %.2f bytes)\n",
145                 ideal_cost, optimal_cost,
146                  calc_cost, (calc_cost - ideal_cost) / 8.0, total_loss / 8.0, total_loss_with_dp / 8.0);
147 }
148
149 SymbolStats stats[128];
150
151 const int luma_mapping[64] = {
152         0, 0, 1, 1, 2, 2, 3, 3,
153         0, 0, 1, 2, 2, 2, 3, 3,
154         1, 1, 2, 2, 2, 3, 3, 3,
155         1, 1, 2, 2, 2, 3, 3, 3,
156         1, 2, 2, 2, 2, 3, 3, 3,
157         2, 2, 2, 2, 3, 3, 3, 3,
158         2, 2, 3, 3, 3, 3, 3, 3,
159         3, 3, 3, 3, 3, 3, 3, 3,
160 };
161
162 int pick_stats_for(int x, int y)
163 {
164         return luma_mapping[y * 8 + x];
165 }
166
167 class RansEncoder {
168 public:
169         RansEncoder()
170         {
171                 out_buf.reset(new uint8_t[out_max_size]);
172                 clear();
173         }
174
175         void init_prob(SymbolStats &s)
176         {
177                 for (int i = 0; i < NUM_SYMS; i++) {
178                         //printf("%d: cumfreqs=%d freqs=%d prob_bits=%d\n", i, s.cum_freqs[i], s.freqs[i], prob_bits + 1);
179                         RansEncSymbolInit(&esyms[i], s.cum_freqs[i], s.freqs[i], prob_bits + 1);
180                 }
181                 sign_bias = s.cum_freqs[NUM_SYMS];
182         }
183
184         void clear()
185         {
186                 out_end = out_buf.get() + out_max_size;
187                 ptr = out_end; // *end* of output buffer
188                 RansEncInit(&rans);
189         }
190
191         uint32_t save_block(FILE *codedfp)  // Returns number of bytes.
192         {
193                 RansEncFlush(&rans, &ptr);
194                 //printf("post-flush = %08x\n", rans);
195
196                 uint32_t num_rans_bytes = out_end - ptr;
197                 if (num_rans_bytes == last_block.size() &&
198                     memcmp(last_block.data(), ptr, last_block.size()) == 0) {
199                         write_varint(0, codedfp);
200                         clear();
201                         return 1;
202                 } else {
203                         last_block = string((const char *)ptr, num_rans_bytes);
204                 }
205
206                 write_varint(num_rans_bytes, codedfp);
207                 //fwrite(&num_rans_bytes, 1, 4, codedfp);
208                 fwrite(ptr, 1, num_rans_bytes, codedfp);
209
210                 //printf("first rANS bytes: %02x %02x %02x %02x %02x %02x %02x %02x\n", ptr[0], ptr[1], ptr[2], ptr[3], ptr[4], ptr[5], ptr[6], ptr[7]);
211
212
213                 clear();
214
215                 //printf("Saving block: %d rANS bytes\n", num_rans_bytes);
216                 return num_rans_bytes;
217                 //return num_rans_bytes;
218         }
219
220         void encode_coeff(short signed_k)
221         {
222                 //printf("encoding coeff %d (sym %d), rans before encoding = %08x\n", signed_k, ((abs(signed_k) - 1) & 255), rans);
223                 unsigned short k = abs(signed_k);
224                 if (k >= ESCAPE_LIMIT) {
225                         // Put the coefficient as a 1/(2^12) symbol _before_
226                         // the 255 coefficient, since the decoder will read the
227                         // 255 coefficient first.
228                         RansEncPut(&rans, &ptr, k, 1, prob_bits);
229                         k = ESCAPE_LIMIT;
230                 }
231                 RansEncPutSymbol(&rans, &ptr, &esyms[(k - 1) & (NUM_SYMS - 1)]);
232                 if (signed_k < 0) {
233                         rans += sign_bias;
234                 }
235         }
236
237 private:
238         static constexpr size_t out_max_size = 32 << 20; // 32 MB.
239         static constexpr size_t max_num_sign = 1048576;  // Way too big. And actually bytes.
240
241         unique_ptr<uint8_t[]> out_buf;
242         uint8_t *out_end;
243         uint8_t *ptr;
244         RansState rans;
245         RansEncSymbol esyms[NUM_SYMS];
246         uint32_t sign_bias;
247
248         std::string last_block;
249 };
250
251 // Should be done on the GPU, of course, but irrelevant for the demonstration.
252 void convert_ycbcr()
253 {
254         double coeff[3] = { 0.2126, 0.7152, 0.0722 };  // sum = 1.0
255         double cb_fac = 1.0 / (coeff[0] + coeff[1] + 1.0f - coeff[2]);  // 0.539
256         double cr_fac = 1.0 / (1.0f - coeff[0] + coeff[1] + coeff[2]);  // 0.635 
257
258         unique_ptr<float[]> temp_cb(new float[WIDTH * HEIGHT]);
259         unique_ptr<float[]> temp_cr(new float[WIDTH * HEIGHT]);
260         for (unsigned yb = 0; yb < HEIGHT; ++yb) {
261                 for (unsigned xb = 0; xb < WIDTH; ++xb) {
262                         int r = rgb[((yb * WIDTH) + xb) * 3 + 0];
263                         int g = rgb[((yb * WIDTH) + xb) * 3 + 1];
264                         int b = rgb[((yb * WIDTH) + xb) * 3 + 2];
265                         double y = std::min(std::max(coeff[0] * r + coeff[1] * g + coeff[2] * b, 0.0), 255.0);
266                         double cb = (b - y) * cb_fac + 128.0;
267                         double cr = (r - y) * cr_fac + 128.0;
268                         pix_y[(yb * WIDTH) + xb] = lrint(y);
269                         temp_cb[(yb * WIDTH) + xb] = cb;
270                         temp_cr[(yb * WIDTH) + xb] = cr;
271                 }
272         }
273
274         // Simple 4:2:2 subsampling with left convention.
275         for (unsigned yb = 0; yb < HEIGHT; ++yb) {
276                 for (unsigned xb = 0; xb < WIDTH / 2; ++xb) {
277                         int c0 = yb * WIDTH + std::max(int(xb) * 2 - 1, 0);
278                         int c1 = yb * WIDTH + xb * 2;
279                         int c2 = yb * WIDTH + xb * 2 + 1;
280                         
281                         double cb = 0.25 * temp_cb[c0] + 0.5 * temp_cb[c1] + 0.25 * temp_cb[c2];
282                         double cr = 0.25 * temp_cr[c0] + 0.5 * temp_cr[c1] + 0.25 * temp_cr[c2];
283                         cb = std::min(std::max(cb, 0.0), 255.0);
284                         cr = std::min(std::max(cr, 0.0), 255.0);
285                         pix_cb[(yb * WIDTH/2) + xb] = lrint(cb);
286                         pix_cr[(yb * WIDTH/2) + xb] = lrint(cr);
287                 }
288         }
289 }
290
291 int main(int argc, char **argv)
292 {
293         // Set up an OpenGL context using SDL.
294         if (SDL_Init(SDL_INIT_VIDEO) == -1) {
295                 fprintf(stderr, "SDL_Init failed: %s\n", SDL_GetError());
296                 exit(1);
297         }
298         SDL_GL_SetAttribute(SDL_GL_DEPTH_SIZE, 0);
299         SDL_GL_SetAttribute(SDL_GL_STENCIL_SIZE, 0);
300         SDL_GL_SetAttribute(SDL_GL_DOUBLEBUFFER, 1);
301         SDL_GL_SetAttribute(SDL_GL_CONTEXT_PROFILE_MASK, SDL_GL_CONTEXT_PROFILE_CORE);
302         SDL_GL_SetAttribute(SDL_GL_CONTEXT_MAJOR_VERSION, 4);
303         SDL_GL_SetAttribute(SDL_GL_CONTEXT_MINOR_VERSION, 5);
304
305         SDL_Window *window = SDL_CreateWindow("OpenGL window for unit test",
306                 SDL_WINDOWPOS_UNDEFINED,
307                 SDL_WINDOWPOS_UNDEFINED,
308                 32, 32,
309                 SDL_WINDOW_OPENGL);
310         SDL_GLContext context = SDL_GL_CreateContext(window);
311         assert(context != nullptr);
312
313         if (argc >= 2)
314                 readpix(rgb, argv[1]);
315         else
316                 readpix(rgb, "color.pnm");
317         convert_ycbcr();
318
319         // Compile the shader.
320         string shader_src = ::read_file("encoder.shader");
321         GLuint shader_num = compile_shader(shader_src, GL_COMPUTE_SHADER);
322         GLuint glsl_program_num = glCreateProgram();
323         glAttachShader(glsl_program_num, shader_num);
324         glLinkProgram(glsl_program_num);
325
326         GLint success;
327         glGetProgramiv(glsl_program_num, GL_LINK_STATUS, &success);
328         if (success == GL_FALSE) {
329                 GLchar error_log[1024] = {0};
330                 glGetProgramInfoLog(glsl_program_num, 1024, nullptr, error_log);
331                 fprintf(stderr, "Error linking program: %s\n", error_log);
332                 exit(1);
333         }
334
335         glUseProgram(glsl_program_num);
336
337         // An SSBO for the rANS distributions.
338         GLuint ssbo;
339         glGenBuffers(1, &ssbo);
340         glBindBuffer(GL_SHADER_STORAGE_BUFFER, ssbo);
341         glBufferData(GL_SHADER_STORAGE_BUFFER, 65536 * 4 * sizeof(uint32_t), nullptr, GL_DYNAMIC_COPY);
342         glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 9, ssbo);
343
344         // Upload luma.
345         GLuint y_tex;
346         glGenTextures(1, &y_tex);
347         glBindTexture(GL_TEXTURE_2D, y_tex);
348         glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST);
349         glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST);
350         glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_REPEAT);
351         glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_REPEAT);
352         glTexImage2D(GL_TEXTURE_2D, 0, GL_R8UI, WIDTH, HEIGHT, 0, GL_RED_INTEGER, GL_UNSIGNED_BYTE, pix_y);
353         check_error();
354
355         // Make destination textures.
356         GLuint dc_ac7_tex, ac1_ac6_tex, ac2_ac5_tex;
357         for (GLuint *tex : { &dc_ac7_tex, &ac1_ac6_tex, &ac2_ac5_tex }) {
358                 glGenTextures(1, tex);
359                 glBindTexture(GL_TEXTURE_2D, *tex);
360                 glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST);
361                 glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST);
362                 glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_REPEAT);
363                 glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_REPEAT);
364                 glTexImage2D(GL_TEXTURE_2D, 0, GL_R16UI, WIDTH / 8, HEIGHT, 0, GL_RED_INTEGER, GL_UNSIGNED_SHORT, nullptr);
365                 check_error();
366         }
367
368         GLuint ac3_tex, ac4_tex;
369         for (GLuint *tex : { &ac3_tex, &ac4_tex }) {
370                 glGenTextures(1, tex);
371                 glBindTexture(GL_TEXTURE_2D, *tex);
372                 glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST);
373                 glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST);
374                 glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_REPEAT);
375                 glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_REPEAT);
376                 glTexImage2D(GL_TEXTURE_2D, 0, GL_R8I, WIDTH / 8, HEIGHT, 0, GL_RED_INTEGER, GL_BYTE, nullptr);
377                 check_error();
378         }
379
380         GLint dc_ac7_tex_uniform = glGetUniformLocation(glsl_program_num, "dc_ac7_tex");
381         GLint ac1_ac6_tex_uniform = glGetUniformLocation(glsl_program_num, "ac1_ac6_tex");
382         GLint ac2_ac5_tex_uniform = glGetUniformLocation(glsl_program_num, "ac2_ac5_tex");
383         GLint ac3_tex_uniform = glGetUniformLocation(glsl_program_num, "ac3_tex");
384         GLint ac4_tex_uniform = glGetUniformLocation(glsl_program_num, "ac4_tex");
385         GLint image_tex_uniform = glGetUniformLocation(glsl_program_num, "image_tex");
386
387         glUniform1i(dc_ac7_tex_uniform, 0);
388         glUniform1i(ac1_ac6_tex_uniform, 1);
389         glUniform1i(ac2_ac5_tex_uniform, 2);
390         glUniform1i(ac3_tex_uniform, 3);
391         glUniform1i(ac4_tex_uniform, 4);
392         glUniform1i(image_tex_uniform, 5);
393         glBindImageTexture(0, dc_ac7_tex, 0, GL_FALSE, 0, GL_WRITE_ONLY, GL_R16UI);
394         glBindImageTexture(1, ac1_ac6_tex, 0, GL_FALSE, 0, GL_WRITE_ONLY, GL_R16UI);
395         glBindImageTexture(2, ac2_ac5_tex, 0, GL_FALSE, 0, GL_WRITE_ONLY, GL_R16UI);
396         glBindImageTexture(3, ac3_tex, 0, GL_FALSE, 0, GL_WRITE_ONLY, GL_R8I);
397         glBindImageTexture(4, ac4_tex, 0, GL_FALSE, 0, GL_WRITE_ONLY, GL_R8I);
398         glBindImageTexture(5, y_tex, 0, GL_FALSE, 0, GL_READ_ONLY, GL_R8UI);
399         check_error();
400
401         steady_clock::time_point start = steady_clock::now();
402         unsigned num_iterations = 100;
403         for (unsigned i = 0; i < num_iterations; ++i) {
404                 glDispatchCompute(WIDTH_BLOCKS / 16, HEIGHT_BLOCKS, 1);
405         }
406         check_error();
407         glFinish();
408         steady_clock::time_point now = steady_clock::now();
409
410         // CPU part starts here -- will be GPU later.
411         // We only do luma for now.
412
413         int16_t *coeff_y = new int16_t[WIDTH * HEIGHT];
414
415         glBindTexture(GL_TEXTURE_2D, dc_ac7_tex);
416         uint16_t *dc_ac7_data = new uint16_t[(WIDTH/8) * HEIGHT];
417         glGetTexImage(GL_TEXTURE_2D, 0, GL_RED_INTEGER, GL_UNSIGNED_SHORT, dc_ac7_data);
418         check_error();
419
420         glBindTexture(GL_TEXTURE_2D, ac1_ac6_tex);
421         uint16_t *ac1_ac6_data = new uint16_t[(WIDTH/8) * HEIGHT];
422         glGetTexImage(GL_TEXTURE_2D, 0, GL_RED_INTEGER, GL_UNSIGNED_SHORT, ac1_ac6_data);
423         check_error();
424
425         glBindTexture(GL_TEXTURE_2D, ac2_ac5_tex);
426         uint16_t *ac2_ac5_data = new uint16_t[(WIDTH/8) * HEIGHT];
427         glGetTexImage(GL_TEXTURE_2D, 0, GL_RED_INTEGER, GL_UNSIGNED_SHORT, ac2_ac5_data);
428         check_error();
429
430         glBindTexture(GL_TEXTURE_2D, ac3_tex);
431         int8_t *ac3_data = new int8_t[(WIDTH/8) * HEIGHT];
432         glGetTexImage(GL_TEXTURE_2D, 0, GL_RED_INTEGER, GL_BYTE, ac3_data);
433         check_error();
434
435         glBindTexture(GL_TEXTURE_2D, ac4_tex);
436         int8_t *ac4_data = new int8_t[(WIDTH/8) * HEIGHT];
437         glGetTexImage(GL_TEXTURE_2D, 0, GL_RED_INTEGER, GL_BYTE, ac4_data);
438         check_error();
439
440         for (unsigned y = 0; y < HEIGHT; ++y) {
441                 for (unsigned xb = 0; xb < WIDTH/8; ++xb) {
442                         coeff_y[y * WIDTH + xb*8 + 0] = int(dc_ac7_data[y * (WIDTH/8) + xb] << 23) >> 23;
443                         coeff_y[y * WIDTH + xb*8 + 7] = int(dc_ac7_data[y * (WIDTH/8) + xb] << 16) >> 25;
444                         coeff_y[y * WIDTH + xb*8 + 1] = int(ac1_ac6_data[y * (WIDTH/8) + xb] << 23) >> 23;
445                         coeff_y[y * WIDTH + xb*8 + 6] = int(ac1_ac6_data[y * (WIDTH/8) + xb] << 16) >> 25;
446                         coeff_y[y * WIDTH + xb*8 + 2] = int(ac2_ac5_data[y * (WIDTH/8) + xb] << 23) >> 23;
447                         coeff_y[y * WIDTH + xb*8 + 5] = int(ac2_ac5_data[y * (WIDTH/8) + xb] << 16) >> 25;
448                         coeff_y[y * WIDTH + xb*8 + 3] = ac3_data[y * (WIDTH/8) + xb];
449                         coeff_y[y * WIDTH + xb*8 + 4] = ac4_data[y * (WIDTH/8) + xb];
450                 }
451         }
452
453 #if 1
454         for (unsigned y = 0; y < HEIGHT; ++y) {
455                 for (unsigned xb = 0; xb < WIDTH/8; ++xb) {
456                         printf("%4d %4d %4d %4d %4d %4d %4d %4d | ",
457                                 coeff_y[y * WIDTH + xb*8 + 0],
458                                 coeff_y[y * WIDTH + xb*8 + 1],
459                                 coeff_y[y * WIDTH + xb*8 + 2],
460                                 coeff_y[y * WIDTH + xb*8 + 3],
461                                 coeff_y[y * WIDTH + xb*8 + 4],
462                                 coeff_y[y * WIDTH + xb*8 + 5],
463                                 coeff_y[y * WIDTH + xb*8 + 6],
464                                 coeff_y[y * WIDTH + xb*8 + 7]);
465                         printf("%4d %4d %4d %4d %4d %4d %4d %4d || ",
466                                 pix_y[y * WIDTH + xb*8 + 0],
467                                 pix_y[y * WIDTH + xb*8 + 1],
468                                 pix_y[y * WIDTH + xb*8 + 2],
469                                 pix_y[y * WIDTH + xb*8 + 3],
470                                 pix_y[y * WIDTH + xb*8 + 4],
471                                 pix_y[y * WIDTH + xb*8 + 5],
472                                 pix_y[y * WIDTH + xb*8 + 6],
473                                 pix_y[y * WIDTH + xb*8 + 7]);
474                 }
475                 printf("\n");
476         }
477 #endif
478
479         // DC coefficient pred from the right to left (within each slice)
480         for (unsigned block_idx = 0; block_idx < NUM_BLOCKS; block_idx += BLOCKS_PER_STREAM) {
481                 int prev_k = 128;
482
483                 for (unsigned subblock_idx = BLOCKS_PER_STREAM; subblock_idx --> 0; ) {
484                         unsigned yb = (block_idx + subblock_idx) / WIDTH_BLOCKS;
485                         unsigned xb = (block_idx + subblock_idx) % WIDTH_BLOCKS;
486                         int k = coeff_y[(yb * 8) * WIDTH + xb * 8];
487
488                         coeff_y[(yb * 8) * WIDTH + xb * 8] = k - prev_k;
489
490                         prev_k = k;
491                 }
492         }
493
494         // For each coefficient, make some tables.
495         size_t extra_bits = 0;
496         for (unsigned i = 0; i < 64; ++i) {
497                 stats[i].clear();
498         }
499         for (unsigned y = 0; y < 8; ++y) {
500                 for (unsigned x = 0; x < 8; ++x) {
501                         SymbolStats &s_luma = stats[pick_stats_for(x, y)];
502
503                         // Luma
504                         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
505                                 for (unsigned xb = 0; xb < WIDTH; xb += 8) {
506                                         unsigned short k = abs(coeff_y[(yb + y) * WIDTH + (xb + x)]);
507                                         if (k >= ESCAPE_LIMIT) {
508                                                 k = ESCAPE_LIMIT;
509                                                 extra_bits += 12;  // escape this one
510                                         }
511                                         ++s_luma.freqs[(k - 1) & (NUM_SYMS - 1)];
512                                 }
513                         }
514                 }
515         }
516
517         for (unsigned i = 0; i < 64; ++i) {
518                 stats[i].freqs[NUM_SYMS - 1] /= 2;  // zero, has no sign bits (yes, this is trickery)
519                 stats[i].normalize_freqs(prob_scale);
520                 stats[i].cum_freqs[NUM_SYMS] += stats[i].freqs[NUM_SYMS - 1];
521                 stats[i].freqs[NUM_SYMS - 1] *= 2;
522         }
523
524         FILE *codedfp = fopen("coded.dat", "wb");
525         if (codedfp == nullptr) {
526                 perror("coded.dat");
527                 exit(1);
528         }
529
530         for (unsigned r = 0; r < 2; ++r) {  // Hack to write fake chroma tables.
531                 // TODO: rather gamma-k or something
532                 for (unsigned i = 0; i < 64; ++i) {
533                         if (stats[i].cum_freqs[NUM_SYMS] == 0) {
534                                 continue;
535                         }
536                         printf("writing table %d\n", i);
537                         for (unsigned j = 0; j < NUM_SYMS; ++j) {
538                                 write_varint(stats[i].freqs[j], codedfp);
539                         }
540                 }
541         }
542
543         RansEncoder rans_encoder;
544
545         size_t tot_bytes = 0;
546
547         // Luma
548         for (unsigned y = 0; y < 8; ++y) {
549                 for (unsigned x = 0; x < 8; ++x) {
550                         SymbolStats &s_luma = stats[pick_stats_for(x, y)];
551                         rans_encoder.init_prob(s_luma);
552
553                         // Luma
554                         std::vector<int> lens;
555
556                         rans_encoder.clear();
557                         size_t num_bytes = 0;
558                         for (unsigned block_idx = 0; block_idx < NUM_BLOCKS; ++block_idx) {
559                                 unsigned yb = block_idx / WIDTH_BLOCKS;
560                                 unsigned xb = block_idx % WIDTH_BLOCKS;
561
562                                 int k = coeff_y[(yb * 8 + y) * WIDTH + (xb * 8 + x)];
563                                 rans_encoder.encode_coeff(k);
564
565                                 if (block_idx % BLOCKS_PER_STREAM == (BLOCKS_PER_STREAM - 1) || block_idx == NUM_BLOCKS - 1) {
566                                         int l = rans_encoder.save_block(codedfp);
567                                         num_bytes += l;
568                                         lens.push_back(l);
569                                 }
570                         }
571                         tot_bytes += num_bytes;
572                         printf("coeff %d Y': %ld bytes\n", y * 8 + x, num_bytes);
573                 }
574         }
575
576         printf("%ld bytes + %ld escape bits (%ld) = %ld total bytes\n",
577                 tot_bytes - extra_bits / 8,
578                 extra_bits,
579                 extra_bits / 8,
580                 tot_bytes);
581
582         printf("\n");
583         printf("Each iteration took %.3f ms (but note that is DCT only, no rANS).\n", 1e3 * duration<double>(now - start).count() / num_iterations);
584
585 #if 1
586         glBindBuffer(GL_SHADER_STORAGE_BUFFER, ssbo);
587         const uint32_t *dist = (const uint32_t *)glMapBuffer(GL_SHADER_STORAGE_BUFFER, GL_READ_ONLY);
588         for (int i = 0; i < 1024; ++i) {
589                 printf("%d,%d: %u\n", i / 256, i % 256, dist[i] / num_iterations);
590         }
591 #endif
592 }