]> git.sesse.net Git - narabu/blob - narabu-encoder.cpp
Silence some Mesa warnings.
[narabu] / narabu-encoder.cpp
1 #include <stdio.h>
2 #include <stdint.h>
3 #include <stdlib.h>
4 #include <string.h>
5 #include <assert.h>
6 #include <math.h>
7 #include <SDL2/SDL.h>
8 #include <SDL2/SDL_error.h>
9 #include <SDL2/SDL_video.h>
10 #include <epoxy/gl.h>
11
12 #include <algorithm>
13 #include <chrono>
14 #include <memory>
15 #include <numeric>
16 #include <random>
17 #include <vector>
18 #include <unordered_map>
19
20 #include <movit/util.h>
21
22 #include "ryg_rans/rans_byte.h"
23 #include "ryg_rans/renormalize.h"
24 #include "util.h"
25
26 #define WIDTH 1280
27 #define HEIGHT 720
28 #define WIDTH_BLOCKS (WIDTH/8)
29 #define WIDTH_BLOCKS_CHROMA (WIDTH/16)
30 #define HEIGHT_BLOCKS (HEIGHT/8)
31 #define NUM_BLOCKS (WIDTH_BLOCKS * HEIGHT_BLOCKS)
32 #define NUM_BLOCKS_CHROMA (WIDTH_BLOCKS_CHROMA * HEIGHT_BLOCKS)
33
34 #define NUM_SYMS 256
35 #define ESCAPE_LIMIT (NUM_SYMS - 1)
36 #define BLOCKS_PER_STREAM 320
37
38 static constexpr uint32_t prob_bits = 12;
39 static constexpr uint32_t prob_scale = 1 << prob_bits;
40
41 unsigned char rgb[WIDTH * HEIGHT * 3];
42 unsigned char pix_y[WIDTH * HEIGHT];
43 unsigned char pix_cb[(WIDTH/2) * HEIGHT];
44 unsigned char pix_cr[(WIDTH/2) * HEIGHT];
45
46 using namespace std;
47 using namespace std::chrono;
48
49 void write_varint(int x, FILE *fp)
50 {
51         while (x >= 128) {
52                 putc((x & 0x7f) | 0x80, fp);
53                 x >>= 7;
54         }
55         putc(x, fp);
56 }
57
58 void readpix(unsigned char *ptr, const char *filename)
59 {
60         FILE *fp = fopen(filename, "rb");
61         if (fp == nullptr) {
62                 perror(filename);
63                 exit(1);
64         }
65
66         fseek(fp, 0, SEEK_END);
67         long len = ftell(fp);
68         assert(len >= WIDTH * HEIGHT * 3);
69         fseek(fp, len - WIDTH * HEIGHT * 3, SEEK_SET);
70
71         fread(ptr, 1, WIDTH * HEIGHT * 3, fp);
72         fclose(fp);
73 }
74
75 struct SymbolStats
76 {
77     uint32_t freqs[NUM_SYMS];
78     uint32_t cum_freqs[NUM_SYMS + 1];
79
80     void clear();
81     void calc_cum_freqs();
82     void normalize_freqs(uint32_t target_total);
83 };
84
85 void SymbolStats::clear()
86 {
87     for (int i=0; i < NUM_SYMS; i++)
88         freqs[i] = 0;
89 }
90
91 void SymbolStats::calc_cum_freqs()
92 {
93     cum_freqs[0] = 0;
94     for (int i=0; i < NUM_SYMS; i++)
95         cum_freqs[i+1] = cum_freqs[i] + freqs[i];
96 }
97
98 void SymbolStats::normalize_freqs(uint32_t target_total)
99 {
100     uint64_t real_freq[NUM_SYMS + 1];  // hack
101
102     assert(target_total >= NUM_SYMS);
103
104     calc_cum_freqs();
105     uint32_t cur_total = cum_freqs[NUM_SYMS];
106
107     if (cur_total == 0) return;
108
109     double ideal_cost = 0.0;
110     for (int i = 1; i <= NUM_SYMS; i++)
111     {
112       real_freq[i] = cum_freqs[i] - cum_freqs[i - 1];
113       if (real_freq[i] > 0)
114         ideal_cost -= real_freq[i] * log2(real_freq[i] / double(cur_total));
115     }
116
117     OptimalRenormalize(cum_freqs, NUM_SYMS, prob_scale);
118
119     // calculate updated freqs and make sure we didn't screw anything up
120     assert(cum_freqs[0] == 0 && cum_freqs[NUM_SYMS] == target_total);
121     for (int i=0; i < NUM_SYMS; i++) {
122         if (freqs[i] == 0)
123             assert(cum_freqs[i+1] == cum_freqs[i]);
124         else
125             assert(cum_freqs[i+1] > cum_freqs[i]);
126
127         // calc updated freq
128         freqs[i] = cum_freqs[i+1] - cum_freqs[i];
129     }
130
131     double calc_cost = 0.0;
132     for (int i = 1; i <= NUM_SYMS; i++)
133     {
134       uint64_t freq = cum_freqs[i] - cum_freqs[i - 1];
135       if (real_freq[i] > 0)
136         calc_cost -= real_freq[i] * log2(freq / double(target_total));
137     }
138
139     static double total_loss = 0.0;
140     total_loss += calc_cost - ideal_cost;
141     static double total_loss_with_dp = 0.0;
142         double optimal_cost = 0.0;
143     //total_loss_with_dp += optimal_cost - ideal_cost;
144     printf("ideal cost = %.0f bits, DP cost = %.0f bits, calc cost = %.0f bits (loss = %.2f bytes, total loss = %.2f bytes, total loss with DP = %.2f bytes)\n",
145                 ideal_cost, optimal_cost,
146                  calc_cost, (calc_cost - ideal_cost) / 8.0, total_loss / 8.0, total_loss_with_dp / 8.0);
147 }
148
149 SymbolStats stats[128];
150
151 const int luma_mapping[64] = {
152         0, 0, 1, 1, 2, 2, 3, 3,
153         0, 0, 1, 2, 2, 2, 3, 3,
154         1, 1, 2, 2, 2, 3, 3, 3,
155         1, 1, 2, 2, 2, 3, 3, 3,
156         1, 2, 2, 2, 2, 3, 3, 3,
157         2, 2, 2, 2, 3, 3, 3, 3,
158         2, 2, 3, 3, 3, 3, 3, 3,
159         3, 3, 3, 3, 3, 3, 3, 3,
160 };
161
162 int pick_stats_for(int x, int y)
163 {
164         return luma_mapping[y * 8 + x];
165 }
166
167 class RansEncoder {
168 public:
169         RansEncoder()
170         {
171                 out_buf.reset(new uint8_t[out_max_size]);
172                 clear();
173         }
174
175         void init_prob(SymbolStats &s)
176         {
177                 for (int i = 0; i < NUM_SYMS; i++) {
178                         //printf("%d: cumfreqs=%d freqs=%d prob_bits=%d\n", i, s.cum_freqs[i], s.freqs[i], prob_bits + 1);
179                         RansEncSymbolInit(&esyms[i], s.cum_freqs[i], s.freqs[i], prob_bits + 1);
180                 }
181                 sign_bias = s.cum_freqs[NUM_SYMS];
182         }
183
184         void clear()
185         {
186                 out_end = out_buf.get() + out_max_size;
187                 ptr = out_end; // *end* of output buffer
188                 RansEncInit(&rans);
189         }
190
191         uint32_t save_block(FILE *codedfp)  // Returns number of bytes.
192         {
193                 RansEncFlush(&rans, &ptr);
194                 //printf("post-flush = %08x\n", rans);
195
196                 uint32_t num_rans_bytes = out_end - ptr;
197                 if (num_rans_bytes == last_block.size() &&
198                     memcmp(last_block.data(), ptr, last_block.size()) == 0) {
199                         write_varint(0, codedfp);
200                         clear();
201                         return 1;
202                 } else {
203                         last_block = string((const char *)ptr, num_rans_bytes);
204                 }
205
206                 write_varint(num_rans_bytes, codedfp);
207                 //fwrite(&num_rans_bytes, 1, 4, codedfp);
208                 fwrite(ptr, 1, num_rans_bytes, codedfp);
209
210                 //printf("first rANS bytes: %02x %02x %02x %02x %02x %02x %02x %02x\n", ptr[0], ptr[1], ptr[2], ptr[3], ptr[4], ptr[5], ptr[6], ptr[7]);
211
212
213                 clear();
214
215                 //printf("Saving block: %d rANS bytes\n", num_rans_bytes);
216                 return num_rans_bytes;
217                 //return num_rans_bytes;
218         }
219
220         void encode_coeff(short signed_k)
221         {
222                 //printf("encoding coeff %d (sym %d), rans before encoding = %08x\n", signed_k, ((abs(signed_k) - 1) & 255), rans);
223                 unsigned short k = abs(signed_k);
224                 if (k >= ESCAPE_LIMIT) {
225                         // Put the coefficient as a 1/(2^12) symbol _before_
226                         // the 255 coefficient, since the decoder will read the
227                         // 255 coefficient first.
228                         RansEncPut(&rans, &ptr, k, 1, prob_bits);
229                         k = ESCAPE_LIMIT;
230                 }
231                 RansEncPutSymbol(&rans, &ptr, &esyms[(k - 1) & (NUM_SYMS - 1)]);
232                 if (signed_k < 0) {
233                         rans += sign_bias;
234                 }
235         }
236
237 private:
238         static constexpr size_t out_max_size = 32 << 20; // 32 MB.
239         static constexpr size_t max_num_sign = 1048576;  // Way too big. And actually bytes.
240
241         unique_ptr<uint8_t[]> out_buf;
242         uint8_t *out_end;
243         uint8_t *ptr;
244         RansState rans;
245         RansEncSymbol esyms[NUM_SYMS];
246         uint32_t sign_bias;
247
248         std::string last_block;
249 };
250
251 // Should be done on the GPU, of course, but irrelevant for the demonstration.
252 void convert_ycbcr()
253 {
254         double coeff[3] = { 0.2126, 0.7152, 0.0722 };  // sum = 1.0
255         double cb_fac = 1.0 / (coeff[0] + coeff[1] + 1.0f - coeff[2]);  // 0.539
256         double cr_fac = 1.0 / (1.0f - coeff[0] + coeff[1] + coeff[2]);  // 0.635 
257
258         unique_ptr<float[]> temp_cb(new float[WIDTH * HEIGHT]);
259         unique_ptr<float[]> temp_cr(new float[WIDTH * HEIGHT]);
260         for (unsigned yb = 0; yb < HEIGHT; ++yb) {
261                 for (unsigned xb = 0; xb < WIDTH; ++xb) {
262                         int r = rgb[((yb * WIDTH) + xb) * 3 + 0];
263                         int g = rgb[((yb * WIDTH) + xb) * 3 + 1];
264                         int b = rgb[((yb * WIDTH) + xb) * 3 + 2];
265                         double y = std::min(std::max(coeff[0] * r + coeff[1] * g + coeff[2] * b, 0.0), 255.0);
266                         double cb = (b - y) * cb_fac + 128.0;
267                         double cr = (r - y) * cr_fac + 128.0;
268                         pix_y[(yb * WIDTH) + xb] = lrint(y);
269                         temp_cb[(yb * WIDTH) + xb] = cb;
270                         temp_cr[(yb * WIDTH) + xb] = cr;
271                 }
272         }
273
274         // Simple 4:2:2 subsampling with left convention.
275         for (unsigned yb = 0; yb < HEIGHT; ++yb) {
276                 for (unsigned xb = 0; xb < WIDTH / 2; ++xb) {
277                         int c0 = yb * WIDTH + std::max(int(xb) * 2 - 1, 0);
278                         int c1 = yb * WIDTH + xb * 2;
279                         int c2 = yb * WIDTH + xb * 2 + 1;
280                         
281                         double cb = 0.25 * temp_cb[c0] + 0.5 * temp_cb[c1] + 0.25 * temp_cb[c2];
282                         double cr = 0.25 * temp_cr[c0] + 0.5 * temp_cr[c1] + 0.25 * temp_cr[c2];
283                         cb = std::min(std::max(cb, 0.0), 255.0);
284                         cr = std::min(std::max(cr, 0.0), 255.0);
285                         pix_cb[(yb * WIDTH/2) + xb] = lrint(cb);
286                         pix_cr[(yb * WIDTH/2) + xb] = lrint(cr);
287                 }
288         }
289 }
290
291 int main(int argc, char **argv)
292 {
293         // Set up an OpenGL context using SDL.
294         if (SDL_Init(SDL_INIT_VIDEO) == -1) {
295                 fprintf(stderr, "SDL_Init failed: %s\n", SDL_GetError());
296                 exit(1);
297         }
298         SDL_GL_SetAttribute(SDL_GL_DEPTH_SIZE, 0);
299         SDL_GL_SetAttribute(SDL_GL_STENCIL_SIZE, 0);
300         SDL_GL_SetAttribute(SDL_GL_DOUBLEBUFFER, 1);
301         SDL_GL_SetAttribute(SDL_GL_CONTEXT_PROFILE_MASK, SDL_GL_CONTEXT_PROFILE_CORE);
302         SDL_GL_SetAttribute(SDL_GL_CONTEXT_MAJOR_VERSION, 4);
303         SDL_GL_SetAttribute(SDL_GL_CONTEXT_MINOR_VERSION, 5);
304
305         SDL_Window *window = SDL_CreateWindow("OpenGL window for unit test",
306                 SDL_WINDOWPOS_UNDEFINED,
307                 SDL_WINDOWPOS_UNDEFINED,
308                 32, 32,
309                 SDL_WINDOW_OPENGL);
310         SDL_GLContext context = SDL_GL_CreateContext(window);
311         assert(context != nullptr);
312
313         if (argc >= 2)
314                 readpix(rgb, argv[1]);
315         else
316                 readpix(rgb, "color.pnm");
317         convert_ycbcr();
318
319         // Compile the shader.
320         string shader_src = ::read_file("encoder.shader");
321         GLuint shader_num = compile_shader(shader_src, GL_COMPUTE_SHADER);
322         GLuint glsl_program_num = glCreateProgram();
323         glAttachShader(glsl_program_num, shader_num);
324         glLinkProgram(glsl_program_num);
325
326         GLint success;
327         glGetProgramiv(glsl_program_num, GL_LINK_STATUS, &success);
328         if (success == GL_FALSE) {
329                 GLchar error_log[1024] = {0};
330                 glGetProgramInfoLog(glsl_program_num, 1024, nullptr, error_log);
331                 fprintf(stderr, "Error linking program: %s\n", error_log);
332                 exit(1);
333         }
334
335         // Compile the tally shader.
336         shader_src = ::read_file("tally.shader");
337         shader_num = compile_shader(shader_src, GL_COMPUTE_SHADER);
338         GLuint glsl_tally_program_num = glCreateProgram();
339         glAttachShader(glsl_tally_program_num, shader_num);
340         glLinkProgram(glsl_tally_program_num);
341
342         glGetProgramiv(glsl_tally_program_num, GL_LINK_STATUS, &success);
343         if (success == GL_FALSE) {
344                 GLchar error_log[1024] = {0};
345                 glGetProgramInfoLog(glsl_tally_program_num, 1024, nullptr, error_log);
346                 fprintf(stderr, "Error linking program: %s\n", error_log);
347                 exit(1);
348         }
349
350         glUseProgram(glsl_program_num);
351
352         // An SSBO for the rANS distributions.
353         GLuint ssbo;
354         glGenBuffers(1, &ssbo);
355         glBindBuffer(GL_SHADER_STORAGE_BUFFER, ssbo);
356         glBufferData(GL_SHADER_STORAGE_BUFFER, 256 * 4 * sizeof(uint32_t), nullptr, GL_DYNAMIC_COPY);
357         glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 9, ssbo);
358
359         // Upload luma.
360         GLuint y_tex;
361         glGenTextures(1, &y_tex);
362         glBindTexture(GL_TEXTURE_2D, y_tex);
363         glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST);
364         glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST);
365         glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_REPEAT);
366         glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_REPEAT);
367         glTexImage2D(GL_TEXTURE_2D, 0, GL_R8UI, WIDTH, HEIGHT, 0, GL_RED_INTEGER, GL_UNSIGNED_BYTE, pix_y);
368         check_error();
369
370         // Make destination textures.
371         GLuint dc_ac7_tex, ac1_ac6_tex, ac2_ac5_tex;
372         for (GLuint *tex : { &dc_ac7_tex, &ac1_ac6_tex, &ac2_ac5_tex }) {
373                 glGenTextures(1, tex);
374                 glBindTexture(GL_TEXTURE_2D, *tex);
375                 glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST);
376                 glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST);
377                 glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_REPEAT);
378                 glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_REPEAT);
379                 glTexImage2D(GL_TEXTURE_2D, 0, GL_R16UI, WIDTH / 8, HEIGHT, 0, GL_RED_INTEGER, GL_UNSIGNED_SHORT, nullptr);
380                 check_error();
381         }
382
383         GLuint ac3_tex, ac4_tex;
384         for (GLuint *tex : { &ac3_tex, &ac4_tex }) {
385                 glGenTextures(1, tex);
386                 glBindTexture(GL_TEXTURE_2D, *tex);
387                 glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST);
388                 glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST);
389                 glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_REPEAT);
390                 glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_REPEAT);
391                 glTexImage2D(GL_TEXTURE_2D, 0, GL_R8I, WIDTH / 8, HEIGHT, 0, GL_RED_INTEGER, GL_BYTE, nullptr);
392                 check_error();
393         }
394
395         GLint dc_ac7_tex_uniform = glGetUniformLocation(glsl_program_num, "dc_ac7_tex");
396         GLint ac1_ac6_tex_uniform = glGetUniformLocation(glsl_program_num, "ac1_ac6_tex");
397         GLint ac2_ac5_tex_uniform = glGetUniformLocation(glsl_program_num, "ac2_ac5_tex");
398         GLint ac3_tex_uniform = glGetUniformLocation(glsl_program_num, "ac3_tex");
399         GLint ac4_tex_uniform = glGetUniformLocation(glsl_program_num, "ac4_tex");
400         GLint image_tex_uniform = glGetUniformLocation(glsl_program_num, "image_tex");
401
402         glUniform1i(dc_ac7_tex_uniform, 0);
403         glUniform1i(ac1_ac6_tex_uniform, 1);
404         glUniform1i(ac2_ac5_tex_uniform, 2);
405         glUniform1i(ac3_tex_uniform, 3);
406         glUniform1i(ac4_tex_uniform, 4);
407         glUniform1i(image_tex_uniform, 5);
408         glBindImageTexture(0, dc_ac7_tex, 0, GL_FALSE, 0, GL_WRITE_ONLY, GL_R16UI);
409         glBindImageTexture(1, ac1_ac6_tex, 0, GL_FALSE, 0, GL_WRITE_ONLY, GL_R16UI);
410         glBindImageTexture(2, ac2_ac5_tex, 0, GL_FALSE, 0, GL_WRITE_ONLY, GL_R16UI);
411         glBindImageTexture(3, ac3_tex, 0, GL_FALSE, 0, GL_WRITE_ONLY, GL_R8I);
412         glBindImageTexture(4, ac4_tex, 0, GL_FALSE, 0, GL_WRITE_ONLY, GL_R8I);
413         glBindImageTexture(5, y_tex, 0, GL_FALSE, 0, GL_READ_ONLY, GL_R8UI);
414         check_error();
415
416         glUseProgram(glsl_tally_program_num);
417         glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 9, ssbo);
418
419         steady_clock::time_point start = steady_clock::now();
420         unsigned num_iterations = 1000;
421         for (unsigned i = 0; i < num_iterations; ++i) {
422                 glClearNamedBufferSubData(ssbo, GL_R8, 0, 256 * 4 * sizeof(uint32_t), GL_RED, GL_UNSIGNED_BYTE, nullptr);
423                 glUseProgram(glsl_program_num);
424                 glDispatchCompute(WIDTH_BLOCKS / 16, HEIGHT_BLOCKS, 1);
425                 glMemoryBarrier(GL_SHADER_STORAGE_BARRIER_BIT);
426
427                 glUseProgram(glsl_tally_program_num);
428                 glDispatchCompute(4, 1, 1);
429         }
430         check_error();
431         glFinish();
432         steady_clock::time_point now = steady_clock::now();
433
434         // CPU part starts here -- will be GPU later.
435         // We only do luma for now.
436
437         int16_t *coeff_y = new int16_t[WIDTH * HEIGHT];
438
439         glBindTexture(GL_TEXTURE_2D, dc_ac7_tex);
440         uint16_t *dc_ac7_data = new uint16_t[(WIDTH/8) * HEIGHT];
441         glGetTexImage(GL_TEXTURE_2D, 0, GL_RED_INTEGER, GL_UNSIGNED_SHORT, dc_ac7_data);
442         check_error();
443
444         glBindTexture(GL_TEXTURE_2D, ac1_ac6_tex);
445         uint16_t *ac1_ac6_data = new uint16_t[(WIDTH/8) * HEIGHT];
446         glGetTexImage(GL_TEXTURE_2D, 0, GL_RED_INTEGER, GL_UNSIGNED_SHORT, ac1_ac6_data);
447         check_error();
448
449         glBindTexture(GL_TEXTURE_2D, ac2_ac5_tex);
450         uint16_t *ac2_ac5_data = new uint16_t[(WIDTH/8) * HEIGHT];
451         glGetTexImage(GL_TEXTURE_2D, 0, GL_RED_INTEGER, GL_UNSIGNED_SHORT, ac2_ac5_data);
452         check_error();
453
454         glBindTexture(GL_TEXTURE_2D, ac3_tex);
455         int8_t *ac3_data = new int8_t[(WIDTH/8) * HEIGHT];
456         glGetTexImage(GL_TEXTURE_2D, 0, GL_RED_INTEGER, GL_BYTE, ac3_data);
457         check_error();
458
459         glBindTexture(GL_TEXTURE_2D, ac4_tex);
460         int8_t *ac4_data = new int8_t[(WIDTH/8) * HEIGHT];
461         glGetTexImage(GL_TEXTURE_2D, 0, GL_RED_INTEGER, GL_BYTE, ac4_data);
462         check_error();
463
464         for (unsigned y = 0; y < HEIGHT; ++y) {
465                 for (unsigned xb = 0; xb < WIDTH/8; ++xb) {
466                         coeff_y[y * WIDTH + xb*8 + 0] = int(dc_ac7_data[y * (WIDTH/8) + xb] << 23) >> 23;
467                         coeff_y[y * WIDTH + xb*8 + 7] = int(dc_ac7_data[y * (WIDTH/8) + xb] << 16) >> 25;
468                         coeff_y[y * WIDTH + xb*8 + 1] = int(ac1_ac6_data[y * (WIDTH/8) + xb] << 23) >> 23;
469                         coeff_y[y * WIDTH + xb*8 + 6] = int(ac1_ac6_data[y * (WIDTH/8) + xb] << 16) >> 25;
470                         coeff_y[y * WIDTH + xb*8 + 2] = int(ac2_ac5_data[y * (WIDTH/8) + xb] << 23) >> 23;
471                         coeff_y[y * WIDTH + xb*8 + 5] = int(ac2_ac5_data[y * (WIDTH/8) + xb] << 16) >> 25;
472                         coeff_y[y * WIDTH + xb*8 + 3] = ac3_data[y * (WIDTH/8) + xb];
473                         coeff_y[y * WIDTH + xb*8 + 4] = ac4_data[y * (WIDTH/8) + xb];
474                 }
475         }
476
477 #if 0
478         for (unsigned y = 0; y < HEIGHT; ++y) {
479                 for (unsigned xb = 0; xb < WIDTH/8; ++xb) {
480                         printf("%4d %4d %4d %4d %4d %4d %4d %4d | ",
481                                 coeff_y[y * WIDTH + xb*8 + 0],
482                                 coeff_y[y * WIDTH + xb*8 + 1],
483                                 coeff_y[y * WIDTH + xb*8 + 2],
484                                 coeff_y[y * WIDTH + xb*8 + 3],
485                                 coeff_y[y * WIDTH + xb*8 + 4],
486                                 coeff_y[y * WIDTH + xb*8 + 5],
487                                 coeff_y[y * WIDTH + xb*8 + 6],
488                                 coeff_y[y * WIDTH + xb*8 + 7]);
489                         printf("%4d %4d %4d %4d %4d %4d %4d %4d || ",
490                                 pix_y[y * WIDTH + xb*8 + 0],
491                                 pix_y[y * WIDTH + xb*8 + 1],
492                                 pix_y[y * WIDTH + xb*8 + 2],
493                                 pix_y[y * WIDTH + xb*8 + 3],
494                                 pix_y[y * WIDTH + xb*8 + 4],
495                                 pix_y[y * WIDTH + xb*8 + 5],
496                                 pix_y[y * WIDTH + xb*8 + 6],
497                                 pix_y[y * WIDTH + xb*8 + 7]);
498                 }
499                 printf("\n");
500         }
501 #endif
502
503         // DC coefficient pred from the right to left (within each slice)
504         for (unsigned block_idx = 0; block_idx < NUM_BLOCKS; block_idx += BLOCKS_PER_STREAM) {
505                 int prev_k = 128;
506
507                 for (unsigned subblock_idx = BLOCKS_PER_STREAM; subblock_idx --> 0; ) {
508                         unsigned yb = (block_idx + subblock_idx) / WIDTH_BLOCKS;
509                         unsigned xb = (block_idx + subblock_idx) % WIDTH_BLOCKS;
510                         int k = coeff_y[(yb * 8) * WIDTH + xb * 8];
511
512                         coeff_y[(yb * 8) * WIDTH + xb * 8] = k - prev_k;
513
514                         prev_k = k;
515                 }
516         }
517
518         // For each coefficient, make some tables.
519         size_t extra_bits = 0;
520         for (unsigned i = 0; i < 64; ++i) {
521                 stats[i].clear();
522         }
523         for (unsigned y = 0; y < 8; ++y) {
524                 for (unsigned x = 0; x < 8; ++x) {
525                         SymbolStats &s_luma = stats[pick_stats_for(x, y)];
526
527                         // Luma
528                         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
529                                 for (unsigned xb = 0; xb < WIDTH; xb += 8) {
530                                         unsigned short k = abs(coeff_y[(yb + y) * WIDTH + (xb + x)]);
531                                         if (k >= ESCAPE_LIMIT) {
532                                                 k = ESCAPE_LIMIT;
533                                                 extra_bits += 12;  // escape this one
534                                         }
535                                         ++s_luma.freqs[(k - 1) & (NUM_SYMS - 1)];
536                                 }
537                         }
538                 }
539         }
540
541         for (unsigned i = 0; i < 64; ++i) {
542                 stats[i].freqs[NUM_SYMS - 1] /= 2;  // zero, has no sign bits (yes, this is trickery)
543                 stats[i].normalize_freqs(prob_scale);
544                 stats[i].cum_freqs[NUM_SYMS] += stats[i].freqs[NUM_SYMS - 1];
545                 stats[i].freqs[NUM_SYMS - 1] *= 2;
546         }
547
548         FILE *codedfp = fopen("coded.dat", "wb");
549         if (codedfp == nullptr) {
550                 perror("coded.dat");
551                 exit(1);
552         }
553
554         for (unsigned r = 0; r < 2; ++r) {  // Hack to write fake chroma tables.
555                 // TODO: rather gamma-k or something
556                 for (unsigned i = 0; i < 64; ++i) {
557                         if (stats[i].cum_freqs[NUM_SYMS] == 0) {
558                                 continue;
559                         }
560                         printf("writing table %d\n", i);
561                         for (unsigned j = 0; j < NUM_SYMS; ++j) {
562                                 write_varint(stats[i].freqs[j], codedfp);
563                         }
564                 }
565         }
566
567         RansEncoder rans_encoder;
568
569         size_t tot_bytes = 0;
570
571         // Luma
572         for (unsigned y = 0; y < 8; ++y) {
573                 for (unsigned x = 0; x < 8; ++x) {
574                         SymbolStats &s_luma = stats[pick_stats_for(x, y)];
575                         rans_encoder.init_prob(s_luma);
576
577                         // Luma
578                         std::vector<int> lens;
579
580                         rans_encoder.clear();
581                         size_t num_bytes = 0;
582                         for (unsigned block_idx = 0; block_idx < NUM_BLOCKS; ++block_idx) {
583                                 unsigned yb = block_idx / WIDTH_BLOCKS;
584                                 unsigned xb = block_idx % WIDTH_BLOCKS;
585
586                                 int k = coeff_y[(yb * 8 + y) * WIDTH + (xb * 8 + x)];
587                                 rans_encoder.encode_coeff(k);
588
589                                 if (block_idx % BLOCKS_PER_STREAM == (BLOCKS_PER_STREAM - 1) || block_idx == NUM_BLOCKS - 1) {
590                                         int l = rans_encoder.save_block(codedfp);
591                                         num_bytes += l;
592                                         lens.push_back(l);
593                                 }
594                         }
595                         tot_bytes += num_bytes;
596                         printf("coeff %d Y': %ld bytes\n", y * 8 + x, num_bytes);
597                 }
598         }
599
600         printf("%ld bytes + %ld escape bits (%ld) = %ld total bytes\n",
601                 tot_bytes - extra_bits / 8,
602                 extra_bits,
603                 extra_bits / 8,
604                 tot_bytes);
605
606         printf("\n");
607         printf("Each iteration took %.3f ms (but note that is DCT only, no rANS).\n", 1e3 * duration<double>(now - start).count() / num_iterations);
608
609 #if 1
610         glBindBuffer(GL_SHADER_STORAGE_BUFFER, ssbo);
611         const uint32_t *dist = (const uint32_t *)glMapBuffer(GL_SHADER_STORAGE_BUFFER, GL_READ_ONLY);
612         for (int i = 0; i < 1024; ++i) {
613                 printf("%d,%d: %u\n", i / 256, i % 256, dist[i]);
614         }
615 #endif
616 }