]> git.sesse.net Git - narabu/blob - narabu-encoder.cpp
Make the encoder 100% GPU. Not working yet, though.
[narabu] / narabu-encoder.cpp
1 #include <stdio.h>
2 #include <stdint.h>
3 #include <stdlib.h>
4 #include <string.h>
5 #include <assert.h>
6 #include <math.h>
7 #include <SDL2/SDL.h>
8 #include <SDL2/SDL_error.h>
9 #include <SDL2/SDL_video.h>
10 #include <epoxy/gl.h>
11
12 #include <algorithm>
13 #include <chrono>
14 #include <memory>
15 #include <numeric>
16 #include <random>
17 #include <vector>
18 #include <unordered_map>
19
20 #include <movit/util.h>
21
22 #include "ryg_rans/rans_byte.h"
23 #include "ryg_rans/renormalize.h"
24 #include "util.h"
25
26 #define WIDTH 1280
27 #define HEIGHT 720
28 #define WIDTH_BLOCKS (WIDTH/8)
29 #define WIDTH_BLOCKS_CHROMA (WIDTH/16)
30 #define HEIGHT_BLOCKS (HEIGHT/8)
31 #define NUM_BLOCKS (WIDTH_BLOCKS * HEIGHT_BLOCKS)
32 #define NUM_BLOCKS_CHROMA (WIDTH_BLOCKS_CHROMA * HEIGHT_BLOCKS)
33
34 #define NUM_SYMS 256
35 #define ESCAPE_LIMIT (NUM_SYMS - 1)
36 #define BLOCKS_PER_STREAM 320
37
38 static constexpr uint32_t prob_bits = 12;
39 static constexpr uint32_t prob_scale = 1 << prob_bits;
40
41 unsigned char rgb[WIDTH * HEIGHT * 3];
42 unsigned char pix_y[WIDTH * HEIGHT];
43 unsigned char pix_cb[(WIDTH/2) * HEIGHT];
44 unsigned char pix_cr[(WIDTH/2) * HEIGHT];
45
46 struct RansDistSSBO {
47         unsigned dist[4 * 256];
48         std::pair<unsigned, unsigned> ransdist[4 * 256];
49 };
50
51 using namespace std;
52 using namespace std::chrono;
53
54 void write_varint(int x, FILE *fp)
55 {
56         while (x >= 128) {
57                 putc((x & 0x7f) | 0x80, fp);
58                 x >>= 7;
59         }
60         putc(x, fp);
61 }
62
63 void readpix(unsigned char *ptr, const char *filename)
64 {
65         FILE *fp = fopen(filename, "rb");
66         if (fp == nullptr) {
67                 perror(filename);
68                 exit(1);
69         }
70
71         fseek(fp, 0, SEEK_END);
72         long len = ftell(fp);
73         assert(len >= WIDTH * HEIGHT * 3);
74         fseek(fp, len - WIDTH * HEIGHT * 3, SEEK_SET);
75
76         fread(ptr, 1, WIDTH * HEIGHT * 3, fp);
77         fclose(fp);
78 }
79
80 // Should be done on the GPU, of course, but irrelevant for the demonstration.
81 void convert_ycbcr()
82 {
83         double coeff[3] = { 0.2126, 0.7152, 0.0722 };  // sum = 1.0
84         double cb_fac = 1.0 / (coeff[0] + coeff[1] + 1.0f - coeff[2]);  // 0.539
85         double cr_fac = 1.0 / (1.0f - coeff[0] + coeff[1] + coeff[2]);  // 0.635 
86
87         unique_ptr<float[]> temp_cb(new float[WIDTH * HEIGHT]);
88         unique_ptr<float[]> temp_cr(new float[WIDTH * HEIGHT]);
89         for (unsigned yb = 0; yb < HEIGHT; ++yb) {
90                 for (unsigned xb = 0; xb < WIDTH; ++xb) {
91                         int r = rgb[((yb * WIDTH) + xb) * 3 + 0];
92                         int g = rgb[((yb * WIDTH) + xb) * 3 + 1];
93                         int b = rgb[((yb * WIDTH) + xb) * 3 + 2];
94                         double y = std::min(std::max(coeff[0] * r + coeff[1] * g + coeff[2] * b, 0.0), 255.0);
95                         double cb = (b - y) * cb_fac + 128.0;
96                         double cr = (r - y) * cr_fac + 128.0;
97                         pix_y[(yb * WIDTH) + xb] = lrint(y);
98                         temp_cb[(yb * WIDTH) + xb] = cb;
99                         temp_cr[(yb * WIDTH) + xb] = cr;
100                 }
101         }
102
103         // Simple 4:2:2 subsampling with left convention.
104         for (unsigned yb = 0; yb < HEIGHT; ++yb) {
105                 for (unsigned xb = 0; xb < WIDTH / 2; ++xb) {
106                         int c0 = yb * WIDTH + std::max(int(xb) * 2 - 1, 0);
107                         int c1 = yb * WIDTH + xb * 2;
108                         int c2 = yb * WIDTH + xb * 2 + 1;
109                         
110                         double cb = 0.25 * temp_cb[c0] + 0.5 * temp_cb[c1] + 0.25 * temp_cb[c2];
111                         double cr = 0.25 * temp_cr[c0] + 0.5 * temp_cr[c1] + 0.25 * temp_cr[c2];
112                         cb = std::min(std::max(cb, 0.0), 255.0);
113                         cr = std::min(std::max(cr, 0.0), 255.0);
114                         pix_cb[(yb * WIDTH/2) + xb] = lrint(cb);
115                         pix_cr[(yb * WIDTH/2) + xb] = lrint(cr);
116                 }
117         }
118 }
119
120 int main(int argc, char **argv)
121 {
122         // Set up an OpenGL context using SDL.
123         if (SDL_Init(SDL_INIT_VIDEO) == -1) {
124                 fprintf(stderr, "SDL_Init failed: %s\n", SDL_GetError());
125                 exit(1);
126         }
127         SDL_GL_SetAttribute(SDL_GL_DEPTH_SIZE, 0);
128         SDL_GL_SetAttribute(SDL_GL_STENCIL_SIZE, 0);
129         SDL_GL_SetAttribute(SDL_GL_DOUBLEBUFFER, 1);
130         SDL_GL_SetAttribute(SDL_GL_CONTEXT_PROFILE_MASK, SDL_GL_CONTEXT_PROFILE_CORE);
131         SDL_GL_SetAttribute(SDL_GL_CONTEXT_MAJOR_VERSION, 4);
132         SDL_GL_SetAttribute(SDL_GL_CONTEXT_MINOR_VERSION, 5);
133
134         SDL_Window *window = SDL_CreateWindow("OpenGL window for unit test",
135                 SDL_WINDOWPOS_UNDEFINED,
136                 SDL_WINDOWPOS_UNDEFINED,
137                 32, 32,
138                 SDL_WINDOW_OPENGL);
139         SDL_GLContext context = SDL_GL_CreateContext(window);
140         assert(context != nullptr);
141
142         if (argc >= 2)
143                 readpix(rgb, argv[1]);
144         else
145                 readpix(rgb, "color.pnm");
146         convert_ycbcr();
147
148         // Compile the DCT shader.
149         string shader_src = ::read_file("encoder.shader");
150         GLuint shader_num = compile_shader(shader_src, GL_COMPUTE_SHADER);
151         GLuint glsl_program_num = glCreateProgram();
152         glAttachShader(glsl_program_num, shader_num);
153         glLinkProgram(glsl_program_num);
154
155         GLint success;
156         glGetProgramiv(glsl_program_num, GL_LINK_STATUS, &success);
157         if (success == GL_FALSE) {
158                 GLchar error_log[1024] = {0};
159                 glGetProgramInfoLog(glsl_program_num, 1024, nullptr, error_log);
160                 fprintf(stderr, "Error linking program: %s\n", error_log);
161                 exit(1);
162         }
163
164         // Compile the tally shader.
165         shader_src = ::read_file("tally.shader");
166         shader_num = compile_shader(shader_src, GL_COMPUTE_SHADER);
167         GLuint glsl_tally_program_num = glCreateProgram();
168         glAttachShader(glsl_tally_program_num, shader_num);
169         glLinkProgram(glsl_tally_program_num);
170
171         glGetProgramiv(glsl_tally_program_num, GL_LINK_STATUS, &success);
172         if (success == GL_FALSE) {
173                 GLchar error_log[1024] = {0};
174                 glGetProgramInfoLog(glsl_tally_program_num, 1024, nullptr, error_log);
175                 fprintf(stderr, "Error linking program: %s\n", error_log);
176                 exit(1);
177         }
178
179         // Compile the rANS shader.
180         shader_src = ::read_file("rans.shader");
181         shader_num = compile_shader(shader_src, GL_COMPUTE_SHADER);
182         GLuint glsl_rans_program_num = glCreateProgram();
183         glAttachShader(glsl_rans_program_num, shader_num);
184         glLinkProgram(glsl_rans_program_num);
185
186         glGetProgramiv(glsl_rans_program_num, GL_LINK_STATUS, &success);
187         if (success == GL_FALSE) {
188                 GLchar error_log[1024] = {0};
189                 glGetProgramInfoLog(glsl_rans_program_num, 1024, nullptr, error_log);
190                 fprintf(stderr, "Error linking program: %s\n", error_log);
191                 exit(1);
192         }
193         check_error();
194
195         // An SSBO for the rANS distributions.
196         GLuint ssbo;
197         glGenBuffers(1, &ssbo);
198         glBindBuffer(GL_SHADER_STORAGE_BUFFER, ssbo);
199         glNamedBufferStorage(ssbo, 256 * 16 * sizeof(uint32_t), nullptr, GL_MAP_READ_BIT | GL_MAP_PERSISTENT_BIT);
200         check_error();
201
202         // SSBOs for the rANS output (data and offsets).
203         GLuint output_ssbo;
204         glGenBuffers(1, &output_ssbo);
205         glBindBuffer(GL_SHADER_STORAGE_BUFFER, output_ssbo);
206         glNamedBufferStorage(output_ssbo, 45 * 64 * 1024, nullptr, GL_MAP_READ_BIT | GL_MAP_PERSISTENT_BIT);
207         check_error();
208
209         GLuint output_offset_ssbo;
210         glGenBuffers(1, &output_offset_ssbo);
211         glBindBuffer(GL_SHADER_STORAGE_BUFFER, output_offset_ssbo);
212         glNamedBufferStorage(output_offset_ssbo, 45 * 64 * sizeof(uint32_t), nullptr, GL_MAP_READ_BIT | GL_MAP_PERSISTENT_BIT);
213         check_error();
214
215         // Bind SSBOs.
216         glUseProgram(glsl_program_num);
217         glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 9, ssbo);
218
219         glUseProgram(glsl_tally_program_num);
220         glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 9, ssbo);
221
222         glUseProgram(glsl_rans_program_num);
223         glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 9, ssbo);
224         glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 10, output_ssbo);
225         glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 11, output_offset_ssbo);
226
227         glUseProgram(glsl_program_num);
228         check_error();
229
230         // Upload luma.
231         GLuint y_tex;
232         glGenTextures(1, &y_tex);
233         glBindTexture(GL_TEXTURE_2D, y_tex);
234         glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST);
235         glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST);
236         glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_REPEAT);
237         glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_REPEAT);
238         glTexImage2D(GL_TEXTURE_2D, 0, GL_R8UI, WIDTH, HEIGHT, 0, GL_RED_INTEGER, GL_UNSIGNED_BYTE, pix_y);
239         check_error();
240
241         // Make destination textures.
242         GLuint dc_ac7_tex, ac1_ac6_tex, ac2_ac5_tex;
243         for (GLuint *tex : { &dc_ac7_tex, &ac1_ac6_tex, &ac2_ac5_tex }) {
244                 glGenTextures(1, tex);
245                 glBindTexture(GL_TEXTURE_2D, *tex);
246                 glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST);
247                 glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST);
248                 glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_REPEAT);
249                 glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_REPEAT);
250                 glTexImage2D(GL_TEXTURE_2D, 0, GL_R16UI, WIDTH / 8, HEIGHT, 0, GL_RED_INTEGER, GL_UNSIGNED_SHORT, nullptr);
251                 check_error();
252         }
253
254         GLuint ac3_tex, ac4_tex;
255         for (GLuint *tex : { &ac3_tex, &ac4_tex }) {
256                 glGenTextures(1, tex);
257                 glBindTexture(GL_TEXTURE_2D, *tex);
258                 glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST);
259                 glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST);
260                 glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_REPEAT);
261                 glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_REPEAT);
262                 glTexImage2D(GL_TEXTURE_2D, 0, GL_R8I, WIDTH / 8, HEIGHT, 0, GL_RED_INTEGER, GL_BYTE, nullptr);
263                 check_error();
264         }
265
266         glBindImageTexture(0, dc_ac7_tex, 0, GL_FALSE, 0, GL_WRITE_ONLY, GL_R16UI);
267         glBindImageTexture(1, ac1_ac6_tex, 0, GL_FALSE, 0, GL_WRITE_ONLY, GL_R16UI);
268         glBindImageTexture(2, ac2_ac5_tex, 0, GL_FALSE, 0, GL_WRITE_ONLY, GL_R16UI);
269         glBindImageTexture(3, ac3_tex, 0, GL_FALSE, 0, GL_WRITE_ONLY, GL_R8I);
270         glBindImageTexture(4, ac4_tex, 0, GL_FALSE, 0, GL_WRITE_ONLY, GL_R8I);
271         glBindImageTexture(5, y_tex, 0, GL_FALSE, 0, GL_READ_ONLY, GL_R8UI);
272         check_error();
273
274         // Bind uniforms.
275         glUseProgram(glsl_program_num);
276         GLint dc_ac7_tex_uniform = glGetUniformLocation(glsl_program_num, "dc_ac7_tex");
277         GLint ac1_ac6_tex_uniform = glGetUniformLocation(glsl_program_num, "ac1_ac6_tex");
278         GLint ac2_ac5_tex_uniform = glGetUniformLocation(glsl_program_num, "ac2_ac5_tex");
279         GLint ac3_tex_uniform = glGetUniformLocation(glsl_program_num, "ac3_tex");
280         GLint ac4_tex_uniform = glGetUniformLocation(glsl_program_num, "ac4_tex");
281         GLint image_tex_uniform = glGetUniformLocation(glsl_program_num, "image_tex");
282         glUniform1i(dc_ac7_tex_uniform, 0);
283         glUniform1i(ac1_ac6_tex_uniform, 1);
284         glUniform1i(ac2_ac5_tex_uniform, 2);
285         glUniform1i(ac3_tex_uniform, 3);
286         glUniform1i(ac4_tex_uniform, 4);
287         glUniform1i(image_tex_uniform, 5);
288
289         glUseProgram(glsl_rans_program_num);
290         dc_ac7_tex_uniform = glGetUniformLocation(glsl_rans_program_num, "dc_ac7_tex");
291         ac1_ac6_tex_uniform = glGetUniformLocation(glsl_rans_program_num, "ac1_ac6_tex");
292         ac2_ac5_tex_uniform = glGetUniformLocation(glsl_rans_program_num, "ac2_ac5_tex");
293         ac3_tex_uniform = glGetUniformLocation(glsl_rans_program_num, "ac3_tex");
294         ac4_tex_uniform = glGetUniformLocation(glsl_rans_program_num, "ac4_tex");
295         image_tex_uniform = glGetUniformLocation(glsl_rans_program_num, "image_tex");
296         glUniform1i(dc_ac7_tex_uniform, 0);
297         glUniform1i(ac1_ac6_tex_uniform, 1);
298         glUniform1i(ac2_ac5_tex_uniform, 2);
299         glUniform1i(ac3_tex_uniform, 3);
300         glUniform1i(ac4_tex_uniform, 4);
301
302         steady_clock::time_point start = steady_clock::now();
303         unsigned num_iterations = 100;
304         for (unsigned i = 0; i < num_iterations; ++i) {
305                 glClearNamedBufferSubData(ssbo, GL_R8, 0, 256 * 16 * sizeof(uint32_t), GL_RED, GL_UNSIGNED_BYTE, nullptr);
306                 glUseProgram(glsl_program_num);
307                 glDispatchCompute(WIDTH_BLOCKS / 16, HEIGHT_BLOCKS, 1);
308                 glMemoryBarrier(GL_SHADER_STORAGE_BARRIER_BIT);
309
310                 glUseProgram(glsl_tally_program_num);
311                 glDispatchCompute(4, 1, 1);
312                 glMemoryBarrier(GL_SHADER_STORAGE_BARRIER_BIT);
313         
314                 glUseProgram(glsl_rans_program_num);
315                 glDispatchCompute(NUM_BLOCKS / BLOCKS_PER_STREAM, 8, 5);
316         }
317         check_error();
318         glFinish();
319         check_error();
320         steady_clock::time_point now = steady_clock::now();
321
322 #if 0
323         printf("%ld bytes + %ld escape bits (%ld) = %ld total bytes\n",
324                 tot_bytes - extra_bits / 8,
325                 extra_bits,
326                 extra_bits / 8,
327                 tot_bytes);
328
329         printf("\n");
330 #endif
331
332         printf("Each iteration took %.3f ms.\n", 1e3 * duration<double>(now - start).count() / num_iterations);
333
334         FILE *codedfp = fopen("coded.dat", "wb");
335         if (codedfp == nullptr) {
336                 perror("coded.dat");
337                 exit(1);
338         }
339
340         // Write out the distributions.
341         const RansDistSSBO *rans_dist = (const RansDistSSBO *)glMapNamedBufferRange(ssbo, 0, 256 * 16 * sizeof(uint32_t), GL_MAP_READ_BIT | GL_MAP_PERSISTENT_BIT);
342         for (unsigned r = 0; r < 2; ++r) {  // Hack to write fake chroma tables.
343                 // TODO: rather gamma-k or something
344                 for (unsigned i = 0; i < 4; ++i) {
345                         printf("writing table %d\n", i);
346                         for (unsigned j = 0; j < NUM_SYMS; ++j) {
347                                 printf("%d,%d: %d\n", i, j, rans_dist->ransdist[i * 256 + j].first);
348                                 write_varint(rans_dist->ransdist[i * 256 + j].first, codedfp);
349                         }
350                 }
351         }
352
353         // Write out the actual data.
354         // TODO: Do the deduplication.
355
356         const uint32_t *offsets = (const uint32_t *)glMapNamedBufferRange(output_offset_ssbo, 0, 45 * 64 * sizeof(uint32_t), GL_MAP_READ_BIT | GL_MAP_PERSISTENT_BIT);
357 #if 0
358         for (int i = 0; i < 45*64; ++i) {
359                 printf("%d,%d,%d: %u\n", i / 64, (i / 8) % 8, i % 8, 1024 * (i + 1) - offsets[i]);
360         }
361 #endif
362
363         const uint8_t *data = (const uint8_t *)glMapNamedBufferRange(output_ssbo, 0, 45 * 64 * 1024, GL_MAP_READ_BIT | GL_MAP_PERSISTENT_BIT);
364
365         for (unsigned y = 0; y < 8; ++y) {
366                 for (unsigned x = 0; x < 8; ++x) {
367                         for (unsigned int stream_idx = 0; stream_idx < 45; ++stream_idx) {
368                                 const uint8_t *out_end = data + (stream_idx * 64 + y * 8 + x + 1) * 1024;
369                                 const uint8_t *ptr = data + offsets[stream_idx * 64 + y * 8 + x];
370                                 uint32_t num_rans_bytes = out_end - ptr;
371 #if 0
372                                 if (num_rans_bytes == last_block.size() &&
373                                     memcmp(last_block.data(), ptr, last_block.size()) == 0) {
374                                         write_varint(0, codedfp);
375                                         clear();
376                                         return 1;
377                                 } else {
378                                         last_block = string((const char *)ptr, num_rans_bytes);
379                                 }
380 #endif
381
382                                 write_varint(num_rans_bytes, codedfp);
383                                 fwrite(ptr, 1, num_rans_bytes, codedfp);
384                         }
385                 }
386         }
387         fclose(codedfp);
388 }