]> git.sesse.net Git - ffmpeg/blob - libavcodec/cbs_av1.c
vaapi_encode: Support configurable slices
[ffmpeg] / libavcodec / cbs_av1.c
1 /*
2  * This file is part of FFmpeg.
3  *
4  * FFmpeg is free software; you can redistribute it and/or
5  * modify it under the terms of the GNU Lesser General Public
6  * License as published by the Free Software Foundation; either
7  * version 2.1 of the License, or (at your option) any later version.
8  *
9  * FFmpeg is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
12  * Lesser General Public License for more details.
13  *
14  * You should have received a copy of the GNU Lesser General Public
15  * License along with FFmpeg; if not, write to the Free Software
16  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17  */
18
19 #include "libavutil/avassert.h"
20 #include "libavutil/pixfmt.h"
21
22 #include "cbs.h"
23 #include "cbs_internal.h"
24 #include "cbs_av1.h"
25 #include "internal.h"
26
27
28 static int cbs_av1_read_uvlc(CodedBitstreamContext *ctx, GetBitContext *gbc,
29                              const char *name, uint32_t *write_to,
30                              uint32_t range_min, uint32_t range_max)
31 {
32     uint32_t value;
33     int position, zeroes, i, j;
34     char bits[65];
35
36     if (ctx->trace_enable)
37         position = get_bits_count(gbc);
38
39     zeroes = i = 0;
40     while (1) {
41         if (get_bits_left(gbc) < zeroes + 1) {
42             av_log(ctx->log_ctx, AV_LOG_ERROR, "Invalid uvlc code at "
43                    "%s: bitstream ended.\n", name);
44             return AVERROR_INVALIDDATA;
45         }
46
47         if (get_bits1(gbc)) {
48             bits[i++] = '1';
49             break;
50         } else {
51             bits[i++] = '0';
52             ++zeroes;
53         }
54     }
55
56     if (zeroes >= 32) {
57         value = MAX_UINT_BITS(32);
58     } else {
59         value = get_bits_long(gbc, zeroes);
60
61         for (j = 0; j < zeroes; j++)
62             bits[i++] = (value >> (zeroes - j - 1) & 1) ? '1' : '0';
63
64         value += (1 << zeroes) - 1;
65     }
66
67     if (ctx->trace_enable) {
68         bits[i] = 0;
69         ff_cbs_trace_syntax_element(ctx, position, name, NULL,
70                                     bits, value);
71     }
72
73     if (value < range_min || value > range_max) {
74         av_log(ctx->log_ctx, AV_LOG_ERROR, "%s out of range: "
75                "%"PRIu32", but must be in [%"PRIu32",%"PRIu32"].\n",
76                name, value, range_min, range_max);
77         return AVERROR_INVALIDDATA;
78     }
79
80     *write_to = value;
81     return 0;
82 }
83
84 static int cbs_av1_write_uvlc(CodedBitstreamContext *ctx, PutBitContext *pbc,
85                               const char *name, uint32_t value,
86                               uint32_t range_min, uint32_t range_max)
87 {
88     uint32_t v;
89     int position, zeroes;
90
91     if (value < range_min || value > range_max) {
92         av_log(ctx->log_ctx, AV_LOG_ERROR, "%s out of range: "
93                "%"PRIu32", but must be in [%"PRIu32",%"PRIu32"].\n",
94                name, value, range_min, range_max);
95         return AVERROR_INVALIDDATA;
96     }
97
98     if (ctx->trace_enable)
99         position = put_bits_count(pbc);
100
101     if (value == 0) {
102         zeroes = 0;
103         put_bits(pbc, 1, 1);
104     } else {
105         zeroes = av_log2(value + 1);
106         v = value - (1 << zeroes) + 1;
107         put_bits(pbc, zeroes + 1, 1);
108         put_bits(pbc, zeroes, v);
109     }
110
111     if (ctx->trace_enable) {
112         char bits[65];
113         int i, j;
114         i = 0;
115         for (j = 0; j < zeroes; j++)
116             bits[i++] = '0';
117         bits[i++] = '1';
118         for (j = 0; j < zeroes; j++)
119             bits[i++] = (v >> (zeroes - j - 1) & 1) ? '1' : '0';
120         bits[i++] = 0;
121         ff_cbs_trace_syntax_element(ctx, position, name, NULL,
122                                     bits, value);
123     }
124
125     return 0;
126 }
127
128 static int cbs_av1_read_leb128(CodedBitstreamContext *ctx, GetBitContext *gbc,
129                                const char *name, uint64_t *write_to)
130 {
131     uint64_t value;
132     int position, err, i;
133
134     if (ctx->trace_enable)
135         position = get_bits_count(gbc);
136
137     value = 0;
138     for (i = 0; i < 8; i++) {
139         int subscript[2] = { 1, i };
140         uint32_t byte;
141         err = ff_cbs_read_unsigned(ctx, gbc, 8, "leb128_byte[i]", subscript,
142                                    &byte, 0x00, 0xff);
143         if (err < 0)
144             return err;
145
146         value |= (uint64_t)(byte & 0x7f) << (i * 7);
147         if (!(byte & 0x80))
148             break;
149     }
150
151     if (ctx->trace_enable)
152         ff_cbs_trace_syntax_element(ctx, position, name, NULL, "", value);
153
154     *write_to = value;
155     return 0;
156 }
157
158 static int cbs_av1_write_leb128(CodedBitstreamContext *ctx, PutBitContext *pbc,
159                                 const char *name, uint64_t value)
160 {
161     int position, err, len, i;
162     uint8_t byte;
163
164     len = (av_log2(value) + 7) / 7;
165
166     if (ctx->trace_enable)
167         position = put_bits_count(pbc);
168
169     for (i = 0; i < len; i++) {
170         int subscript[2] = { 1, i };
171
172         byte = value >> (7 * i) & 0x7f;
173         if (i < len - 1)
174             byte |= 0x80;
175
176         err = ff_cbs_write_unsigned(ctx, pbc, 8, "leb128_byte[i]", subscript,
177                                     byte, 0x00, 0xff);
178         if (err < 0)
179             return err;
180     }
181
182     if (ctx->trace_enable)
183         ff_cbs_trace_syntax_element(ctx, position, name, NULL, "", value);
184
185     return 0;
186 }
187
188 static int cbs_av1_read_su(CodedBitstreamContext *ctx, GetBitContext *gbc,
189                            int width, const char *name,
190                            const int *subscripts, int32_t *write_to)
191 {
192     uint32_t magnitude;
193     int position, sign;
194     int32_t value;
195
196     if (ctx->trace_enable)
197         position = get_bits_count(gbc);
198
199     if (get_bits_left(gbc) < width + 1) {
200         av_log(ctx->log_ctx, AV_LOG_ERROR, "Invalid signed value at "
201                "%s: bitstream ended.\n", name);
202         return AVERROR_INVALIDDATA;
203     }
204
205     magnitude = get_bits(gbc, width);
206     sign      = get_bits1(gbc);
207     value     = sign ? -(int32_t)magnitude : magnitude;
208
209     if (ctx->trace_enable) {
210         char bits[33];
211         int i;
212         for (i = 0; i < width; i++)
213             bits[i] = magnitude >> (width - i - 1) & 1 ? '1' : '0';
214         bits[i] = sign ? '1' : '0';
215         bits[i + 1] = 0;
216
217         ff_cbs_trace_syntax_element(ctx, position,
218                                     name, subscripts, bits, value);
219     }
220
221     *write_to = value;
222     return 0;
223 }
224
225 static int cbs_av1_write_su(CodedBitstreamContext *ctx, PutBitContext *pbc,
226                             int width, const char *name,
227                             const int *subscripts, int32_t value)
228 {
229     uint32_t magnitude;
230     int sign;
231
232     if (put_bits_left(pbc) < width + 1)
233         return AVERROR(ENOSPC);
234
235     sign      = value < 0;
236     magnitude = sign ? -value : value;
237
238     if (ctx->trace_enable) {
239         char bits[33];
240         int i;
241         for (i = 0; i < width; i++)
242             bits[i] = magnitude >> (width - i - 1) & 1 ? '1' : '0';
243         bits[i] = sign ? '1' : '0';
244         bits[i + 1] = 0;
245
246         ff_cbs_trace_syntax_element(ctx, put_bits_count(pbc),
247                                     name, subscripts, bits, value);
248     }
249
250     put_bits(pbc, width, magnitude);
251     put_bits(pbc, 1, sign);
252
253     return 0;
254 }
255
256 static int cbs_av1_read_ns(CodedBitstreamContext *ctx, GetBitContext *gbc,
257                            uint32_t n, const char *name,
258                            const int *subscripts, uint32_t *write_to)
259 {
260     uint32_t w, m, v, extra_bit, value;
261     int position;
262
263     av_assert0(n > 0);
264
265     if (ctx->trace_enable)
266         position = get_bits_count(gbc);
267
268     w = av_log2(n) + 1;
269     m = (1 << w) - n;
270
271     if (get_bits_left(gbc) < w) {
272         av_log(ctx->log_ctx, AV_LOG_ERROR, "Invalid non-symmetric value at "
273                "%s: bitstream ended.\n", name);
274         return AVERROR_INVALIDDATA;
275     }
276
277     if (w - 1 > 0)
278         v = get_bits(gbc, w - 1);
279     else
280         v = 0;
281
282     if (v < m) {
283         value = v;
284     } else {
285         extra_bit = get_bits1(gbc);
286         value = (v << 1) - m + extra_bit;
287     }
288
289     if (ctx->trace_enable) {
290         char bits[33];
291         int i;
292         for (i = 0; i < w - 1; i++)
293             bits[i] = (v >> i & 1) ? '1' : '0';
294         if (v >= m)
295             bits[i++] = extra_bit ? '1' : '0';
296         bits[i] = 0;
297
298         ff_cbs_trace_syntax_element(ctx, position,
299                                     name, subscripts, bits, value);
300     }
301
302     *write_to = value;
303     return 0;
304 }
305
306 static int cbs_av1_write_ns(CodedBitstreamContext *ctx, PutBitContext *pbc,
307                             uint32_t n, const char *name,
308                             const int *subscripts, uint32_t value)
309 {
310     uint32_t w, m, v, extra_bit;
311     int position;
312
313     if (value > n) {
314         av_log(ctx->log_ctx, AV_LOG_ERROR, "%s out of range: "
315                "%"PRIu32", but must be in [0,%"PRIu32"].\n",
316                name, value, n);
317         return AVERROR_INVALIDDATA;
318     }
319
320     if (ctx->trace_enable)
321         position = put_bits_count(pbc);
322
323     w = av_log2(n) + 1;
324     m = (1 << w) - n;
325
326     if (put_bits_left(pbc) < w)
327         return AVERROR(ENOSPC);
328
329     if (value < m) {
330         v = value;
331         put_bits(pbc, w - 1, v);
332     } else {
333         v = m + ((value - m) >> 1);
334         extra_bit = (value - m) & 1;
335         put_bits(pbc, w - 1, v);
336         put_bits(pbc, 1, extra_bit);
337     }
338
339     if (ctx->trace_enable) {
340         char bits[33];
341         int i;
342         for (i = 0; i < w - 1; i++)
343             bits[i] = (v >> i & 1) ? '1' : '0';
344         if (value >= m)
345             bits[i++] = extra_bit ? '1' : '0';
346         bits[i] = 0;
347
348         ff_cbs_trace_syntax_element(ctx, position,
349                                     name, subscripts, bits, value);
350     }
351
352     return 0;
353 }
354
355 static int cbs_av1_read_increment(CodedBitstreamContext *ctx, GetBitContext *gbc,
356                                   uint32_t range_min, uint32_t range_max,
357                                   const char *name, uint32_t *write_to)
358 {
359     uint32_t value;
360     int position, i;
361     char bits[33];
362
363     av_assert0(range_min <= range_max && range_max - range_min < sizeof(bits) - 1);
364     if (ctx->trace_enable)
365         position = get_bits_count(gbc);
366
367     for (i = 0, value = range_min; value < range_max;) {
368         if (get_bits_left(gbc) < 1) {
369             av_log(ctx->log_ctx, AV_LOG_ERROR, "Invalid increment value at "
370                    "%s: bitstream ended.\n", name);
371             return AVERROR_INVALIDDATA;
372         }
373         if (get_bits1(gbc)) {
374             bits[i++] = '1';
375             ++value;
376         } else {
377             bits[i++] = '0';
378             break;
379         }
380     }
381
382     if (ctx->trace_enable) {
383         bits[i] = 0;
384         ff_cbs_trace_syntax_element(ctx, position,
385                                     name, NULL, bits, value);
386     }
387
388     *write_to = value;
389     return 0;
390 }
391
392 static int cbs_av1_write_increment(CodedBitstreamContext *ctx, PutBitContext *pbc,
393                                    uint32_t range_min, uint32_t range_max,
394                                    const char *name, uint32_t value)
395 {
396     int len;
397
398     av_assert0(range_min <= range_max && range_max - range_min < 32);
399     if (value < range_min || value > range_max) {
400         av_log(ctx->log_ctx, AV_LOG_ERROR, "%s out of range: "
401                "%"PRIu32", but must be in [%"PRIu32",%"PRIu32"].\n",
402                name, value, range_min, range_max);
403         return AVERROR_INVALIDDATA;
404     }
405
406     if (value == range_max)
407         len = range_max - range_min;
408     else
409         len = value - range_min + 1;
410     if (put_bits_left(pbc) < len)
411         return AVERROR(ENOSPC);
412
413     if (ctx->trace_enable) {
414         char bits[33];
415         int i;
416         for (i = 0; i < len; i++) {
417             if (range_min + i == value)
418                 bits[i] = '0';
419             else
420                 bits[i] = '1';
421         }
422         bits[i] = 0;
423         ff_cbs_trace_syntax_element(ctx, put_bits_count(pbc),
424                                     name, NULL, bits, value);
425     }
426
427     if (len > 0)
428         put_bits(pbc, len, (1 << len) - 1 - (value != range_max));
429
430     return 0;
431 }
432
433 static int cbs_av1_read_subexp(CodedBitstreamContext *ctx, GetBitContext *gbc,
434                                uint32_t range_max, const char *name,
435                                const int *subscripts, uint32_t *write_to)
436 {
437     uint32_t value;
438     int position, err;
439     uint32_t max_len, len, range_offset, range_bits;
440
441     if (ctx->trace_enable)
442         position = get_bits_count(gbc);
443
444     av_assert0(range_max > 0);
445     max_len = av_log2(range_max - 1) - 3;
446
447     err = cbs_av1_read_increment(ctx, gbc, 0, max_len,
448                                  "subexp_more_bits", &len);
449     if (err < 0)
450         return err;
451
452     if (len) {
453         range_bits   = 2 + len;
454         range_offset = 1 << range_bits;
455     } else {
456         range_bits   = 3;
457         range_offset = 0;
458     }
459
460     if (len < max_len) {
461         err = ff_cbs_read_unsigned(ctx, gbc, range_bits,
462                                    "subexp_bits", NULL, &value,
463                                    0, MAX_UINT_BITS(range_bits));
464         if (err < 0)
465             return err;
466
467     } else {
468         err = cbs_av1_read_ns(ctx, gbc, range_max - range_offset,
469                               "subexp_final_bits", NULL, &value);
470         if (err < 0)
471             return err;
472     }
473     value += range_offset;
474
475     if (ctx->trace_enable)
476         ff_cbs_trace_syntax_element(ctx, position,
477                                     name, subscripts, "", value);
478
479     *write_to = value;
480     return err;
481 }
482
483 static int cbs_av1_write_subexp(CodedBitstreamContext *ctx, PutBitContext *pbc,
484                                 uint32_t range_max, const char *name,
485                                 const int *subscripts, uint32_t value)
486 {
487     int position, err;
488     uint32_t max_len, len, range_offset, range_bits;
489
490     if (value > range_max) {
491         av_log(ctx->log_ctx, AV_LOG_ERROR, "%s out of range: "
492                "%"PRIu32", but must be in [0,%"PRIu32"].\n",
493                name, value, range_max);
494         return AVERROR_INVALIDDATA;
495     }
496
497     if (ctx->trace_enable)
498         position = put_bits_count(pbc);
499
500     av_assert0(range_max > 0);
501     max_len = av_log2(range_max - 1) - 3;
502
503     if (value < 8) {
504         range_bits   = 3;
505         range_offset = 0;
506         len = 0;
507     } else {
508         range_bits = av_log2(value);
509         len = range_bits - 2;
510         if (len > max_len) {
511             // The top bin is combined with the one below it.
512             av_assert0(len == max_len + 1);
513             --range_bits;
514             len = max_len;
515         }
516         range_offset = 1 << range_bits;
517     }
518
519     err = cbs_av1_write_increment(ctx, pbc, 0, max_len,
520                                   "subexp_more_bits", len);
521     if (err < 0)
522         return err;
523
524     if (len < max_len) {
525         err = ff_cbs_write_unsigned(ctx, pbc, range_bits,
526                                     "subexp_bits", NULL,
527                                     value - range_offset,
528                                     0, MAX_UINT_BITS(range_bits));
529         if (err < 0)
530             return err;
531
532     } else {
533         err = cbs_av1_write_ns(ctx, pbc, range_max - range_offset,
534                                "subexp_final_bits", NULL,
535                                value - range_offset);
536         if (err < 0)
537             return err;
538     }
539
540     if (ctx->trace_enable)
541         ff_cbs_trace_syntax_element(ctx, position,
542                                     name, subscripts, "", value);
543
544     return err;
545 }
546
547
548 static int cbs_av1_tile_log2(int blksize, int target)
549 {
550     int k;
551     for (k = 0; (blksize << k) < target; k++);
552     return k;
553 }
554
555 static int cbs_av1_get_relative_dist(const AV1RawSequenceHeader *seq,
556                                      unsigned int a, unsigned int b)
557 {
558     unsigned int diff, m;
559     if (!seq->enable_order_hint)
560         return 0;
561     diff = a - b;
562     m = 1 << seq->order_hint_bits_minus_1;
563     diff = (diff & (m - 1)) - (diff & m);
564     return diff;
565 }
566
567
568 #define HEADER(name) do { \
569         ff_cbs_trace_header(ctx, name); \
570     } while (0)
571
572 #define CHECK(call) do { \
573         err = (call); \
574         if (err < 0) \
575             return err; \
576     } while (0)
577
578 #define FUNC_NAME(rw, codec, name) cbs_ ## codec ## _ ## rw ## _ ## name
579 #define FUNC_AV1(rw, name) FUNC_NAME(rw, av1, name)
580 #define FUNC(name) FUNC_AV1(READWRITE, name)
581
582 #define SUBSCRIPTS(subs, ...) (subs > 0 ? ((int[subs + 1]){ subs, __VA_ARGS__ }) : NULL)
583
584 #define fb(width, name) \
585         xf(width, name, current->name, 0, MAX_UINT_BITS(width), 0)
586 #define fc(width, name, range_min, range_max) \
587         xf(width, name, current->name, range_min, range_max, 0)
588 #define flag(name) fb(1, name)
589 #define su(width, name) \
590         xsu(width, name, current->name, 0)
591
592 #define fbs(width, name, subs, ...) \
593         xf(width, name, current->name, 0, MAX_UINT_BITS(width), subs, __VA_ARGS__)
594 #define fcs(width, name, range_min, range_max, subs, ...) \
595         xf(width, name, current->name, range_min, range_max, subs, __VA_ARGS__)
596 #define flags(name, subs, ...) \
597         xf(1, name, current->name, 0, 1, subs, __VA_ARGS__)
598 #define sus(width, name, subs, ...) \
599         xsu(width, name, current->name, subs, __VA_ARGS__)
600
601 #define fixed(width, name, value) do { \
602         av_unused uint32_t fixed_value = value; \
603         xf(width, name, fixed_value, value, value, 0); \
604     } while (0)
605
606
607 #define READ
608 #define READWRITE read
609 #define RWContext GetBitContext
610
611 #define xf(width, name, var, range_min, range_max, subs, ...) do { \
612         uint32_t value = range_min; \
613         CHECK(ff_cbs_read_unsigned(ctx, rw, width, #name, \
614                                    SUBSCRIPTS(subs, __VA_ARGS__), \
615                                    &value, range_min, range_max)); \
616         var = value; \
617     } while (0)
618
619 #define xsu(width, name, var, subs, ...) do { \
620         int32_t value = 0; \
621         CHECK(cbs_av1_read_su(ctx, rw, width, #name, \
622                               SUBSCRIPTS(subs, __VA_ARGS__), &value)); \
623         var = value; \
624     } while (0)
625
626 #define uvlc(name, range_min, range_max) do { \
627         uint32_t value = range_min; \
628         CHECK(cbs_av1_read_uvlc(ctx, rw, #name, \
629                                 &value, range_min, range_max)); \
630         current->name = value; \
631     } while (0)
632
633 #define ns(max_value, name, subs, ...) do { \
634         uint32_t value = 0; \
635         CHECK(cbs_av1_read_ns(ctx, rw, max_value, #name, \
636                               SUBSCRIPTS(subs, __VA_ARGS__), &value)); \
637         current->name = value; \
638     } while (0)
639
640 #define increment(name, min, max) do { \
641         uint32_t value = 0; \
642         CHECK(cbs_av1_read_increment(ctx, rw, min, max, #name, &value)); \
643         current->name = value; \
644     } while (0)
645
646 #define subexp(name, max, subs, ...) do { \
647         uint32_t value = 0; \
648         CHECK(cbs_av1_read_subexp(ctx, rw, max, #name, \
649                                   SUBSCRIPTS(subs, __VA_ARGS__), &value)); \
650         current->name = value; \
651     } while (0)
652
653 #define delta_q(name) do { \
654         uint8_t delta_coded; \
655         int8_t delta_q; \
656         xf(1, name.delta_coded, delta_coded, 0, 1, 0); \
657         if (delta_coded) \
658             xsu(1 + 6, name.delta_q, delta_q, 0); \
659         else \
660             delta_q = 0; \
661         current->name = delta_q; \
662     } while (0)
663
664 #define leb128(name) do { \
665         uint64_t value = 0; \
666         CHECK(cbs_av1_read_leb128(ctx, rw, #name, &value)); \
667         current->name = value; \
668     } while (0)
669
670 #define infer(name, value) do { \
671         current->name = value; \
672     } while (0)
673
674 #define byte_alignment(rw) (get_bits_count(rw) % 8)
675
676 #include "cbs_av1_syntax_template.c"
677
678 #undef READ
679 #undef READWRITE
680 #undef RWContext
681 #undef xf
682 #undef xsu
683 #undef uvlc
684 #undef leb128
685 #undef ns
686 #undef increment
687 #undef subexp
688 #undef delta_q
689 #undef leb128
690 #undef infer
691 #undef byte_alignment
692
693
694 #define WRITE
695 #define READWRITE write
696 #define RWContext PutBitContext
697
698 #define xf(width, name, var, range_min, range_max, subs, ...) do { \
699         CHECK(ff_cbs_write_unsigned(ctx, rw, width, #name, \
700                                     SUBSCRIPTS(subs, __VA_ARGS__), \
701                                     var, range_min, range_max)); \
702     } while (0)
703
704 #define xsu(width, name, var, subs, ...) do { \
705         CHECK(cbs_av1_write_su(ctx, rw, width, #name, \
706                                SUBSCRIPTS(subs, __VA_ARGS__), var)); \
707     } while (0)
708
709 #define uvlc(name, range_min, range_max) do { \
710         CHECK(cbs_av1_write_uvlc(ctx, rw, #name, current->name, \
711                                  range_min, range_max)); \
712     } while (0)
713
714 #define ns(max_value, name, subs, ...) do { \
715         CHECK(cbs_av1_write_ns(ctx, rw, max_value, #name, \
716                                SUBSCRIPTS(subs, __VA_ARGS__), \
717                                current->name)); \
718     } while (0)
719
720 #define increment(name, min, max) do { \
721         CHECK(cbs_av1_write_increment(ctx, rw, min, max, #name, \
722                                       current->name)); \
723     } while (0)
724
725 #define subexp(name, max, subs, ...) do { \
726         CHECK(cbs_av1_write_subexp(ctx, rw, max, #name, \
727                                    SUBSCRIPTS(subs, __VA_ARGS__), \
728                                    current->name)); \
729     } while (0)
730
731 #define delta_q(name) do { \
732         xf(1, name.delta_coded, current->name != 0, 0, 1, 0); \
733         if (current->name) \
734             xsu(1 + 6, name.delta_q, current->name, 0); \
735     } while (0)
736
737 #define leb128(name) do { \
738         CHECK(cbs_av1_write_leb128(ctx, rw, #name, current->name)); \
739     } while (0)
740
741 #define infer(name, value) do { \
742         if (current->name != (value)) { \
743             av_log(ctx->log_ctx, AV_LOG_WARNING, "Warning: " \
744                    "%s does not match inferred value: " \
745                    "%"PRId64", but should be %"PRId64".\n", \
746                    #name, (int64_t)current->name, (int64_t)(value)); \
747         } \
748     } while (0)
749
750 #define byte_alignment(rw) (put_bits_count(rw) % 8)
751
752 #include "cbs_av1_syntax_template.c"
753
754 #undef READ
755 #undef READWRITE
756 #undef RWContext
757 #undef xf
758 #undef xsu
759 #undef uvlc
760 #undef leb128
761 #undef ns
762 #undef increment
763 #undef subexp
764 #undef delta_q
765 #undef infer
766 #undef byte_alignment
767
768
769 static int cbs_av1_split_fragment(CodedBitstreamContext *ctx,
770                                   CodedBitstreamFragment *frag,
771                                   int header)
772 {
773     GetBitContext gbc;
774     uint8_t *data;
775     size_t size;
776     uint64_t obu_length;
777     int pos, err, trace;
778
779     // Don't include this parsing in trace output.
780     trace = ctx->trace_enable;
781     ctx->trace_enable = 0;
782
783     data = frag->data;
784     size = frag->data_size;
785
786     if (INT_MAX / 8 < size) {
787         av_log(ctx->log_ctx, AV_LOG_ERROR, "Invalid fragment: "
788                "too large (%zu bytes).\n", size);
789         err = AVERROR_INVALIDDATA;
790         goto fail;
791     }
792
793     while (size > 0) {
794         AV1RawOBUHeader header;
795         uint64_t obu_size;
796
797         init_get_bits(&gbc, data, 8 * size);
798
799         err = cbs_av1_read_obu_header(ctx, &gbc, &header);
800         if (err < 0)
801             goto fail;
802
803         if (!header.obu_has_size_field) {
804             av_log(ctx->log_ctx, AV_LOG_ERROR, "Invalid OBU for raw "
805                    "stream: size field must be present.\n");
806             err = AVERROR_INVALIDDATA;
807             goto fail;
808         }
809
810         if (get_bits_left(&gbc) < 8) {
811             av_log(ctx->log_ctx, AV_LOG_ERROR, "Invalid OBU: fragment "
812                    "too short (%zu bytes).\n", size);
813             err = AVERROR_INVALIDDATA;
814             goto fail;
815         }
816
817         err = cbs_av1_read_leb128(ctx, &gbc, "obu_size", &obu_size);
818         if (err < 0)
819             goto fail;
820
821         pos = get_bits_count(&gbc);
822         av_assert0(pos % 8 == 0 && pos / 8 <= size);
823
824         obu_length = pos / 8 + obu_size;
825
826         if (size < obu_length) {
827             av_log(ctx->log_ctx, AV_LOG_ERROR, "Invalid OBU length: "
828                    "%"PRIu64", but only %zu bytes remaining in fragment.\n",
829                    obu_length, size);
830             err = AVERROR_INVALIDDATA;
831             goto fail;
832         }
833
834         err = ff_cbs_insert_unit_data(ctx, frag, -1, header.obu_type,
835                                       data, obu_length, frag->data_ref);
836         if (err < 0)
837             goto fail;
838
839         data += obu_length;
840         size -= obu_length;
841     }
842
843     err = 0;
844 fail:
845     ctx->trace_enable = trace;
846     return err;
847 }
848
849 static void cbs_av1_free_tile_data(AV1RawTileData *td)
850 {
851     av_buffer_unref(&td->data_ref);
852 }
853
854 static void cbs_av1_free_metadata(AV1RawMetadata *md)
855 {
856     switch (md->metadata_type) {
857     case AV1_METADATA_TYPE_ITUT_T35:
858         av_buffer_unref(&md->metadata.itut_t35.payload_ref);
859         break;
860     }
861 }
862
863 static void cbs_av1_free_obu(void *unit, uint8_t *content)
864 {
865     AV1RawOBU *obu = (AV1RawOBU*)content;
866
867     switch (obu->header.obu_type) {
868     case AV1_OBU_TILE_GROUP:
869         cbs_av1_free_tile_data(&obu->obu.tile_group.tile_data);
870         break;
871     case AV1_OBU_FRAME:
872         cbs_av1_free_tile_data(&obu->obu.frame.tile_group.tile_data);
873         break;
874     case AV1_OBU_TILE_LIST:
875         cbs_av1_free_tile_data(&obu->obu.tile_list.tile_data);
876         break;
877     case AV1_OBU_METADATA:
878         cbs_av1_free_metadata(&obu->obu.metadata);
879         break;
880     }
881
882     av_freep(&obu);
883 }
884
885 static int cbs_av1_ref_tile_data(CodedBitstreamContext *ctx,
886                                  CodedBitstreamUnit *unit,
887                                  GetBitContext *gbc,
888                                  AV1RawTileData *td)
889 {
890     int pos;
891
892     pos = get_bits_count(gbc);
893     if (pos >= 8 * unit->data_size) {
894         av_log(ctx->log_ctx, AV_LOG_ERROR, "Bitstream ended before "
895                "any data in tile group (%d bits read).\n", pos);
896         return AVERROR_INVALIDDATA;
897     }
898     // Must be byte-aligned at this point.
899     av_assert0(pos % 8 == 0);
900
901     td->data_ref = av_buffer_ref(unit->data_ref);
902     if (!td->data_ref)
903         return AVERROR(ENOMEM);
904
905     td->data      = unit->data      + pos / 8;
906     td->data_size = unit->data_size - pos / 8;
907
908     return 0;
909 }
910
911 static int cbs_av1_read_unit(CodedBitstreamContext *ctx,
912                              CodedBitstreamUnit *unit)
913 {
914     CodedBitstreamAV1Context *priv = ctx->priv_data;
915     AV1RawOBU *obu;
916     GetBitContext gbc;
917     int err, start_pos, end_pos;
918
919     err = ff_cbs_alloc_unit_content(ctx, unit, sizeof(*obu),
920                                     &cbs_av1_free_obu);
921     if (err < 0)
922         return err;
923     obu = unit->content;
924
925     err = init_get_bits(&gbc, unit->data, 8 * unit->data_size);
926     if (err < 0)
927         return err;
928
929     err = cbs_av1_read_obu_header(ctx, &gbc, &obu->header);
930     if (err < 0)
931         return err;
932     av_assert0(obu->header.obu_type == unit->type);
933
934     if (obu->header.obu_has_size_field) {
935         uint64_t obu_size;
936         err = cbs_av1_read_leb128(ctx, &gbc, "obu_size", &obu_size);
937         if (err < 0)
938             return err;
939         obu->obu_size = obu_size;
940     } else {
941         if (unit->data_size < 1 + obu->header.obu_extension_flag) {
942             av_log(ctx->log_ctx, AV_LOG_ERROR, "Invalid OBU length: "
943                    "unit too short (%zu).\n", unit->data_size);
944             return AVERROR_INVALIDDATA;
945         }
946         obu->obu_size = unit->data_size - 1 - obu->header.obu_extension_flag;
947     }
948
949     start_pos = get_bits_count(&gbc);
950
951     if (obu->header.obu_extension_flag) {
952         priv->temporal_id = obu->header.temporal_id;
953         priv->spatial_id  = obu->header.temporal_id;
954
955         if (obu->header.obu_type != AV1_OBU_SEQUENCE_HEADER &&
956             obu->header.obu_type != AV1_OBU_TEMPORAL_DELIMITER &&
957             priv->operating_point_idc) {
958             int in_temporal_layer =
959                 (priv->operating_point_idc >>  priv->temporal_id    ) & 1;
960             int in_spatial_layer  =
961                 (priv->operating_point_idc >> (priv->spatial_id + 8)) & 1;
962             if (!in_temporal_layer || !in_spatial_layer) {
963                 // Decoding will drop this OBU at this operating point.
964             }
965         }
966     } else {
967         priv->temporal_id = 0;
968         priv->spatial_id  = 0;
969     }
970
971     switch (obu->header.obu_type) {
972     case AV1_OBU_SEQUENCE_HEADER:
973         {
974             err = cbs_av1_read_sequence_header_obu(ctx, &gbc,
975                                                    &obu->obu.sequence_header);
976             if (err < 0)
977                 return err;
978
979             av_buffer_unref(&priv->sequence_header_ref);
980             priv->sequence_header = NULL;
981
982             priv->sequence_header_ref = av_buffer_ref(unit->content_ref);
983             if (!priv->sequence_header_ref)
984                 return AVERROR(ENOMEM);
985             priv->sequence_header = &obu->obu.sequence_header;
986         }
987         break;
988     case AV1_OBU_TEMPORAL_DELIMITER:
989         {
990             err = cbs_av1_read_temporal_delimiter_obu(ctx, &gbc);
991             if (err < 0)
992                 return err;
993         }
994         break;
995     case AV1_OBU_FRAME_HEADER:
996     case AV1_OBU_REDUNDANT_FRAME_HEADER:
997         {
998             err = cbs_av1_read_frame_header_obu(ctx, &gbc,
999                                                 &obu->obu.frame_header);
1000             if (err < 0)
1001                 return err;
1002         }
1003         break;
1004     case AV1_OBU_TILE_GROUP:
1005         {
1006             err = cbs_av1_read_tile_group_obu(ctx, &gbc,
1007                                               &obu->obu.tile_group);
1008             if (err < 0)
1009                 return err;
1010
1011             err = cbs_av1_ref_tile_data(ctx, unit, &gbc,
1012                                         &obu->obu.tile_group.tile_data);
1013             if (err < 0)
1014                 return err;
1015         }
1016         break;
1017     case AV1_OBU_FRAME:
1018         {
1019             err = cbs_av1_read_frame_obu(ctx, &gbc, &obu->obu.frame);
1020             if (err < 0)
1021                 return err;
1022
1023             err = cbs_av1_ref_tile_data(ctx, unit, &gbc,
1024                                         &obu->obu.frame.tile_group.tile_data);
1025             if (err < 0)
1026                 return err;
1027         }
1028         break;
1029     case AV1_OBU_TILE_LIST:
1030         {
1031             err = cbs_av1_read_tile_list_obu(ctx, &gbc,
1032                                              &obu->obu.tile_list);
1033             if (err < 0)
1034                 return err;
1035
1036             err = cbs_av1_ref_tile_data(ctx, unit, &gbc,
1037                                         &obu->obu.tile_list.tile_data);
1038             if (err < 0)
1039                 return err;
1040         }
1041         break;
1042     case AV1_OBU_METADATA:
1043         {
1044             err = cbs_av1_read_metadata_obu(ctx, &gbc, &obu->obu.metadata);
1045             if (err < 0)
1046                 return err;
1047         }
1048         break;
1049     case AV1_OBU_PADDING:
1050     default:
1051         return AVERROR(ENOSYS);
1052     }
1053
1054     end_pos = get_bits_count(&gbc);
1055     av_assert0(end_pos <= unit->data_size * 8);
1056
1057     if (obu->obu_size > 0 &&
1058         obu->header.obu_type != AV1_OBU_TILE_GROUP &&
1059         obu->header.obu_type != AV1_OBU_FRAME) {
1060         err = cbs_av1_read_trailing_bits(ctx, &gbc,
1061                                          obu->obu_size * 8 + start_pos - end_pos);
1062         if (err < 0)
1063             return err;
1064     }
1065
1066     return 0;
1067 }
1068
1069 static int cbs_av1_write_obu(CodedBitstreamContext *ctx,
1070                              CodedBitstreamUnit *unit,
1071                              PutBitContext *pbc)
1072 {
1073     CodedBitstreamAV1Context *priv = ctx->priv_data;
1074     AV1RawOBU *obu = unit->content;
1075     PutBitContext pbc_tmp;
1076     AV1RawTileData *td;
1077     size_t header_size;
1078     int err, start_pos, end_pos, data_pos;
1079
1080     // OBUs in the normal bitstream format must contain a size field
1081     // in every OBU (in annex B it is optional, but we don't support
1082     // writing that).
1083     obu->header.obu_has_size_field = 1;
1084
1085     err = cbs_av1_write_obu_header(ctx, pbc, &obu->header);
1086     if (err < 0)
1087         return err;
1088
1089     if (obu->header.obu_has_size_field) {
1090         pbc_tmp = *pbc;
1091         // Add space for the size field to fill later.
1092         put_bits32(pbc, 0);
1093         put_bits32(pbc, 0);
1094     }
1095
1096     td = NULL;
1097     start_pos = put_bits_count(pbc);
1098
1099     switch (obu->header.obu_type) {
1100     case AV1_OBU_SEQUENCE_HEADER:
1101         {
1102             err = cbs_av1_write_sequence_header_obu(ctx, pbc,
1103                                                     &obu->obu.sequence_header);
1104             if (err < 0)
1105                 return err;
1106
1107             av_buffer_unref(&priv->sequence_header_ref);
1108             priv->sequence_header = NULL;
1109
1110             priv->sequence_header_ref = av_buffer_ref(unit->content_ref);
1111             if (!priv->sequence_header_ref)
1112                 return AVERROR(ENOMEM);
1113             priv->sequence_header = &obu->obu.sequence_header;
1114         }
1115         break;
1116     case AV1_OBU_TEMPORAL_DELIMITER:
1117         {
1118             err = cbs_av1_write_temporal_delimiter_obu(ctx, pbc);
1119             if (err < 0)
1120                 return err;
1121         }
1122         break;
1123     case AV1_OBU_FRAME_HEADER:
1124     case AV1_OBU_REDUNDANT_FRAME_HEADER:
1125         {
1126             err = cbs_av1_write_frame_header_obu(ctx, pbc,
1127                                                  &obu->obu.frame_header);
1128             if (err < 0)
1129                 return err;
1130         }
1131         break;
1132     case AV1_OBU_TILE_GROUP:
1133         {
1134             err = cbs_av1_write_tile_group_obu(ctx, pbc,
1135                                                &obu->obu.tile_group);
1136             if (err < 0)
1137                 return err;
1138
1139             td = &obu->obu.tile_group.tile_data;
1140         }
1141         break;
1142     case AV1_OBU_FRAME:
1143         {
1144             err = cbs_av1_write_frame_obu(ctx, pbc, &obu->obu.frame);
1145             if (err < 0)
1146                 return err;
1147
1148             td = &obu->obu.frame.tile_group.tile_data;
1149         }
1150         break;
1151     case AV1_OBU_TILE_LIST:
1152         {
1153             err = cbs_av1_write_tile_list_obu(ctx, pbc, &obu->obu.tile_list);
1154             if (err < 0)
1155                 return err;
1156
1157             td = &obu->obu.tile_list.tile_data;
1158         }
1159         break;
1160     case AV1_OBU_METADATA:
1161         {
1162             err = cbs_av1_write_metadata_obu(ctx, pbc, &obu->obu.metadata);
1163             if (err < 0)
1164                 return err;
1165         }
1166         break;
1167     case AV1_OBU_PADDING:
1168     default:
1169         return AVERROR(ENOSYS);
1170     }
1171
1172     end_pos = put_bits_count(pbc);
1173     header_size = (end_pos - start_pos + 7) / 8;
1174     if (td) {
1175         obu->obu_size = header_size + td->data_size;
1176     } else if (header_size > 0) {
1177         // Add trailing bits and recalculate.
1178         err = cbs_av1_write_trailing_bits(ctx, pbc, 8 - end_pos % 8);
1179         if (err < 0)
1180             return err;
1181         end_pos = put_bits_count(pbc);
1182         obu->obu_size = (end_pos - start_pos + 7) / 8;
1183     } else {
1184         // Empty OBU.
1185         obu->obu_size = 0;
1186     }
1187
1188     end_pos = put_bits_count(pbc);
1189     // Must now be byte-aligned.
1190     av_assert0(end_pos % 8 == 0);
1191     flush_put_bits(pbc);
1192     start_pos /= 8;
1193     end_pos   /= 8;
1194
1195     *pbc = pbc_tmp;
1196     err = cbs_av1_write_leb128(ctx, pbc, "obu_size", obu->obu_size);
1197     if (err < 0)
1198         return err;
1199
1200     data_pos = put_bits_count(pbc) / 8;
1201     flush_put_bits(pbc);
1202     av_assert0(data_pos <= start_pos);
1203
1204     if (8 * obu->obu_size > put_bits_left(pbc))
1205         return AVERROR(ENOSPC);
1206
1207     if (obu->obu_size > 0) {
1208         memmove(priv->write_buffer + data_pos,
1209                 priv->write_buffer + start_pos, header_size);
1210         skip_put_bytes(pbc, header_size);
1211
1212         if (td) {
1213             memcpy(priv->write_buffer + data_pos + header_size,
1214                    td->data, td->data_size);
1215             skip_put_bytes(pbc, td->data_size);
1216         }
1217     }
1218
1219     return 0;
1220 }
1221
1222 static int cbs_av1_write_unit(CodedBitstreamContext *ctx,
1223                               CodedBitstreamUnit *unit)
1224 {
1225     CodedBitstreamAV1Context *priv = ctx->priv_data;
1226     PutBitContext pbc;
1227     int err;
1228
1229     if (!priv->write_buffer) {
1230         // Initial write buffer size is 1MB.
1231         priv->write_buffer_size = 1024 * 1024;
1232
1233     reallocate_and_try_again:
1234         err = av_reallocp(&priv->write_buffer, priv->write_buffer_size);
1235         if (err < 0) {
1236             av_log(ctx->log_ctx, AV_LOG_ERROR, "Unable to allocate a "
1237                    "sufficiently large write buffer (last attempt "
1238                    "%zu bytes).\n", priv->write_buffer_size);
1239             return err;
1240         }
1241     }
1242
1243     init_put_bits(&pbc, priv->write_buffer, priv->write_buffer_size);
1244
1245     err = cbs_av1_write_obu(ctx, unit, &pbc);
1246     if (err == AVERROR(ENOSPC)) {
1247         // Overflow.
1248         priv->write_buffer_size *= 2;
1249         goto reallocate_and_try_again;
1250     }
1251     if (err < 0)
1252         return err;
1253
1254     // Overflow but we didn't notice.
1255     av_assert0(put_bits_count(&pbc) <= 8 * priv->write_buffer_size);
1256
1257     // OBU data must be byte-aligned.
1258     av_assert0(put_bits_count(&pbc) % 8 == 0);
1259
1260     unit->data_size = put_bits_count(&pbc) / 8;
1261     flush_put_bits(&pbc);
1262
1263     err = ff_cbs_alloc_unit_data(ctx, unit, unit->data_size);
1264     if (err < 0)
1265         return err;
1266
1267     memcpy(unit->data, priv->write_buffer, unit->data_size);
1268
1269     return 0;
1270 }
1271
1272 static int cbs_av1_assemble_fragment(CodedBitstreamContext *ctx,
1273                                      CodedBitstreamFragment *frag)
1274 {
1275     size_t size, pos;
1276     int i;
1277
1278     size = 0;
1279     for (i = 0; i < frag->nb_units; i++)
1280         size += frag->units[i].data_size;
1281
1282     frag->data_ref = av_buffer_alloc(size + AV_INPUT_BUFFER_PADDING_SIZE);
1283     if (!frag->data_ref)
1284         return AVERROR(ENOMEM);
1285     frag->data = frag->data_ref->data;
1286     memset(frag->data + size, 0, AV_INPUT_BUFFER_PADDING_SIZE);
1287
1288     pos = 0;
1289     for (i = 0; i < frag->nb_units; i++) {
1290         memcpy(frag->data + pos, frag->units[i].data,
1291                frag->units[i].data_size);
1292         pos += frag->units[i].data_size;
1293     }
1294     av_assert0(pos == size);
1295     frag->data_size = size;
1296
1297     return 0;
1298 }
1299
1300 static void cbs_av1_close(CodedBitstreamContext *ctx)
1301 {
1302     CodedBitstreamAV1Context *priv = ctx->priv_data;
1303
1304     av_buffer_unref(&priv->sequence_header_ref);
1305
1306     av_freep(&priv->write_buffer);
1307 }
1308
1309 const CodedBitstreamType ff_cbs_type_av1 = {
1310     .codec_id          = AV_CODEC_ID_AV1,
1311
1312     .priv_data_size    = sizeof(CodedBitstreamAV1Context),
1313
1314     .split_fragment    = &cbs_av1_split_fragment,
1315     .read_unit         = &cbs_av1_read_unit,
1316     .write_unit        = &cbs_av1_write_unit,
1317     .assemble_fragment = &cbs_av1_assemble_fragment,
1318
1319     .close             = &cbs_av1_close,
1320 };