]> git.sesse.net Git - narabu/blob - qdd.cpp
Make the encoder 100% GPU. Not working yet, though.
[narabu] / qdd.cpp
1 #include <stdio.h>
2 #include <stdlib.h>
3 #include <stdint.h>
4 #include <assert.h>
5 #include <memory>
6
7 #define WIDTH 1280
8 #define HEIGHT 720
9 #define WIDTH_BLOCKS (WIDTH/8)
10 #define WIDTH_BLOCKS_CHROMA (WIDTH/16)
11 #define HEIGHT_BLOCKS (HEIGHT/8)
12 #define NUM_BLOCKS (WIDTH_BLOCKS * HEIGHT_BLOCKS)
13 #define NUM_BLOCKS_CHROMA (WIDTH_BLOCKS_CHROMA * HEIGHT_BLOCKS)
14
15 #define NUM_SYMS 256
16 #define ESCAPE_LIMIT (NUM_SYMS - 1)
17 #define BLOCKS_PER_STREAM 320
18
19 #include "ryg_rans/rans_byte.h"
20
21 using namespace std;
22
23 void fdct_int32(short *const In);
24 void idct_int32(short *const In);
25
26 unsigned char pix[WIDTH * HEIGHT];
27 short coeff[WIDTH * HEIGHT];
28
29 struct RansDecoder {
30 };
31
32 static constexpr uint32_t prob_bits = 12;
33 static constexpr uint32_t prob_scale = 1 << prob_bits;
34
35 struct RansDecodeTable {
36         int cum2sym[prob_scale];
37         RansDecSymbol dsyms[NUM_SYMS];
38 };
39 #define NUM_TABLES 8
40 RansDecodeTable decode_tables[NUM_TABLES];
41
42 static const unsigned char std_luminance_quant_tbl[64] = {
43 #if 0
44         16,  11,  10,  16,  24,  40,  51,  61,
45         12,  12,  14,  19,  26,  58,  60,  55,
46         14,  13,  16,  24,  40,  57,  69,  56,
47         14,  17,  22,  29,  51,  87,  80,  62,
48         18,  22,  37,  56,  68, 109, 103,  77,
49         24,  35,  55,  64,  81, 104, 113,  92,
50         49,  64,  78,  87, 103, 121, 120, 101,
51         72,  92,  95,  98, 112, 100, 103,  99
52 #else
53         // ff_mpeg1_default_intra_matrix
54          8, 16, 19, 22, 26, 27, 29, 34,
55         16, 16, 22, 24, 27, 29, 34, 37,
56         19, 22, 26, 27, 29, 34, 34, 38,
57         22, 22, 26, 27, 29, 34, 37, 40,
58         22, 26, 27, 29, 32, 35, 40, 48,
59         26, 27, 29, 32, 35, 40, 48, 58,
60         26, 27, 29, 34, 38, 46, 56, 69,
61         27, 29, 35, 38, 46, 56, 69, 83
62 #endif
63 };
64
65
66 const int luma_mapping[64] = {
67         0, 0, 1, 1, 2, 2, 3, 3,
68         0, 0, 1, 2, 2, 2, 3, 3,
69         1, 1, 2, 2, 2, 3, 3, 3,
70         1, 1, 2, 2, 2, 3, 3, 3,
71         1, 2, 2, 2, 2, 3, 3, 3,
72         2, 2, 2, 2, 3, 3, 3, 3,
73         2, 2, 3, 3, 3, 3, 3, 3,
74         3, 3, 3, 3, 3, 3, 3, 3,
75 };
76 const int chroma_mapping[64] = {
77         0, 1, 1, 2, 2, 2, 3, 3,
78         1, 1, 2, 2, 2, 3, 3, 3,
79         2, 2, 2, 2, 3, 3, 3, 3,
80         2, 2, 2, 3, 3, 3, 3, 3,
81         2, 3, 3, 3, 3, 3, 3, 3,
82         3, 3, 3, 3, 3, 3, 3, 3,
83         3, 3, 3, 3, 3, 3, 3, 3,
84         3, 3, 3, 3, 3, 3, 3, 3,
85 };
86
87 int pick_stats_for(int x, int y, bool is_chroma)
88 {
89         if (is_chroma) {
90                 return chroma_mapping[y * 8 + x] + 4;
91         } else {
92                 return luma_mapping[y * 8 + x];
93         }
94 }
95
96 uint32_t read_varint(FILE *fp)
97 {
98         uint32_t x = 0;
99         int shift = 0;
100         for ( ;; ) {
101                 int ch = getc(fp);
102                 if (ch == -1) {
103                         fprintf(stderr, "Premature EOF\n");
104                         exit(1);
105                 }
106
107                 x |= (ch & 0x7f) << shift;
108                 if ((ch & 0x80) == 0) return x;
109                 shift += 7;
110         }
111 }
112
113 static constexpr int dc_scalefac = 8;  // Matches the FDCT's gain.
114 static constexpr double quant_scalefac = 4.0;  // whatever?
115
116 static inline int unquantize(int qf, int coeff_idx)
117 {
118         if (coeff_idx == 0) {
119                 return qf * dc_scalefac;
120         }
121         if (qf == 0) {
122                 return 0;
123         }
124
125         const int w = std_luminance_quant_tbl[coeff_idx];
126         const int s = quant_scalefac;
127         return (2 * qf * w * s) / 32;
128 }
129
130 int main(void)
131 {
132         FILE *fp = fopen("coded.dat", "rb");
133         if (fp == nullptr) {
134                 perror("coded.dat");
135                 exit(1);
136         }
137
138         uint32_t sign_bias[NUM_TABLES];
139         for (unsigned table = 0; table < NUM_TABLES; ++table) {
140                 uint32_t cum_freq = 0;
141                 for (unsigned sym = 0; sym < NUM_SYMS; ++sym) {
142                         uint32_t freq = read_varint(fp);
143                         fprintf(stderr, "sym=%u/%u: freq=%u\n", sym, NUM_SYMS, freq);
144                         RansDecSymbolInit(&decode_tables[table].dsyms[(sym + 1) & 255], cum_freq, freq);
145                         for (uint32_t i = 0; i < freq; ++i) {
146                                 if (cum_freq < prob_scale)
147                                         decode_tables[table].cum2sym[cum_freq] = (sym + 1) & 255;
148                                 ++cum_freq;
149                         }
150                 }
151                 sign_bias[table] = cum_freq;
152                 printf("sign_bias=%u (of %d)\n", sign_bias[table], prob_scale * 2);
153         }
154
155         // loop over all coefficients
156         for (unsigned y = 0; y < 8; ++y) {
157                 for (unsigned x = 0; x < 8; ++x) {
158                         unsigned tbl = pick_stats_for(x, y, false);
159                 
160                         RansState rans = 0;
161
162                         unique_ptr<uint8_t[]> rans_bytes;
163                         uint8_t *rans_ptr = nullptr;
164
165                         // loop over all DCT blocks
166                         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
167                                 if (yb % 16 == 0) {
168                                         // read a block
169                                         uint32_t num_rans_bytes = read_varint(fp);
170                                         rans_bytes.reset(new uint8_t[num_rans_bytes]);
171                                         fread(rans_bytes.get(), 1, num_rans_bytes, fp);
172
173                                         printf("%d,%d: read %d rANS bytes\n", x, y, num_rans_bytes);
174                                         //printf("first bytes: %02x %02x %02x %02x %02x %02x %02x %02x\n", rans_bytes[0], rans_bytes[1], rans_bytes[2], rans_bytes[3], rans_bytes[4], rans_bytes[5], rans_bytes[6], rans_bytes[7]);
175
176
177                                         // init rANS decoder
178                                         rans_ptr = rans_bytes.get();
179                                         RansDecInit(&rans, &rans_ptr);
180                                 }
181                                 for (unsigned xb = 0; xb < WIDTH; xb += 8) {
182                                         uint32_t bottom_bits = RansDecGet(&rans, prob_bits + 1);
183                                         uint32_t sign = 0;
184                                         if (bottom_bits >= sign_bias[tbl]) {
185                                                 bottom_bits -= sign_bias[tbl];
186                                                 rans -= sign_bias[tbl];
187                                                 sign = 1;
188                                         }
189                                         uint32_t k = decode_tables[tbl].cum2sym[std::min(bottom_bits, prob_scale - 1)];
190                                         RansDecAdvanceSymbol(&rans, &rans_ptr, &decode_tables[tbl].dsyms[k], prob_bits + 1);
191                                         assert(k <= ESCAPE_LIMIT);
192                                         if (k == ESCAPE_LIMIT) {
193                                                 k = RansDecGet(&rans, prob_bits);
194                                                 assert(k >= ESCAPE_LIMIT);
195                                                 RansDecAdvance(&rans, &rans_ptr, k, 1, prob_bits);
196                                         }
197                                         if (sign) {
198                                                 assert(k != 0);
199                                                 k = -k;
200                                         }
201
202                                         // reverse
203                                         int reversed_yb = yb ^ 8;
204                                         int reversed_xb = WIDTH - 8 - xb;
205                                         coeff[(reversed_yb + y) * WIDTH + (reversed_xb + x)] = k;
206                 //                      printf("coeff %d xb,yb=%d,%d: decoded %d\n", y * 8 + x, reversed_xb, reversed_yb, k);
207                                 }
208                         }
209
210 #if 0
211                         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
212                                 for (unsigned xb = 0; xb < WIDTH; xb += 8) {
213 #endif  
214                 }
215         }
216         fclose(fp);
217
218         // DC coefficient pred from the right to left (within each slice)
219         for (unsigned block_idx = 0; block_idx < NUM_BLOCKS; block_idx += BLOCKS_PER_STREAM) {
220                 int prev_k = 128;
221
222                 for (unsigned subblock_idx = BLOCKS_PER_STREAM; subblock_idx --> 0; ) {
223                         unsigned yb = (block_idx + subblock_idx) / WIDTH_BLOCKS;
224                         unsigned xb = (block_idx + subblock_idx) % WIDTH_BLOCKS;
225                         int k = coeff[(yb * 8) * WIDTH + (xb * 8)];
226
227                         prev_k += k;
228                         coeff[(yb * 8) * WIDTH + (xb * 8)] = prev_k;
229                 }
230         }
231
232         // IDCT
233         for (unsigned yb = 0; yb < HEIGHT; yb += 8) {
234                 for (unsigned xb = 0; xb < WIDTH; xb += 8) {
235                         // Read one block
236                         short in[64];
237                         for (unsigned y = 0; y < 8; ++y) {
238                                 for (unsigned x = 0; x < 8; ++x) {
239                                         int k = coeff[(yb + y) * WIDTH + (xb + x)];
240                                         in[y * 8 + x] = unquantize(k, y * 8 + x);
241                                         printf("%3d ", in[y * 8 + x]);
242                                 }
243                                 printf("\n");
244                         }
245                         printf("\n");
246
247                         idct_int32(in);
248
249                         // Clamp and move back
250                         for (unsigned y = 0; y < 8; ++y) {
251                                 for (unsigned x = 0; x < 8; ++x) {
252                                         int k = in[y * 8 + x];
253                                         printf("%3d ", k);
254                                         if (k < 0) k = 0;
255                                         if (k > 255) k = 255;
256                                         pix[(yb + y) * WIDTH + (xb + x)] = k;
257                                 }
258                                 printf("\n");
259                         }
260                         printf("\n");
261                 }
262         }
263
264         fp = fopen("output.pgm", "wb");
265         fprintf(fp, "P5\n%d %d\n255\n", WIDTH, HEIGHT);
266         fwrite(pix, 1, WIDTH * HEIGHT, fp);
267         fclose(fp);
268 }