]> git.sesse.net Git - narabu/blob - narabu-encoder.cpp
More fixes of hard-coded values.
[narabu] / narabu-encoder.cpp
1 #include <stdio.h>
2 #include <stdint.h>
3 #include <stdlib.h>
4 #include <string.h>
5 #include <assert.h>
6 #include <math.h>
7 #include <SDL2/SDL.h>
8 #include <SDL2/SDL_error.h>
9 #include <SDL2/SDL_video.h>
10 #include <epoxy/gl.h>
11
12 #include <algorithm>
13 #include <chrono>
14 #include <memory>
15 #include <numeric>
16 #include <random>
17 #include <vector>
18 #include <unordered_map>
19
20 #include <movit/util.h>
21
22 #include "util.h"
23
24 #define WIDTH 1280
25 #define HEIGHT 720
26 #define WIDTH_BLOCKS (WIDTH/8)
27 #define WIDTH_BLOCKS_CHROMA (WIDTH/16)
28 #define HEIGHT_BLOCKS (HEIGHT/8)
29 #define NUM_BLOCKS (WIDTH_BLOCKS * HEIGHT_BLOCKS)
30 #define NUM_BLOCKS_CHROMA (WIDTH_BLOCKS_CHROMA * HEIGHT_BLOCKS)
31
32 #define NUM_SYMS 256
33 #define ESCAPE_LIMIT (NUM_SYMS - 1)
34 #define BLOCKS_PER_STREAM 320
35 #define STREAM_BUF_SIZE 1024  // In bytes.
36
37 #define NUM_STREAMS ((NUM_BLOCKS + BLOCKS_PER_STREAM - 1) / BLOCKS_PER_STREAM)
38
39 static constexpr uint32_t prob_bits = 12;
40 static constexpr uint32_t prob_scale = 1 << prob_bits;
41
42 unsigned char rgb[WIDTH * HEIGHT * 3];
43 unsigned char pix_y[WIDTH * HEIGHT];
44 unsigned char pix_cb[(WIDTH/2) * HEIGHT];
45 unsigned char pix_cr[(WIDTH/2) * HEIGHT];
46
47 struct RansCountSSBO {
48         unsigned dist[4 * 256];
49         unsigned ransfreq[4 * 256];
50 };
51
52 struct RansDistUBO {
53         struct {
54                 uint32_t x_max, rcp_freq, bias, rcp_shift_and_cmpl_freq;
55         } ransdist[4 * 256];
56         struct {
57                 uint32_t val;
58                 uint32_t padding[3];  // std140 layout.
59         } sign_biases[4];
60 };
61
62 using namespace std;
63 using namespace std::chrono;
64
65 void write_varint(int x, FILE *fp)
66 {
67         while (x >= 128) {
68                 putc((x & 0x7f) | 0x80, fp);
69                 x >>= 7;
70         }
71         putc(x, fp);
72 }
73
74 void readpix(unsigned char *ptr, const char *filename)
75 {
76         FILE *fp = fopen(filename, "rb");
77         if (fp == nullptr) {
78                 perror(filename);
79                 exit(1);
80         }
81
82         fseek(fp, 0, SEEK_END);
83         long len = ftell(fp);
84         assert(len >= WIDTH * HEIGHT * 3);
85         fseek(fp, len - WIDTH * HEIGHT * 3, SEEK_SET);
86
87         fread(ptr, 1, WIDTH * HEIGHT * 3, fp);
88         fclose(fp);
89 }
90
91 // Should be done on the GPU, of course, but irrelevant for the demonstration.
92 void convert_ycbcr()
93 {
94         double coeff[3] = { 0.2126, 0.7152, 0.0722 };  // sum = 1.0
95         double cb_fac = 1.0 / (coeff[0] + coeff[1] + 1.0f - coeff[2]);  // 0.539
96         double cr_fac = 1.0 / (1.0f - coeff[0] + coeff[1] + coeff[2]);  // 0.635 
97
98         unique_ptr<float[]> temp_cb(new float[WIDTH * HEIGHT]);
99         unique_ptr<float[]> temp_cr(new float[WIDTH * HEIGHT]);
100         for (unsigned yb = 0; yb < HEIGHT; ++yb) {
101                 for (unsigned xb = 0; xb < WIDTH; ++xb) {
102                         int r = rgb[((yb * WIDTH) + xb) * 3 + 0];
103                         int g = rgb[((yb * WIDTH) + xb) * 3 + 1];
104                         int b = rgb[((yb * WIDTH) + xb) * 3 + 2];
105                         double y = std::min(std::max(coeff[0] * r + coeff[1] * g + coeff[2] * b, 0.0), 255.0);
106                         double cb = (b - y) * cb_fac + 128.0;
107                         double cr = (r - y) * cr_fac + 128.0;
108                         pix_y[(yb * WIDTH) + xb] = lrint(y);
109                         temp_cb[(yb * WIDTH) + xb] = cb;
110                         temp_cr[(yb * WIDTH) + xb] = cr;
111                 }
112         }
113
114         // Simple 4:2:2 subsampling with left convention.
115         for (unsigned yb = 0; yb < HEIGHT; ++yb) {
116                 for (unsigned xb = 0; xb < WIDTH / 2; ++xb) {
117                         int c0 = yb * WIDTH + std::max(int(xb) * 2 - 1, 0);
118                         int c1 = yb * WIDTH + xb * 2;
119                         int c2 = yb * WIDTH + xb * 2 + 1;
120                         
121                         double cb = 0.25 * temp_cb[c0] + 0.5 * temp_cb[c1] + 0.25 * temp_cb[c2];
122                         double cr = 0.25 * temp_cr[c0] + 0.5 * temp_cr[c1] + 0.25 * temp_cr[c2];
123                         cb = std::min(std::max(cb, 0.0), 255.0);
124                         cr = std::min(std::max(cr, 0.0), 255.0);
125                         pix_cb[(yb * WIDTH/2) + xb] = lrint(cb);
126                         pix_cr[(yb * WIDTH/2) + xb] = lrint(cr);
127                 }
128         }
129 }
130
131 int main(int argc, char **argv)
132 {
133         // Set up an OpenGL context using SDL.
134         if (SDL_Init(SDL_INIT_VIDEO) == -1) {
135                 fprintf(stderr, "SDL_Init failed: %s\n", SDL_GetError());
136                 exit(1);
137         }
138         SDL_GL_SetAttribute(SDL_GL_DEPTH_SIZE, 0);
139         SDL_GL_SetAttribute(SDL_GL_STENCIL_SIZE, 0);
140         SDL_GL_SetAttribute(SDL_GL_DOUBLEBUFFER, 1);
141         SDL_GL_SetAttribute(SDL_GL_CONTEXT_PROFILE_MASK, SDL_GL_CONTEXT_PROFILE_CORE);
142         SDL_GL_SetAttribute(SDL_GL_CONTEXT_MAJOR_VERSION, 4);
143         SDL_GL_SetAttribute(SDL_GL_CONTEXT_MINOR_VERSION, 5);
144
145         SDL_Window *window = SDL_CreateWindow("OpenGL window for unit test",
146                 SDL_WINDOWPOS_UNDEFINED,
147                 SDL_WINDOWPOS_UNDEFINED,
148                 32, 32,
149                 SDL_WINDOW_OPENGL);
150         SDL_GLContext context = SDL_GL_CreateContext(window);
151         assert(context != nullptr);
152
153         if (argc >= 2)
154                 readpix(rgb, argv[1]);
155         else
156                 readpix(rgb, "color.pnm");
157         convert_ycbcr();
158
159         // Compile the DCT shader.
160         string shader_src = ::read_file("encoder.shader");
161         GLuint shader_num = compile_shader(shader_src, GL_COMPUTE_SHADER);
162         GLuint glsl_program_num = glCreateProgram();
163         glAttachShader(glsl_program_num, shader_num);
164         glLinkProgram(glsl_program_num);
165
166         GLint success;
167         glGetProgramiv(glsl_program_num, GL_LINK_STATUS, &success);
168         if (success == GL_FALSE) {
169                 GLchar error_log[1024] = {0};
170                 glGetProgramInfoLog(glsl_program_num, 1024, nullptr, error_log);
171                 fprintf(stderr, "Error linking program: %s\n", error_log);
172                 exit(1);
173         }
174
175         // Compile the tally shader.
176         shader_src = ::read_file("tally.shader");
177         shader_num = compile_shader(shader_src, GL_COMPUTE_SHADER);
178         GLuint glsl_tally_program_num = glCreateProgram();
179         glAttachShader(glsl_tally_program_num, shader_num);
180         glLinkProgram(glsl_tally_program_num);
181
182         glGetProgramiv(glsl_tally_program_num, GL_LINK_STATUS, &success);
183         if (success == GL_FALSE) {
184                 GLchar error_log[1024] = {0};
185                 glGetProgramInfoLog(glsl_tally_program_num, 1024, nullptr, error_log);
186                 fprintf(stderr, "Error linking program: %s\n", error_log);
187                 exit(1);
188         }
189
190         // Compile the rANS shader.
191         shader_src = ::read_file("rans.shader");
192         shader_num = compile_shader(shader_src, GL_COMPUTE_SHADER);
193         GLuint glsl_rans_program_num = glCreateProgram();
194         glAttachShader(glsl_rans_program_num, shader_num);
195         glLinkProgram(glsl_rans_program_num);
196
197         glGetProgramiv(glsl_rans_program_num, GL_LINK_STATUS, &success);
198         if (success == GL_FALSE) {
199                 GLchar error_log[1024] = {0};
200                 glGetProgramInfoLog(glsl_rans_program_num, 1024, nullptr, error_log);
201                 fprintf(stderr, "Error linking program: %s\n", error_log);
202                 exit(1);
203         }
204         check_error();
205
206         // An SSBO for the raw rANS counts.
207         GLuint ssbo;
208         glGenBuffers(1, &ssbo);
209         glBindBuffer(GL_SHADER_STORAGE_BUFFER, ssbo);
210         glNamedBufferStorage(ssbo, sizeof(RansCountSSBO), nullptr, GL_MAP_READ_BIT | GL_MAP_PERSISTENT_BIT);
211         check_error();
212
213         // UBO for the rANS distributions (copied from an SSBO).
214         GLuint dist_ssbo;
215         glGenBuffers(1, &dist_ssbo);
216         glBindBuffer(GL_SHADER_STORAGE_BUFFER, dist_ssbo);
217         glNamedBufferStorage(dist_ssbo, sizeof(RansDistUBO), nullptr, GL_MAP_READ_BIT | GL_MAP_PERSISTENT_BIT);
218         check_error();
219
220         GLuint dist_ubo;
221         glGenBuffers(1, &dist_ubo);
222         glBindBuffer(GL_UNIFORM_BUFFER, dist_ubo);
223         glNamedBufferStorage(dist_ubo, sizeof(RansDistUBO), nullptr, GL_MAP_READ_BIT | GL_MAP_PERSISTENT_BIT);
224         check_error();
225
226         // SSBOs for the rANS output (data and offsets).
227         GLuint output_ssbo;
228         glGenBuffers(1, &output_ssbo);
229         glBindBuffer(GL_SHADER_STORAGE_BUFFER, output_ssbo);
230         glNamedBufferStorage(output_ssbo, 64 * NUM_STREAMS * STREAM_BUF_SIZE, nullptr, GL_MAP_READ_BIT | GL_MAP_PERSISTENT_BIT);
231         check_error();
232
233         GLuint bytes_written_ssbo;
234         glGenBuffers(1, &bytes_written_ssbo);
235         glBindBuffer(GL_SHADER_STORAGE_BUFFER, bytes_written_ssbo);
236         glNamedBufferStorage(bytes_written_ssbo, 64 * NUM_STREAMS * sizeof(uint32_t), nullptr, GL_MAP_READ_BIT | GL_MAP_PERSISTENT_BIT);
237         check_error();
238
239         // Bind SSBOs.
240         glUseProgram(glsl_program_num);
241         glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 9, ssbo);
242
243         glUseProgram(glsl_tally_program_num);
244         glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 9, ssbo);
245         glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 12, dist_ssbo);
246
247         glUseProgram(glsl_rans_program_num);
248         glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 10, output_ssbo);
249         glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 11, bytes_written_ssbo);
250         glBindBufferBase(GL_UNIFORM_BUFFER, 13, dist_ubo);
251
252         glUseProgram(glsl_program_num);
253         check_error();
254
255         // Upload luma.
256         GLuint y_tex;
257         glGenTextures(1, &y_tex);
258         glBindTexture(GL_TEXTURE_2D, y_tex);
259         glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST);
260         glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST);
261         glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_REPEAT);
262         glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_REPEAT);
263         glTexImage2D(GL_TEXTURE_2D, 0, GL_R8UI, WIDTH, HEIGHT, 0, GL_RED_INTEGER, GL_UNSIGNED_BYTE, pix_y);
264         check_error();
265
266         // Make destination textures.
267         GLuint dc_ac7_tex, ac1_ac6_tex, ac2_ac5_tex;
268         for (GLuint *tex : { &dc_ac7_tex, &ac1_ac6_tex, &ac2_ac5_tex }) {
269                 glGenTextures(1, tex);
270                 glBindTexture(GL_TEXTURE_2D, *tex);
271                 glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST);
272                 glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST);
273                 glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_REPEAT);
274                 glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_REPEAT);
275                 glTexImage2D(GL_TEXTURE_2D, 0, GL_R16UI, WIDTH / 8, HEIGHT, 0, GL_RED_INTEGER, GL_UNSIGNED_SHORT, nullptr);
276                 check_error();
277         }
278
279         GLuint ac3_tex, ac4_tex;
280         for (GLuint *tex : { &ac3_tex, &ac4_tex }) {
281                 glGenTextures(1, tex);
282                 glBindTexture(GL_TEXTURE_2D, *tex);
283                 glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST);
284                 glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST);
285                 glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_REPEAT);
286                 glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_REPEAT);
287                 glTexImage2D(GL_TEXTURE_2D, 0, GL_R8I, WIDTH / 8, HEIGHT, 0, GL_RED_INTEGER, GL_BYTE, nullptr);
288                 check_error();
289         }
290
291         glBindImageTexture(0, dc_ac7_tex, 0, GL_FALSE, 0, GL_READ_WRITE, GL_R16UI);
292         glBindImageTexture(1, ac1_ac6_tex, 0, GL_FALSE, 0, GL_READ_WRITE, GL_R16UI);
293         glBindImageTexture(2, ac2_ac5_tex, 0, GL_FALSE, 0, GL_READ_WRITE, GL_R16UI);
294         glBindImageTexture(3, ac3_tex, 0, GL_FALSE, 0, GL_READ_WRITE, GL_R8I);
295         glBindImageTexture(4, ac4_tex, 0, GL_FALSE, 0, GL_READ_WRITE, GL_R8I);
296         glBindImageTexture(5, y_tex, 0, GL_FALSE, 0, GL_READ_ONLY, GL_R8UI);
297         check_error();
298
299         // Bind uniforms.
300         glUseProgram(glsl_program_num);
301         GLint dc_ac7_tex_uniform = glGetUniformLocation(glsl_program_num, "dc_ac7_tex");
302         GLint ac1_ac6_tex_uniform = glGetUniformLocation(glsl_program_num, "ac1_ac6_tex");
303         GLint ac2_ac5_tex_uniform = glGetUniformLocation(glsl_program_num, "ac2_ac5_tex");
304         GLint ac3_tex_uniform = glGetUniformLocation(glsl_program_num, "ac3_tex");
305         GLint ac4_tex_uniform = glGetUniformLocation(glsl_program_num, "ac4_tex");
306         GLint image_tex_uniform = glGetUniformLocation(glsl_program_num, "image_tex");
307         glUniform1i(dc_ac7_tex_uniform, 0);
308         glUniform1i(ac1_ac6_tex_uniform, 1);
309         glUniform1i(ac2_ac5_tex_uniform, 2);
310         glUniform1i(ac3_tex_uniform, 3);
311         glUniform1i(ac4_tex_uniform, 4);
312         glUniform1i(image_tex_uniform, 5);
313
314         glUseProgram(glsl_rans_program_num);
315         dc_ac7_tex_uniform = glGetUniformLocation(glsl_rans_program_num, "dc_ac7_tex");
316         ac1_ac6_tex_uniform = glGetUniformLocation(glsl_rans_program_num, "ac1_ac6_tex");
317         ac2_ac5_tex_uniform = glGetUniformLocation(glsl_rans_program_num, "ac2_ac5_tex");
318         ac3_tex_uniform = glGetUniformLocation(glsl_rans_program_num, "ac3_tex");
319         ac4_tex_uniform = glGetUniformLocation(glsl_rans_program_num, "ac4_tex");
320         image_tex_uniform = glGetUniformLocation(glsl_rans_program_num, "image_tex");
321         glUniform1i(dc_ac7_tex_uniform, 0);
322         glUniform1i(ac1_ac6_tex_uniform, 1);
323         glUniform1i(ac2_ac5_tex_uniform, 2);
324         glUniform1i(ac3_tex_uniform, 3);
325         glUniform1i(ac4_tex_uniform, 4);
326
327         steady_clock::time_point start = steady_clock::now();
328         unsigned num_iterations = 100;
329         for (unsigned i = 0; i < num_iterations; ++i) {
330                 glClearNamedBufferSubData(ssbo, GL_R8, 0, sizeof(RansCountSSBO), GL_RED, GL_UNSIGNED_BYTE, nullptr);
331                 glUseProgram(glsl_program_num);
332                 glDispatchCompute(WIDTH_BLOCKS / 16, HEIGHT_BLOCKS, 1);
333                 glMemoryBarrier(GL_SHADER_STORAGE_BARRIER_BIT);
334
335                 glUseProgram(glsl_tally_program_num);
336                 glDispatchCompute(4, 1, 1);
337                 glMemoryBarrier(GL_SHADER_STORAGE_BARRIER_BIT);
338         
339                 glCopyNamedBufferSubData(dist_ssbo, dist_ubo, 0, 0, sizeof(RansDistUBO));
340                 glMemoryBarrier(GL_UNIFORM_BARRIER_BIT);
341
342                 glUseProgram(glsl_rans_program_num);
343                 glDispatchCompute(NUM_STREAMS, 8, 5);
344         }
345         check_error();
346         glFinish();
347         check_error();
348         steady_clock::time_point now = steady_clock::now();
349
350 #if 0
351         printf("%ld bytes + %ld escape bits (%ld) = %ld total bytes\n",
352                 tot_bytes - extra_bits / 8,
353                 extra_bits,
354                 extra_bits / 8,
355                 tot_bytes);
356
357         printf("\n");
358 #endif
359
360         printf("Each iteration took %.3f ms.\n", 1e3 * duration<double>(now - start).count() / num_iterations);
361
362         FILE *codedfp = fopen("coded.dat", "wb");
363         if (codedfp == nullptr) {
364                 perror("coded.dat");
365                 exit(1);
366         }
367
368         // Write out the distributions.
369         const RansCountSSBO *rans_count = (const RansCountSSBO *)glMapNamedBufferRange(ssbo, 0, sizeof(RansCountSSBO), GL_MAP_READ_BIT | GL_MAP_PERSISTENT_BIT);
370         const RansDistUBO *rans_dist = (const RansDistUBO *)glMapNamedBufferRange(dist_ssbo, 0, sizeof(RansDistUBO), GL_MAP_READ_BIT | GL_MAP_PERSISTENT_BIT);
371         for (unsigned r = 0; r < 2; ++r) {  // Hack to write fake chroma tables.
372                 // TODO: rather gamma-k or something
373                 for (unsigned i = 0; i < 4; ++i) {
374                         printf("writing table %d\n", i);
375                         for (unsigned j = 0; j < NUM_SYMS; ++j) {
376                                 printf("%d,%d: freq=%d  x_max=%d, rcp_freq=%08x, bias=%d, rcp_shift=%d, cmpl_freq=%d\n",
377                                         i, j, rans_count->ransfreq[i * 256 + j],
378                                         rans_dist->ransdist[i * 256 + j].x_max,
379                                         rans_dist->ransdist[i * 256 + j].rcp_freq,
380                                         rans_dist->ransdist[i * 256 + j].bias,
381                                         rans_dist->ransdist[i * 256 + j].rcp_shift_and_cmpl_freq & 0xffff,
382                                         rans_dist->ransdist[i * 256 + j].rcp_shift_and_cmpl_freq >> 16);
383                                 write_varint(rans_count->ransfreq[i * 256 + j], codedfp);
384                         }
385                 }
386         }
387
388         // Write out the actual data.
389
390         const uint32_t *bytes_written = (const uint32_t *)glMapNamedBufferRange(bytes_written_ssbo, 0, 64 * NUM_STREAMS * sizeof(uint32_t), GL_MAP_READ_BIT | GL_MAP_PERSISTENT_BIT);
391 #if 0
392         for (int i = 0; i < HEIGHT_BLOCKS*64; ++i) {
393                 printf("%d,%d,%d: %u\n", i / 64, (i / 8) % 8, i % 8, 1024 * (i + 1) - offsets[i]);
394         }
395 #endif
396
397         const uint8_t *data = (const uint8_t *)glMapNamedBufferRange(output_ssbo, 0, 64 * NUM_STREAMS * STREAM_BUF_SIZE, GL_MAP_READ_BIT | GL_MAP_PERSISTENT_BIT);
398
399         string last_block;
400         for (unsigned y = 0; y < 8; ++y) {
401                 for (unsigned x = 0; x < 8; ++x) {
402                         for (unsigned int stream_idx = 0; stream_idx < NUM_STREAMS; ++stream_idx) {
403                                 const uint8_t *out_end = data + (stream_idx * 64 + y * 8 + x + 1) * STREAM_BUF_SIZE;
404                                 uint32_t num_rans_bytes = bytes_written[stream_idx * 64 + y * 8 + x];
405                                 const uint8_t *ptr = out_end - num_rans_bytes;
406                                 assert(num_rans_bytes <= STREAM_BUF_SIZE);
407
408                                 if (num_rans_bytes == last_block.size() &&
409                                     memcmp(last_block.data(), ptr, last_block.size()) == 0) {
410                                         write_varint(0, codedfp);
411                                 } else {
412                                         last_block = string((const char *)ptr, num_rans_bytes);
413                                         write_varint(num_rans_bytes, codedfp);
414                                         fwrite(ptr, 1, num_rans_bytes, codedfp);
415                                 }
416                         }
417                 }
418         }
419         fclose(codedfp);
420 }