]> git.sesse.net Git - ffmpeg/blobdiff - libavcodec/cbs_av1.c
avfilter/transform: Stop exporting internal functions
[ffmpeg] / libavcodec / cbs_av1.c
index 9bac9dde0912a05392fdf09dbf613950c6f23b8f..302e1f38f500325427d7c96f145e9e51eae45ef3 100644 (file)
@@ -17,6 +17,7 @@
  */
 
 #include "libavutil/avassert.h"
+#include "libavutil/opt.h"
 #include "libavutil/pixfmt.h"
 
 #include "cbs.h"
@@ -29,45 +30,67 @@ static int cbs_av1_read_uvlc(CodedBitstreamContext *ctx, GetBitContext *gbc,
                              const char *name, uint32_t *write_to,
                              uint32_t range_min, uint32_t range_max)
 {
-    uint32_t value;
-    int position, zeroes, i, j;
-    char bits[65];
+    uint32_t zeroes, bits_value, value;
+    int position;
 
     if (ctx->trace_enable)
         position = get_bits_count(gbc);
 
-    zeroes = i = 0;
+    zeroes = 0;
     while (1) {
-        if (get_bits_left(gbc) < zeroes + 1) {
+        if (get_bits_left(gbc) < 1) {
             av_log(ctx->log_ctx, AV_LOG_ERROR, "Invalid uvlc code at "
                    "%s: bitstream ended.\n", name);
             return AVERROR_INVALIDDATA;
         }
 
-        if (get_bits1(gbc)) {
-            bits[i++] = '1';
+        if (get_bits1(gbc))
             break;
-        } else {
-            bits[i++] = '0';
-            ++zeroes;
-        }
+        ++zeroes;
     }
 
     if (zeroes >= 32) {
         value = MAX_UINT_BITS(32);
     } else {
-        value = get_bits_long(gbc, zeroes);
-
-        for (j = 0; j < zeroes; j++)
-            bits[i++] = (value >> (zeroes - j - 1) & 1) ? '1' : '0';
+        if (get_bits_left(gbc) < zeroes) {
+            av_log(ctx->log_ctx, AV_LOG_ERROR, "Invalid uvlc code at "
+                   "%s: bitstream ended.\n", name);
+            return AVERROR_INVALIDDATA;
+        }
 
-        value += (1 << zeroes) - 1;
+        bits_value = get_bits_long(gbc, zeroes);
+        value = bits_value + (UINT32_C(1) << zeroes) - 1;
     }
 
     if (ctx->trace_enable) {
+        char bits[65];
+        int i, j, k;
+
+        if (zeroes >= 32) {
+            while (zeroes > 32) {
+                k = FFMIN(zeroes - 32, 32);
+                for (i = 0; i < k; i++)
+                    bits[i] = '0';
+                bits[i] = 0;
+                ff_cbs_trace_syntax_element(ctx, position, name,
+                                            NULL, bits, 0);
+                zeroes -= k;
+                position += k;
+            }
+        }
+
+        for (i = 0; i < zeroes; i++)
+            bits[i] = '0';
+        bits[i++] = '1';
+
+        if (zeroes < 32) {
+            for (j = 0; j < zeroes; j++)
+                bits[i++] = (bits_value >> (zeroes - j - 1) & 1) ? '1' : '0';
+        }
+
         bits[i] = 0;
-        ff_cbs_trace_syntax_element(ctx, position, name, NULL,
-                                    bits, value);
+        ff_cbs_trace_syntax_element(ctx, position, name,
+                                    NULL, bits, value);
     }
 
     if (value < range_min || value > range_max) {
@@ -98,15 +121,11 @@ static int cbs_av1_write_uvlc(CodedBitstreamContext *ctx, PutBitContext *pbc,
     if (ctx->trace_enable)
         position = put_bits_count(pbc);
 
-    if (value == 0) {
-        zeroes = 0;
-        put_bits(pbc, 1, 1);
-    } else {
-        zeroes = av_log2(value + 1);
-        v = value - (1 << zeroes) + 1;
-        put_bits(pbc, zeroes + 1, 1);
-        put_bits(pbc, zeroes, v);
-    }
+    zeroes = av_log2(value + 1);
+    v = value - (1U << zeroes) + 1;
+    put_bits(pbc, zeroes, 0);
+    put_bits(pbc, 1, 1);
+    put_bits(pbc, zeroes, v);
 
     if (ctx->trace_enable) {
         char bits[65];
@@ -148,6 +167,9 @@ static int cbs_av1_read_leb128(CodedBitstreamContext *ctx, GetBitContext *gbc,
             break;
     }
 
+    if (value > UINT32_MAX)
+        return AVERROR_INVALIDDATA;
+
     if (ctx->trace_enable)
         ff_cbs_trace_syntax_element(ctx, position, name, NULL, "", value);
 
@@ -185,80 +207,12 @@ static int cbs_av1_write_leb128(CodedBitstreamContext *ctx, PutBitContext *pbc,
     return 0;
 }
 
-static int cbs_av1_read_su(CodedBitstreamContext *ctx, GetBitContext *gbc,
-                           int width, const char *name,
-                           const int *subscripts, int32_t *write_to)
-{
-    uint32_t magnitude;
-    int position, sign;
-    int32_t value;
-
-    if (ctx->trace_enable)
-        position = get_bits_count(gbc);
-
-    if (get_bits_left(gbc) < width + 1) {
-        av_log(ctx->log_ctx, AV_LOG_ERROR, "Invalid signed value at "
-               "%s: bitstream ended.\n", name);
-        return AVERROR_INVALIDDATA;
-    }
-
-    magnitude = get_bits(gbc, width);
-    sign      = get_bits1(gbc);
-    value     = sign ? -(int32_t)magnitude : magnitude;
-
-    if (ctx->trace_enable) {
-        char bits[33];
-        int i;
-        for (i = 0; i < width; i++)
-            bits[i] = magnitude >> (width - i - 1) & 1 ? '1' : '0';
-        bits[i] = sign ? '1' : '0';
-        bits[i + 1] = 0;
-
-        ff_cbs_trace_syntax_element(ctx, position,
-                                    name, subscripts, bits, value);
-    }
-
-    *write_to = value;
-    return 0;
-}
-
-static int cbs_av1_write_su(CodedBitstreamContext *ctx, PutBitContext *pbc,
-                            int width, const char *name,
-                            const int *subscripts, int32_t value)
-{
-    uint32_t magnitude;
-    int sign;
-
-    if (put_bits_left(pbc) < width + 1)
-        return AVERROR(ENOSPC);
-
-    sign      = value < 0;
-    magnitude = sign ? -value : value;
-
-    if (ctx->trace_enable) {
-        char bits[33];
-        int i;
-        for (i = 0; i < width; i++)
-            bits[i] = magnitude >> (width - i - 1) & 1 ? '1' : '0';
-        bits[i] = sign ? '1' : '0';
-        bits[i + 1] = 0;
-
-        ff_cbs_trace_syntax_element(ctx, put_bits_count(pbc),
-                                    name, subscripts, bits, value);
-    }
-
-    put_bits(pbc, width, magnitude);
-    put_bits(pbc, 1, sign);
-
-    return 0;
-}
-
 static int cbs_av1_read_ns(CodedBitstreamContext *ctx, GetBitContext *gbc,
                            uint32_t n, const char *name,
                            const int *subscripts, uint32_t *write_to)
 {
-    uint32_t w, m, v, extra_bit, value;
-    int position;
+    uint32_t m, v, extra_bit, value;
+    int position, w;
 
     av_assert0(n > 0);
 
@@ -564,6 +518,17 @@ static int cbs_av1_get_relative_dist(const AV1RawSequenceHeader *seq,
     return diff;
 }
 
+static size_t cbs_av1_get_payload_bytes_left(GetBitContext *gbc)
+{
+    GetBitContext tmp = *gbc;
+    size_t size = 0;
+    for (int i = 0; get_bits_left(&tmp) >= 8; i++) {
+        if (get_bits(&tmp, 8))
+            size = i;
+    }
+    return size;
+}
+
 
 #define HEADER(name) do { \
         ff_cbs_trace_header(ctx, name); \
@@ -582,12 +547,12 @@ static int cbs_av1_get_relative_dist(const AV1RawSequenceHeader *seq,
 #define SUBSCRIPTS(subs, ...) (subs > 0 ? ((int[subs + 1]){ subs, __VA_ARGS__ }) : NULL)
 
 #define fb(width, name) \
-        xf(width, name, current->name, 0, MAX_UINT_BITS(width), 0)
+        xf(width, name, current->name, 0, MAX_UINT_BITS(width), 0)
 #define fc(width, name, range_min, range_max) \
-        xf(width, name, current->name, range_min, range_max, 0)
+        xf(width, name, current->name, range_min, range_max, 0)
 #define flag(name) fb(1, name)
 #define su(width, name) \
-        xsu(width, name, current->name, 0)
+        xsu(width, name, current->name, 0)
 
 #define fbs(width, name, subs, ...) \
         xf(width, name, current->name, 0, MAX_UINT_BITS(width), subs, __VA_ARGS__)
@@ -600,7 +565,7 @@ static int cbs_av1_get_relative_dist(const AV1RawSequenceHeader *seq,
 
 #define fixed(width, name, value) do { \
         av_unused uint32_t fixed_value = value; \
-        xf(width, name, fixed_value, value, value, 0); \
+        xf(width, name, fixed_value, value, value, 0); \
     } while (0)
 
 
@@ -609,7 +574,7 @@ static int cbs_av1_get_relative_dist(const AV1RawSequenceHeader *seq,
 #define RWContext GetBitContext
 
 #define xf(width, name, var, range_min, range_max, subs, ...) do { \
-        uint32_t value = range_min; \
+        uint32_t value; \
         CHECK(ff_cbs_read_unsigned(ctx, rw, width, #name, \
                                    SUBSCRIPTS(subs, __VA_ARGS__), \
                                    &value, range_min, range_max)); \
@@ -617,34 +582,36 @@ static int cbs_av1_get_relative_dist(const AV1RawSequenceHeader *seq,
     } while (0)
 
 #define xsu(width, name, var, subs, ...) do { \
-        int32_t value = 0; \
-        CHECK(cbs_av1_read_su(ctx, rw, width, #name, \
-                              SUBSCRIPTS(subs, __VA_ARGS__), &value)); \
+        int32_t value; \
+        CHECK(ff_cbs_read_signed(ctx, rw, width, #name, \
+                                 SUBSCRIPTS(subs, __VA_ARGS__), &value, \
+                                 MIN_INT_BITS(width), \
+                                 MAX_INT_BITS(width))); \
         var = value; \
     } while (0)
 
 #define uvlc(name, range_min, range_max) do { \
-        uint32_t value = range_min; \
+        uint32_t value; \
         CHECK(cbs_av1_read_uvlc(ctx, rw, #name, \
                                 &value, range_min, range_max)); \
         current->name = value; \
     } while (0)
 
 #define ns(max_value, name, subs, ...) do { \
-        uint32_t value = 0; \
+        uint32_t value; \
         CHECK(cbs_av1_read_ns(ctx, rw, max_value, #name, \
                               SUBSCRIPTS(subs, __VA_ARGS__), &value)); \
         current->name = value; \
     } while (0)
 
 #define increment(name, min, max) do { \
-        uint32_t value = 0; \
+        uint32_t value; \
         CHECK(cbs_av1_read_increment(ctx, rw, min, max, #name, &value)); \
         current->name = value; \
     } while (0)
 
 #define subexp(name, max, subs, ...) do { \
-        uint32_t value = 0; \
+        uint32_t value; \
         CHECK(cbs_av1_read_subexp(ctx, rw, max, #name, \
                                   SUBSCRIPTS(subs, __VA_ARGS__), &value)); \
         current->name = value; \
@@ -653,16 +620,16 @@ static int cbs_av1_get_relative_dist(const AV1RawSequenceHeader *seq,
 #define delta_q(name) do { \
         uint8_t delta_coded; \
         int8_t delta_q; \
-        xf(1, name.delta_coded, delta_coded, 0, 1, 0); \
+        xf(1, name.delta_coded, delta_coded, 0, 1, 0); \
         if (delta_coded) \
-            xsu(1 + 6, name.delta_q, delta_q, 0); \
+            xsu(1 + 6, name.delta_q, delta_q, 0); \
         else \
             delta_q = 0; \
         current->name = delta_q; \
     } while (0)
 
 #define leb128(name) do { \
-        uint64_t value = 0; \
+        uint64_t value; \
         CHECK(cbs_av1_read_leb128(ctx, rw, #name, &value)); \
         current->name = value; \
     } while (0)
@@ -681,7 +648,6 @@ static int cbs_av1_get_relative_dist(const AV1RawSequenceHeader *seq,
 #undef xf
 #undef xsu
 #undef uvlc
-#undef leb128
 #undef ns
 #undef increment
 #undef subexp
@@ -702,8 +668,10 @@ static int cbs_av1_get_relative_dist(const AV1RawSequenceHeader *seq,
     } while (0)
 
 #define xsu(width, name, var, subs, ...) do { \
-        CHECK(cbs_av1_write_su(ctx, rw, width, #name, \
-                               SUBSCRIPTS(subs, __VA_ARGS__), var)); \
+        CHECK(ff_cbs_write_signed(ctx, rw, width, #name, \
+                                  SUBSCRIPTS(subs, __VA_ARGS__), var, \
+                                  MIN_INT_BITS(width), \
+                                  MAX_INT_BITS(width))); \
     } while (0)
 
 #define uvlc(name, range_min, range_max) do { \
@@ -729,9 +697,9 @@ static int cbs_av1_get_relative_dist(const AV1RawSequenceHeader *seq,
     } while (0)
 
 #define delta_q(name) do { \
-        xf(1, name.delta_coded, current->name != 0, 0, 1, 0); \
+        xf(1, name.delta_coded, current->name != 0, 0, 1, 0); \
         if (current->name) \
-            xsu(1 + 6, name.delta_q, current->name, 0); \
+            xsu(1 + 6, name.delta_q, current->name, 0); \
     } while (0)
 
 #define leb128(name) do { \
@@ -740,10 +708,11 @@ static int cbs_av1_get_relative_dist(const AV1RawSequenceHeader *seq,
 
 #define infer(name, value) do { \
         if (current->name != (value)) { \
-            av_log(ctx->log_ctx, AV_LOG_WARNING, "Warning: " \
+            av_log(ctx->log_ctx, AV_LOG_ERROR, \
                    "%s does not match inferred value: " \
                    "%"PRId64", but should be %"PRId64".\n", \
                    #name, (int64_t)current->name, (int64_t)(value)); \
+            return AVERROR_INVALIDDATA; \
         } \
     } while (0)
 
@@ -751,17 +720,17 @@ static int cbs_av1_get_relative_dist(const AV1RawSequenceHeader *seq,
 
 #include "cbs_av1_syntax_template.c"
 
-#undef READ
+#undef WRITE
 #undef READWRITE
 #undef RWContext
 #undef xf
 #undef xsu
 #undef uvlc
-#undef leb128
 #undef ns
 #undef increment
 #undef subexp
 #undef delta_q
+#undef leb128
 #undef infer
 #undef byte_alignment
 
@@ -785,11 +754,44 @@ static int cbs_av1_split_fragment(CodedBitstreamContext *ctx,
 
     if (INT_MAX / 8 < size) {
         av_log(ctx->log_ctx, AV_LOG_ERROR, "Invalid fragment: "
-               "too large (%zu bytes).\n", size);
+               "too large (%"SIZE_SPECIFIER" bytes).\n", size);
         err = AVERROR_INVALIDDATA;
         goto fail;
     }
 
+    if (header && size && data[0] & 0x80) {
+        // first bit is nonzero, the extradata does not consist purely of
+        // OBUs. Expect MP4/Matroska AV1CodecConfigurationRecord
+        int config_record_version = data[0] & 0x7f;
+
+        if (config_record_version != 1) {
+            av_log(ctx->log_ctx, AV_LOG_ERROR,
+                   "Unknown version %d of AV1CodecConfigurationRecord "
+                   "found!\n",
+                   config_record_version);
+            err = AVERROR_INVALIDDATA;
+            goto fail;
+        }
+
+        if (size <= 4) {
+            if (size < 4) {
+                av_log(ctx->log_ctx, AV_LOG_WARNING,
+                       "Undersized AV1CodecConfigurationRecord v%d found!\n",
+                       config_record_version);
+                err = AVERROR_INVALIDDATA;
+                goto fail;
+            }
+
+            goto success;
+        }
+
+        // In AV1CodecConfigurationRecord v1, actual OBUs start after
+        // four bytes. Thus set the offset as required for properly
+        // parsing them.
+        data += 4;
+        size -= 4;
+    }
+
     while (size > 0) {
         AV1RawOBUHeader header;
         uint64_t obu_size;
@@ -800,23 +802,18 @@ static int cbs_av1_split_fragment(CodedBitstreamContext *ctx,
         if (err < 0)
             goto fail;
 
-        if (!header.obu_has_size_field) {
-            av_log(ctx->log_ctx, AV_LOG_ERROR, "Invalid OBU for raw "
-                   "stream: size field must be present.\n");
-            err = AVERROR_INVALIDDATA;
-            goto fail;
-        }
-
-        if (get_bits_left(&gbc) < 8) {
-            av_log(ctx->log_ctx, AV_LOG_ERROR, "Invalid OBU: fragment "
-                   "too short (%zu bytes).\n", size);
-            err = AVERROR_INVALIDDATA;
-            goto fail;
-        }
-
-        err = cbs_av1_read_leb128(ctx, &gbc, "obu_size", &obu_size);
-        if (err < 0)
-            goto fail;
+        if (header.obu_has_size_field) {
+            if (get_bits_left(&gbc) < 8) {
+                av_log(ctx->log_ctx, AV_LOG_ERROR, "Invalid OBU: fragment "
+                       "too short (%"SIZE_SPECIFIER" bytes).\n", size);
+                err = AVERROR_INVALIDDATA;
+                goto fail;
+            }
+            err = cbs_av1_read_leb128(ctx, &gbc, "obu_size", &obu_size);
+            if (err < 0)
+                goto fail;
+        } else
+            obu_size = size - 1 - header.obu_extension_flag;
 
         pos = get_bits_count(&gbc);
         av_assert0(pos % 8 == 0 && pos / 8 <= size);
@@ -825,13 +822,13 @@ static int cbs_av1_split_fragment(CodedBitstreamContext *ctx,
 
         if (size < obu_length) {
             av_log(ctx->log_ctx, AV_LOG_ERROR, "Invalid OBU length: "
-                   "%"PRIu64", but only %zu bytes remaining in fragment.\n",
+                   "%"PRIu64", but only %"SIZE_SPECIFIER" bytes remaining in fragment.\n",
                    obu_length, size);
             err = AVERROR_INVALIDDATA;
             goto fail;
         }
 
-        err = ff_cbs_insert_unit_data(ctx, frag, -1, header.obu_type,
+        err = ff_cbs_insert_unit_data(frag, -1, header.obu_type,
                                       data, obu_length, frag->data_ref);
         if (err < 0)
             goto fail;
@@ -840,48 +837,13 @@ static int cbs_av1_split_fragment(CodedBitstreamContext *ctx,
         size -= obu_length;
     }
 
+success:
     err = 0;
 fail:
     ctx->trace_enable = trace;
     return err;
 }
 
-static void cbs_av1_free_tile_data(AV1RawTileData *td)
-{
-    av_buffer_unref(&td->data_ref);
-}
-
-static void cbs_av1_free_metadata(AV1RawMetadata *md)
-{
-    switch (md->metadata_type) {
-    case AV1_METADATA_TYPE_ITUT_T35:
-        av_buffer_unref(&md->metadata.itut_t35.payload_ref);
-        break;
-    }
-}
-
-static void cbs_av1_free_obu(void *unit, uint8_t *content)
-{
-    AV1RawOBU *obu = (AV1RawOBU*)content;
-
-    switch (obu->header.obu_type) {
-    case AV1_OBU_TILE_GROUP:
-        cbs_av1_free_tile_data(&obu->obu.tile_group.tile_data);
-        break;
-    case AV1_OBU_FRAME:
-        cbs_av1_free_tile_data(&obu->obu.frame.tile_group.tile_data);
-        break;
-    case AV1_OBU_TILE_LIST:
-        cbs_av1_free_tile_data(&obu->obu.tile_list.tile_data);
-        break;
-    case AV1_OBU_METADATA:
-        cbs_av1_free_metadata(&obu->obu.metadata);
-        break;
-    }
-
-    av_freep(&obu);
-}
-
 static int cbs_av1_ref_tile_data(CodedBitstreamContext *ctx,
                                  CodedBitstreamUnit *unit,
                                  GetBitContext *gbc,
@@ -916,8 +878,7 @@ static int cbs_av1_read_unit(CodedBitstreamContext *ctx,
     GetBitContext gbc;
     int err, start_pos, end_pos;
 
-    err = ff_cbs_alloc_unit_content(ctx, unit, sizeof(*obu),
-                                    &cbs_av1_free_obu);
+    err = ff_cbs_alloc_unit_content2(ctx, unit);
     if (err < 0)
         return err;
     obu = unit->content;
@@ -940,7 +901,7 @@ static int cbs_av1_read_unit(CodedBitstreamContext *ctx,
     } else {
         if (unit->data_size < 1 + obu->header.obu_extension_flag) {
             av_log(ctx->log_ctx, AV_LOG_ERROR, "Invalid OBU length: "
-                   "unit too short (%zu).\n", unit->data_size);
+                   "unit too short (%"SIZE_SPECIFIER").\n", unit->data_size);
             return AVERROR_INVALIDDATA;
         }
         obu->obu_size = unit->data_size - 1 - obu->header.obu_extension_flag;
@@ -949,9 +910,6 @@ static int cbs_av1_read_unit(CodedBitstreamContext *ctx,
     start_pos = get_bits_count(&gbc);
 
     if (obu->header.obu_extension_flag) {
-        priv->temporal_id = obu->header.temporal_id;
-        priv->spatial_id  = obu->header.temporal_id;
-
         if (obu->header.obu_type != AV1_OBU_SEQUENCE_HEADER &&
             obu->header.obu_type != AV1_OBU_TEMPORAL_DELIMITER &&
             priv->operating_point_idc) {
@@ -960,12 +918,9 @@ static int cbs_av1_read_unit(CodedBitstreamContext *ctx,
             int in_spatial_layer  =
                 (priv->operating_point_idc >> (priv->spatial_id + 8)) & 1;
             if (!in_temporal_layer || !in_spatial_layer) {
-                // Decoding will drop this OBU at this operating point.
+                return AVERROR(EAGAIN); // drop_obu()
             }
         }
-    } else {
-        priv->temporal_id = 0;
-        priv->spatial_id  = 0;
     }
 
     switch (obu->header.obu_type) {
@@ -976,6 +931,18 @@ static int cbs_av1_read_unit(CodedBitstreamContext *ctx,
             if (err < 0)
                 return err;
 
+            if (priv->operating_point >= 0) {
+                AV1RawSequenceHeader *sequence_header = &obu->obu.sequence_header;
+
+                if (priv->operating_point > sequence_header->operating_points_cnt_minus_1) {
+                    av_log(ctx->log_ctx, AV_LOG_ERROR, "Invalid Operating Point %d requested. "
+                                                       "Must not be higher than %u.\n",
+                           priv->operating_point, sequence_header->operating_points_cnt_minus_1);
+                    return AVERROR(EINVAL);
+                }
+                priv->operating_point_idc = sequence_header->operating_point_idc[priv->operating_point];
+            }
+
             av_buffer_unref(&priv->sequence_header_ref);
             priv->sequence_header = NULL;
 
@@ -996,7 +963,10 @@ static int cbs_av1_read_unit(CodedBitstreamContext *ctx,
     case AV1_OBU_REDUNDANT_FRAME_HEADER:
         {
             err = cbs_av1_read_frame_header_obu(ctx, &gbc,
-                                                &obu->obu.frame_header);
+                                                &obu->obu.frame_header,
+                                                obu->header.obu_type ==
+                                                AV1_OBU_REDUNDANT_FRAME_HEADER,
+                                                unit->data_ref);
             if (err < 0)
                 return err;
         }
@@ -1016,7 +986,8 @@ static int cbs_av1_read_unit(CodedBitstreamContext *ctx,
         break;
     case AV1_OBU_FRAME:
         {
-            err = cbs_av1_read_frame_obu(ctx, &gbc, &obu->obu.frame);
+            err = cbs_av1_read_frame_obu(ctx, &gbc, &obu->obu.frame,
+                                         unit->data_ref);
             if (err < 0)
                 return err;
 
@@ -1047,6 +1018,12 @@ static int cbs_av1_read_unit(CodedBitstreamContext *ctx,
         }
         break;
     case AV1_OBU_PADDING:
+        {
+            err = cbs_av1_read_padding_obu(ctx, &gbc, &obu->obu.padding);
+            if (err < 0)
+                return err;
+        }
+        break;
     default:
         return AVERROR(ENOSYS);
     }
@@ -1056,9 +1033,14 @@ static int cbs_av1_read_unit(CodedBitstreamContext *ctx,
 
     if (obu->obu_size > 0 &&
         obu->header.obu_type != AV1_OBU_TILE_GROUP &&
+        obu->header.obu_type != AV1_OBU_TILE_LIST &&
         obu->header.obu_type != AV1_OBU_FRAME) {
-        err = cbs_av1_read_trailing_bits(ctx, &gbc,
-                                         obu->obu_size * 8 + start_pos - end_pos);
+        int nb_bits = obu->obu_size * 8 + start_pos - end_pos;
+
+        if (nb_bits <= 0)
+            return AVERROR_INVALIDDATA;
+
+        err = cbs_av1_read_trailing_bits(ctx, &gbc, nb_bits);
         if (err < 0)
             return err;
     }
@@ -1107,6 +1089,10 @@ static int cbs_av1_write_obu(CodedBitstreamContext *ctx,
             av_buffer_unref(&priv->sequence_header_ref);
             priv->sequence_header = NULL;
 
+            err = ff_cbs_make_unit_refcounted(ctx, unit);
+            if (err < 0)
+                return err;
+
             priv->sequence_header_ref = av_buffer_ref(unit->content_ref);
             if (!priv->sequence_header_ref)
                 return AVERROR(ENOMEM);
@@ -1124,7 +1110,10 @@ static int cbs_av1_write_obu(CodedBitstreamContext *ctx,
     case AV1_OBU_REDUNDANT_FRAME_HEADER:
         {
             err = cbs_av1_write_frame_header_obu(ctx, pbc,
-                                                 &obu->obu.frame_header);
+                                                 &obu->obu.frame_header,
+                                                 obu->header.obu_type ==
+                                                 AV1_OBU_REDUNDANT_FRAME_HEADER,
+                                                 NULL);
             if (err < 0)
                 return err;
         }
@@ -1141,7 +1130,7 @@ static int cbs_av1_write_obu(CodedBitstreamContext *ctx,
         break;
     case AV1_OBU_FRAME:
         {
-            err = cbs_av1_write_frame_obu(ctx, pbc, &obu->obu.frame);
+            err = cbs_av1_write_frame_obu(ctx, pbc, &obu->obu.frame, NULL);
             if (err < 0)
                 return err;
 
@@ -1165,6 +1154,12 @@ static int cbs_av1_write_obu(CodedBitstreamContext *ctx,
         }
         break;
     case AV1_OBU_PADDING:
+        {
+            err = cbs_av1_write_padding_obu(ctx, pbc, &obu->obu.padding);
+            if (err < 0)
+                return err;
+        }
+        break;
     default:
         return AVERROR(ENOSYS);
     }
@@ -1179,7 +1174,7 @@ static int cbs_av1_write_obu(CodedBitstreamContext *ctx,
         if (err < 0)
             return err;
         end_pos = put_bits_count(pbc);
-        obu->obu_size = (end_pos - start_pos + 7) / 8;
+        obu->obu_size = header_size = (end_pos - start_pos + 7) / 8;
     } else {
         // Empty OBU.
         obu->obu_size = 0;
@@ -1205,66 +1200,19 @@ static int cbs_av1_write_obu(CodedBitstreamContext *ctx,
         return AVERROR(ENOSPC);
 
     if (obu->obu_size > 0) {
-        memmove(priv->write_buffer + data_pos,
-                priv->write_buffer + start_pos, header_size);
+        memmove(pbc->buf + data_pos,
+                pbc->buf + start_pos, header_size);
         skip_put_bytes(pbc, header_size);
 
         if (td) {
-            memcpy(priv->write_buffer + data_pos + header_size,
+            memcpy(pbc->buf + data_pos + header_size,
                    td->data, td->data_size);
             skip_put_bytes(pbc, td->data_size);
         }
     }
 
-    return 0;
-}
-
-static int cbs_av1_write_unit(CodedBitstreamContext *ctx,
-                              CodedBitstreamUnit *unit)
-{
-    CodedBitstreamAV1Context *priv = ctx->priv_data;
-    PutBitContext pbc;
-    int err;
-
-    if (!priv->write_buffer) {
-        // Initial write buffer size is 1MB.
-        priv->write_buffer_size = 1024 * 1024;
-
-    reallocate_and_try_again:
-        err = av_reallocp(&priv->write_buffer, priv->write_buffer_size);
-        if (err < 0) {
-            av_log(ctx->log_ctx, AV_LOG_ERROR, "Unable to allocate a "
-                   "sufficiently large write buffer (last attempt "
-                   "%zu bytes).\n", priv->write_buffer_size);
-            return err;
-        }
-    }
-
-    init_put_bits(&pbc, priv->write_buffer, priv->write_buffer_size);
-
-    err = cbs_av1_write_obu(ctx, unit, &pbc);
-    if (err == AVERROR(ENOSPC)) {
-        // Overflow.
-        priv->write_buffer_size *= 2;
-        goto reallocate_and_try_again;
-    }
-    if (err < 0)
-        return err;
-
-    // Overflow but we didn't notice.
-    av_assert0(put_bits_count(&pbc) <= 8 * priv->write_buffer_size);
-
     // OBU data must be byte-aligned.
-    av_assert0(put_bits_count(&pbc) % 8 == 0);
-
-    unit->data_size = put_bits_count(&pbc) / 8;
-    flush_put_bits(&pbc);
-
-    err = ff_cbs_alloc_unit_data(ctx, unit, unit->data_size);
-    if (err < 0)
-        return err;
-
-    memcpy(unit->data, priv->write_buffer, unit->data_size);
+    av_assert0(put_bits_count(pbc) % 8 == 0);
 
     return 0;
 }
@@ -1297,24 +1245,92 @@ static int cbs_av1_assemble_fragment(CodedBitstreamContext *ctx,
     return 0;
 }
 
+static void cbs_av1_flush(CodedBitstreamContext *ctx)
+{
+    CodedBitstreamAV1Context *priv = ctx->priv_data;
+
+    av_buffer_unref(&priv->frame_header_ref);
+    priv->sequence_header = NULL;
+    priv->frame_header = NULL;
+
+    memset(priv->ref, 0, sizeof(priv->ref));
+    priv->operating_point_idc = 0;
+    priv->seen_frame_header = 0;
+    priv->tile_num = 0;
+}
+
 static void cbs_av1_close(CodedBitstreamContext *ctx)
 {
     CodedBitstreamAV1Context *priv = ctx->priv_data;
 
     av_buffer_unref(&priv->sequence_header_ref);
+    av_buffer_unref(&priv->frame_header_ref);
+}
+
+static void cbs_av1_free_metadata(void *unit, uint8_t *content)
+{
+    AV1RawOBU *obu = (AV1RawOBU*)content;
+    AV1RawMetadata *md;
+
+    av_assert0(obu->header.obu_type == AV1_OBU_METADATA);
+    md = &obu->obu.metadata;
 
-    av_freep(&priv->write_buffer);
+    switch (md->metadata_type) {
+    case AV1_METADATA_TYPE_ITUT_T35:
+        av_buffer_unref(&md->metadata.itut_t35.payload_ref);
+        break;
+    }
+    av_free(content);
 }
 
+static const CodedBitstreamUnitTypeDescriptor cbs_av1_unit_types[] = {
+    CBS_UNIT_TYPE_POD(AV1_OBU_SEQUENCE_HEADER,        AV1RawOBU),
+    CBS_UNIT_TYPE_POD(AV1_OBU_TEMPORAL_DELIMITER,     AV1RawOBU),
+    CBS_UNIT_TYPE_POD(AV1_OBU_FRAME_HEADER,           AV1RawOBU),
+    CBS_UNIT_TYPE_POD(AV1_OBU_REDUNDANT_FRAME_HEADER, AV1RawOBU),
+
+    CBS_UNIT_TYPE_INTERNAL_REF(AV1_OBU_TILE_GROUP, AV1RawOBU,
+                               obu.tile_group.tile_data.data),
+    CBS_UNIT_TYPE_INTERNAL_REF(AV1_OBU_FRAME,      AV1RawOBU,
+                               obu.frame.tile_group.tile_data.data),
+    CBS_UNIT_TYPE_INTERNAL_REF(AV1_OBU_TILE_LIST,  AV1RawOBU,
+                               obu.tile_list.tile_data.data),
+    CBS_UNIT_TYPE_INTERNAL_REF(AV1_OBU_PADDING,    AV1RawOBU,
+                               obu.padding.payload),
+
+    CBS_UNIT_TYPE_COMPLEX(AV1_OBU_METADATA, AV1RawOBU,
+                          &cbs_av1_free_metadata),
+
+    CBS_UNIT_TYPE_END_OF_LIST
+};
+
+#define OFFSET(x) offsetof(CodedBitstreamAV1Context, x)
+static const AVOption cbs_av1_options[] = {
+    { "operating_point",  "Set operating point to select layers to parse from a scalable bitstream",
+                          OFFSET(operating_point), AV_OPT_TYPE_INT, { .i64 = -1 }, -1, AV1_MAX_OPERATING_POINTS - 1, 0 },
+    { NULL }
+};
+
+static const AVClass cbs_av1_class = {
+    .class_name = "cbs_av1",
+    .item_name  = av_default_item_name,
+    .option     = cbs_av1_options,
+    .version    = LIBAVUTIL_VERSION_INT,
+};
+
 const CodedBitstreamType ff_cbs_type_av1 = {
     .codec_id          = AV_CODEC_ID_AV1,
 
+    .priv_class        = &cbs_av1_class,
     .priv_data_size    = sizeof(CodedBitstreamAV1Context),
 
+    .unit_types        = cbs_av1_unit_types,
+
     .split_fragment    = &cbs_av1_split_fragment,
     .read_unit         = &cbs_av1_read_unit,
-    .write_unit        = &cbs_av1_write_unit,
+    .write_unit        = &cbs_av1_write_obu,
     .assemble_fragment = &cbs_av1_assemble_fragment,
 
+    .flush             = &cbs_av1_flush,
     .close             = &cbs_av1_close,
 };