]> git.sesse.net Git - narabu/blob - narabu.cpp
Add the GPU decoder itself.
[narabu] / narabu.cpp
1 #include <stdio.h>
2 #include <assert.h>
3 #include <SDL2/SDL.h>
4 #include <SDL2/SDL_error.h>
5 #include <SDL2/SDL_video.h>
6 #include <epoxy/gl.h>
7 #include <string>
8 #include <optional>
9 #include <algorithm>
10 #include <vector>
11 #include <memory>
12
13 #include "util.h"
14
15 using namespace std;
16
17 #define WIDTH 1280
18 #define HEIGHT 720
19
20 const unsigned prob_bits = 12;
21 const unsigned prob_scale = 1 << prob_bits;
22 const unsigned NUM_SYMS = 256;
23 const unsigned NUM_TABLES = 16;
24
25 struct RansDecSymbol {
26         unsigned sym_start;
27         unsigned sym_freq;
28 };
29 struct RansDecodeTable {
30         int cum2sym[prob_scale];
31         RansDecSymbol dsyms[NUM_SYMS];
32 };
33 RansDecodeTable decode_tables[NUM_TABLES];
34
35 optional<uint32_t> read_varint(const char **ptr, const char *end)
36 {
37         uint32_t x = 0;
38         int shift = 0;
39         while (*ptr < end) {
40                 int ch = **ptr;
41                 ++(*ptr);       
42
43                 x |= (ch & 0x7f) << shift;
44                 if ((ch & 0x80) == 0) return x;
45                 shift += 7;
46                 if (shift >= 32) {
47                         return nullopt;  // Error: Overlong int.
48                 }
49         }
50         return nullopt;  // Error: EOF.
51 }
52
53 struct CoeffStream {
54         uint src_offset, src_len, sign_offset, sign_len, extra_bits;
55 };
56 CoeffStream streams[45 * 64];  // HACK
57
58 int main(int argc, char **argv)
59 {
60         // Set up an OpenGL context using SDL.
61         if (SDL_Init(SDL_INIT_VIDEO) == -1) {
62                 fprintf(stderr, "SDL_Init failed: %s\n", SDL_GetError());
63                 exit(1);
64         }
65         SDL_GL_SetAttribute(SDL_GL_DEPTH_SIZE, 0);
66         SDL_GL_SetAttribute(SDL_GL_STENCIL_SIZE, 0);
67         SDL_GL_SetAttribute(SDL_GL_DOUBLEBUFFER, 1);
68         SDL_GL_SetAttribute(SDL_GL_CONTEXT_PROFILE_MASK, SDL_GL_CONTEXT_PROFILE_CORE);
69         SDL_GL_SetAttribute(SDL_GL_CONTEXT_MAJOR_VERSION, 4);
70         SDL_GL_SetAttribute(SDL_GL_CONTEXT_MINOR_VERSION, 5);
71
72         SDL_Window *window = SDL_CreateWindow("OpenGL window for unit test",
73                 SDL_WINDOWPOS_UNDEFINED,
74                 SDL_WINDOWPOS_UNDEFINED,
75                 32, 32,
76                 SDL_WINDOW_OPENGL);
77         SDL_GLContext context = SDL_GL_CreateContext(window);
78         assert(context != nullptr);
79
80         //char buf[16] = { 0 };
81
82         GLint size;
83         glGetIntegerv(GL_MAX_COMPUTE_SHARED_MEMORY_SIZE, &size);
84         printf("shared_memory_size=%u\n", size);
85
86         string shader_src = read_file("decoder-pre-sign.shader");
87         GLuint shader_num = compile_shader(shader_src, GL_COMPUTE_SHADER);
88         GLuint glsl_program_num = glCreateProgram();
89         glAttachShader(glsl_program_num, shader_num);
90         glLinkProgram(glsl_program_num);
91
92         GLint success;
93         glGetProgramiv(glsl_program_num, GL_LINK_STATUS, &success);
94         if (success == GL_FALSE) {
95                 GLchar error_log[1024] = {0};
96                 glGetProgramInfoLog(glsl_program_num, 1024, nullptr, error_log);
97                 fprintf(stderr, "Error linking program: %s\n", error_log);
98                 exit(1);
99         }
100
101         glUseProgram(glsl_program_num);
102
103         string coded = read_file(argc >= 2 ? argv[1] : "coded.dat");
104         const char *ptr = &coded[0];
105         const char *end = ptr + coded.size();
106
107 //      printf("first few bytes offs=%zu: %d %d %d %d %d %d %d %d\n", ptr - coded.data(),
108 //              (uint8_t)ptr[0], (uint8_t)ptr[1], (uint8_t)ptr[2], (uint8_t)ptr[3],
109 //              (uint8_t)ptr[4], (uint8_t)ptr[5], (uint8_t)ptr[6], (uint8_t)ptr[7]);
110
111         // read the rANS tables
112         for (unsigned table = 0; table < NUM_TABLES; ++table) {
113                 uint32_t cum_freq = 0;
114                 for (unsigned sym = 0; sym < NUM_SYMS; ++sym) {
115                         optional<uint32_t> freq = read_varint(&ptr, end);
116                         if (!freq) {
117                                 fprintf(stderr, "Error parsing varint for table %d symbol %d\n", table, sym);
118                                 exit(1);
119                         }
120
121                         decode_tables[table].dsyms[sym].sym_start = cum_freq;
122                         decode_tables[table].dsyms[sym].sym_freq = *freq;
123                         for (uint32_t i = 0; i < freq; ++i) {
124                                 decode_tables[table].cum2sym[cum_freq++] = sym;
125                         }
126                 }
127         }
128
129         // Make cum2sym texture.
130         unique_ptr<uint8_t[]> cum2sym_data(new uint8_t[prob_scale * NUM_TABLES]);
131         for (unsigned table = 0; table < NUM_TABLES; ++table) {
132                 for (unsigned i = 0; i < prob_scale; ++i) {
133                         cum2sym_data[prob_scale * table + i] = decode_tables[table].cum2sym[i];
134                 }
135         }
136         GLuint cum2sym_tex;
137         glGenTextures(1, &cum2sym_tex);
138         glBindTexture(GL_TEXTURE_2D, cum2sym_tex);
139         glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST);
140         glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST);
141         glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_REPEAT);
142         glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_REPEAT);
143         glTexImage2D(GL_TEXTURE_2D, 0, GL_R8UI, prob_scale, NUM_TABLES, 0, GL_RED_INTEGER, GL_UNSIGNED_BYTE, cum2sym_data.get());
144
145         // Make dsyms texture.
146         unique_ptr<pair<uint16_t, uint16_t>[]> dsyms_data(new pair<uint16_t, uint16_t>[NUM_SYMS * NUM_TABLES]);
147         for (unsigned table = 0; table < NUM_TABLES; ++table) {
148                 for (unsigned sym = 0; sym < NUM_SYMS; ++sym) {
149                         dsyms_data[NUM_SYMS * table + sym].first = decode_tables[table].dsyms[sym].sym_start;
150                         dsyms_data[NUM_SYMS * table + sym].second = decode_tables[table].dsyms[sym].sym_freq;
151                 }
152         }
153         GLuint dsyms_tex;
154         glGenTextures(1, &dsyms_tex);
155         glBindTexture(GL_TEXTURE_2D, dsyms_tex);
156         glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST);
157         glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST);
158         glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_REPEAT);
159         glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_REPEAT);
160         glTexImage2D(GL_TEXTURE_2D, 0, GL_RG16UI, NUM_SYMS, NUM_TABLES, 0, GL_RG_INTEGER, GL_UNSIGNED_SHORT, dsyms_data.get());
161
162         GLuint out_tex;
163         glGenTextures(1, &out_tex);
164         glBindTexture(GL_TEXTURE_2D, out_tex);
165         glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST);
166         glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST);
167         glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_REPEAT);
168         glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_REPEAT);
169         glTexImage2D(GL_TEXTURE_2D, 0, GL_R8, 1280, 720, 0, GL_RED, GL_UNSIGNED_BYTE, nullptr);
170         //glTexImage2D(GL_TEXTURE_2D, 0, GL_R32F, 1280, 720, 0, GL_RED, GL_FLOAT, nullptr);
171
172         //GLint src_offset_pos = glGetUniformLocation(glsl_program_num, "src_offset");
173         //GLint sign_offset_pos = glGetUniformLocation(glsl_program_num, "sign_offset");
174         //GLint extra_bits_pos = glGetUniformLocation(glsl_program_num, "extra_bits");
175         GLint cum2sym_tex_pos = glGetUniformLocation(glsl_program_num, "cum2sym_tex");
176         GLint dsyms_tex_pos = glGetUniformLocation(glsl_program_num, "dsyms_tex");
177         GLint out_tex_pos = glGetUniformLocation(glsl_program_num, "out_tex");
178         printf("%d err=0x%x pos=%d,%d,%d\n", __LINE__, glGetError(), cum2sym_tex_pos, dsyms_tex_pos, out_tex_pos);
179
180         // Bind the textures.
181         glUniform1i(cum2sym_tex_pos, 0);
182         glUniform1i(dsyms_tex_pos, 1);
183         glUniform1i(out_tex_pos, 2);
184         glBindImageTexture(0, cum2sym_tex, 0, GL_FALSE, 0, GL_READ_ONLY, GL_R8UI);
185         glBindImageTexture(1, dsyms_tex, 0, GL_FALSE, 0, GL_READ_ONLY, GL_RG16UI);
186         glBindImageTexture(2, out_tex, 0, GL_FALSE, 0, GL_WRITE_ONLY, GL_R8);
187         printf("%d err=0x%x\n", __LINE__, glGetError());
188
189         // Decode all luma blocks.
190         unsigned num_blocks = (HEIGHT / 16);
191         for (unsigned y = 0; y < 8; ++y) {
192                 for (unsigned x = 0; x < 8; ++x) {
193                         unsigned coeff_num = y * 8 + x;
194
195                         for (unsigned yb = 0; yb < HEIGHT; yb += 16) {
196                                 optional<uint32_t> num_rans_bytes = read_varint(&ptr, end);
197                                 if (!num_rans_bytes) {
198                                         fprintf(stderr, "Error parsing varint for block %d rANS bytes\n", yb);
199                                         exit(1);
200                                 }
201
202                                 CoeffStream *stream = &streams[coeff_num * num_blocks + (yb/16)];
203                                 stream->src_offset = ptr - coded.data();
204                                 stream->src_len = *num_rans_bytes;
205
206                                 // TODO: check len
207                                 ptr += *num_rans_bytes;
208
209                                 optional<uint32_t> num_sign_bytes = read_varint(&ptr, end);
210                                 if (!num_sign_bytes) {
211                                         fprintf(stderr, "Error parsing varint for block %d rANS bytes\n", yb);
212                                         exit(1);
213                                 }
214
215                                 stream->sign_offset = ptr - coded.data();
216                                 stream->sign_len = *num_sign_bytes >> 3;
217                                 stream->extra_bits = *num_sign_bytes & 0x7;
218
219                                 // TODO: check len
220                                 // TODO: free bits
221                                 ptr += *num_sign_bytes >> 3;
222
223                                 //printf("read %d rANS bytes, %d sign bytes\n", *num_rans_bytes, *num_sign_bytes);
224                         }
225                 }
226         }
227
228         // put the coded data (as a whole) into an SSBO
229         printf("%d err=0x%x bufsize=%zu\n", __LINE__, glGetError(), coded.size());
230
231         GLuint ssbo_stream, ssbo, ssbo_out;
232
233         glGenBuffers(1, &ssbo_stream);
234         glBindBuffer(GL_SHADER_STORAGE_BUFFER, ssbo_stream);
235         glBufferData(GL_SHADER_STORAGE_BUFFER, sizeof(streams), streams, GL_STREAM_DRAW);
236         glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 0, ssbo_stream);
237         printf("%d err=0x%x bufsize=%zu\n", __LINE__, glGetError(), coded.size());
238
239         glGenBuffers(1, &ssbo);
240         glBindBuffer(GL_SHADER_STORAGE_BUFFER, ssbo);
241         glBufferData(GL_SHADER_STORAGE_BUFFER, coded.size(), coded.data(), GL_STREAM_DRAW);
242         glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 9, ssbo);
243         printf("%d err=0x%x bufsize=%zu\n", __LINE__, glGetError(), coded.size());
244
245         glGenBuffers(1, &ssbo_out);
246         glBindBuffer(GL_SHADER_STORAGE_BUFFER, ssbo_out);
247         glBufferData(GL_SHADER_STORAGE_BUFFER, 16384, nullptr, GL_STREAM_DRAW);  // ??
248         glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 10, ssbo_out);
249
250         for (int i = 0; i < 10000; ++i)
251         glDispatchCompute(1, 45, 1);
252
253         unsigned *timing = (unsigned *)glMapBufferRange(GL_SHADER_STORAGE_BUFFER, 0, 16384, GL_MAP_READ_BIT);
254         //setlocale(LC_ALL, "nb_NO.UTF-8");
255
256         string phases[] = {
257                 "init",
258                 "loop overhead",
259                 "rANS decode",
260                 "barrier after rANS decode",
261                 "horizontal IDCT",
262                 "barrier after horizontal IDCT",
263                 "vertical IDCT",
264                 "store to texture",
265                 "barrier after store to texture",
266                 "dummy timer for overhead measurement",
267         };
268         printf("\n");
269         for (int i = 0; i < 10; ++i) {
270                 //printf("%d: %'18.0f  [%s]\n", i, double((uint64_t(timing[i * 2 + 1]) << 32) | timing[i * 2]), phases[i].c_str());
271                 printf("%d,%s", i, phases[i].c_str());
272                 for (int j = 0; j < 64; ++j) {
273                         int idx = (j * 10 + i) * 2;
274                         uint64_t val = (uint64_t(timing[idx + 1]) << 32) | timing[idx];
275                 //      printf(" %'18.0f", double(val));
276                 //      printf(" %'6.0f", double(val) * 1e-6);
277                         printf(",%.0f", double(val) * 1e-6);
278                 }
279                 printf("\n");
280                 //printf("  [%s]\n", phases[i].c_str());
281         }
282         printf("\n");
283
284         unsigned char *data = new unsigned char[1280 * 720];
285         glGetTexImage(GL_TEXTURE_2D, 0, GL_RED, GL_UNSIGNED_BYTE, data);
286         printf("%d err=0x%x bufsize=%zu\n", __LINE__, glGetError(), coded.size());
287
288 #if 0
289         for (int k = 0; k < 4; ++k) {
290                 for (int y = 0; y < 8; ++y) {
291                         for (int x = 0; x < 8; ++x) {
292                                 printf("%3d ", data[y * 1280 + x + k*8]);
293                         }
294                         printf("\n");
295                 }
296                 printf("\n");
297         }
298         printf("\n");
299 #else
300         for (int k = 0; k < 4; ++k) {
301                 for (int y = 0; y < 8; ++y) {
302                         for (int x = 0; x < 8; ++x) {
303                                 //printf("%5.2f ", data[(y+8) * 1280 + x + (1272-k*8)]);
304                                 printf("%3d ", data[y * 1280 + x + k*8]);
305                         }
306                         printf("\n");
307                 }
308                 printf("\n");
309         }
310         printf("\n");
311 #endif
312
313         FILE *fp = fopen("narabu.pgm", "wb");
314         fprintf(fp, "P5\n1280 720\n255\n");
315         for (int y = 0; y < 720; ++y) {
316                 for (int x = 0; x < 1280; ++x) {
317                         int k = lrintf(data[y * 1280 + x]);
318                         if (k < 0) k = 0;
319                         if (k > 255) k = 255;
320                         putc(k, fp);
321                 }
322         }
323         fclose(fp);
324         
325         glBindBuffer(GL_SHADER_STORAGE_BUFFER, 0); // unbind
326         
327         printf("foo = 0x%x\n", glGetError());
328 }