]> git.sesse.net Git - nageru/commitdiff
Add support for 10-bit AV1 encoding.
authorSteinar H. Gunderson <sgunderson@bigfoot.com>
Thu, 21 Jul 2022 21:02:05 +0000 (23:02 +0200)
committerSteinar H. Gunderson <sgunderson@bigfoot.com>
Thu, 21 Jul 2022 21:02:05 +0000 (23:02 +0200)
nageru/av1_encoder.cpp
nageru/flags.cpp
nageru/flags.h
shared/memcpy_interleaved.cpp
shared/memcpy_interleaved.h

index a0dbdd1a8e17a4e3297a42fa32c67af2bf312ec9..568cd5f9acc3673c425f393fd145416b6e8782e8 100644 (file)
@@ -63,7 +63,7 @@ AV1Encoder::AV1Encoder(const AVOutputFormat *oformat)
                        av1_latency_histogram.init("av1");
                });
 
-       const size_t bytes_per_pixel = 1;  // TODO: 10-bit support.
+       const size_t bytes_per_pixel = global_flags.av1_bit_depth > 8 ? 2 : 1;
        frame_pool.reset(new uint8_t[global_flags.width * global_flags.height * 2 * bytes_per_pixel * AV1_QUEUE_LENGTH]);
        for (unsigned i = 0; i < AV1_QUEUE_LENGTH; ++i) {
                free_frames.push(frame_pool.get() + i * (global_flags.width * global_flags.height * 2 * bytes_per_pixel));
@@ -102,8 +102,9 @@ void AV1Encoder::add_frame(int64_t pts, int64_t duration, YCbCrLumaCoefficients
 
        // Since we're copying anyway, we can unpack from NV12 to fully planar on the fly.
        // SVT-AV1 makes its own copy, though, and it would have been nice to avoid the
-       // double-copy.
-       size_t bytes_per_pixel = 1;  // TODO: 10-bit support.
+       // double-copy (and also perhaps let the GPU do the 10-bit compression SVT-AV1
+       // wants, instead of doing it on the CPU).
+       const size_t bytes_per_pixel = global_flags.av1_bit_depth > 8 ? 2 : 1;
        size_t frame_size = global_flags.width * global_flags.height * bytes_per_pixel;
        assert(global_flags.width % 2 == 0);
        assert(global_flags.height % 2 == 0);
@@ -111,7 +112,14 @@ void AV1Encoder::add_frame(int64_t pts, int64_t duration, YCbCrLumaCoefficients
        uint8_t *cb = y + frame_size;
        uint8_t *cr = cb + frame_size / 4;
        memcpy(y, data, frame_size);
-       memcpy_interleaved(cb, cr, data + frame_size, frame_size / 2);
+       if (global_flags.av1_bit_depth == 8) {
+               memcpy_interleaved(cb, cr, data + frame_size, frame_size / 2);
+       } else {
+               const uint16_t *src = reinterpret_cast<const uint16_t *>(data + frame_size);
+               uint16_t *cb16 = reinterpret_cast<uint16_t *>(cb);
+               uint16_t *cr16 = reinterpret_cast<uint16_t *>(cr);
+               memcpy_interleaved_word(cb16, cr16, src, frame_size / 4);
+       }
 
        {
                lock_guard<mutex> lock(mu);
@@ -136,7 +144,7 @@ void AV1Encoder::init_av1()
        config.source_height = global_flags.height;
        config.frame_rate_numerator = global_flags.av1_fps_num;
        config.frame_rate_denominator = global_flags.av1_fps_den;
-       config.encoder_bit_depth = 8;  // TODO: 10-bit support.
+       config.encoder_bit_depth = global_flags.av1_bit_depth;
        config.rate_control_mode = 2;  // CBR.
        config.pred_structure = 1;  // PRED_LOW_DELAY_B (needed for CBR).
        config.target_bit_rate = global_flags.av1_bitrate * 1000;
@@ -273,23 +281,25 @@ void AV1Encoder::encoder_thread_func()
 void AV1Encoder::encode_frame(AV1Encoder::QueuedFrame qf)
 {
        if (qf.data) {
+               const size_t bytes_per_pixel = global_flags.av1_bit_depth > 8 ? 2 : 1;
+
                EbSvtIOFormat pic;
                pic.luma = qf.data;     
-               pic.cb = pic.luma + global_flags.width * global_flags.height;
-               pic.cr = pic.cb + global_flags.width * global_flags.height / 4;
-               pic.y_stride = global_flags.width;
-               pic.cb_stride = global_flags.width / 2;
-               pic.cr_stride = global_flags.width / 2;
+               pic.cb = pic.luma + global_flags.width * global_flags.height * bytes_per_pixel;
+               pic.cr = pic.cb + (global_flags.width * global_flags.height / 4) * bytes_per_pixel;
+               pic.y_stride = global_flags.width;  // In pixels, so no bytes_per_pixel.
+               pic.cb_stride = global_flags.width / 2;  // Likewise.
+               pic.cr_stride = global_flags.width / 2;  // Likewise.
                pic.width = global_flags.width;
                pic.height = global_flags.height;
                pic.origin_x = 0;
                pic.origin_y = 0;
                pic.color_fmt = EB_YUV420;
-               pic.bit_depth = EB_EIGHT_BIT;  // TODO: 10-bit.
+               pic.bit_depth = global_flags.av1_bit_depth > 8 ? EB_TEN_BIT : EB_EIGHT_BIT;
 
                EbBufferHeaderType hdr;
                hdr.p_buffer      = reinterpret_cast<uint8_t *>(&pic);
-               hdr.n_alloc_len   = global_flags.width * global_flags.height * 3 / 2;  // TODO: 10-bit.
+               hdr.n_alloc_len   = (global_flags.width * global_flags.height * 3 / 2) * bytes_per_pixel;
                hdr.n_filled_len  = hdr.n_alloc_len;
                hdr.n_tick_count  = 0;
                hdr.p_app_private = reinterpret_cast<void *>(intptr_t(qf.duration));
index 1878038a46fb58ed2bc39b9e2a0994169c0bd2c0..7cc52b7f7fe1a055a893b9816489e9925ce4c1bb 100644 (file)
@@ -618,6 +618,7 @@ void parse_flags(Program program, int argc, char * const argv[])
                case OPTION_10_BIT_OUTPUT:
                        global_flags.ten_bit_output = true;
                        global_flags.x264_bit_depth = 10;
+                       global_flags.av1_bit_depth = 10;
                        break;
                case OPTION_INPUT_YCBCR_INTERPRETATION: {
                        char *ptr = strchr(optarg, ',');
@@ -698,10 +699,7 @@ void parse_flags(Program program, int argc, char * const argv[])
        }
        if (global_flags.ten_bit_output) {
                global_flags.x264_video_to_disk = true;  // No 10-bit Quick Sync support.
-               if (global_flags.av1_video_to_http) {
-                       fprintf(stderr, "ERROR: 10-bit AV1 output is not supported yet\n");
-                       exit(1);
-               } else {
+               if (!global_flags.av1_video_to_http) {
                        global_flags.x264_video_to_http = true;
                }
        }
index 96a4160f80515c4a2c815f0084f60f08d5b21f50..2c02848b2ff997a4e882f41daba75b6e032cc153 100644 (file)
@@ -79,12 +79,13 @@ struct Flags {
        bool display_timecode_on_stdout = false;
        bool enable_quick_cut_keys = false;
        bool ten_bit_input = false;
-       bool ten_bit_output = false;  // Implies x264_video_to_disk == true and x264_bit_depth == 10.
+       bool ten_bit_output = false;  // Implies x264_video_to_disk == true and {x264,av1}_bit_depth == 10.
        YCbCrInterpretation ycbcr_interpretation[MAX_VIDEO_CARDS];
        bool transcode_video = true;  // Kaeru only.
        bool transcode_audio = true;  // Kaeru only.
        bool enable_audio = true;  // Kaeru only. If false, then transcode_audio is also false.
        int x264_bit_depth = 8;  // Not user-settable.
+       int av1_bit_depth = 8;  // Not user-settable.
        bool use_zerocopy = false;  // Not user-settable.
        bool fullscreen = false;
        std::map<unsigned, unsigned> card_to_mjpeg_stream_export;  // If a card is not in the map, it is not exported.
index 2de1ecec3b725498ebc6d92d298e33b332723bd3..dd7a59656608098ed451998838fcb7a713c5a6af 100644 (file)
@@ -24,8 +24,22 @@ void memcpy_interleaved_slow(uint8_t *dest1, uint8_t *dest2, const uint8_t *src,
        }
 }
 
+void memcpy_interleaved_word_slow(uint16_t *dest1, uint16_t *dest2, const uint16_t *src, size_t n)
+{
+       assert(n % 2 == 0);
+       uint16_t *dptr1 = dest1;
+       uint16_t *dptr2 = dest2;
+
+       for (size_t i = 0; i < n; i += 2) {
+               *dptr1++ = *src++;
+               *dptr2++ = *src++;
+       }
+}
+
 #if HAS_MULTIVERSIONING
 
+// uint8_t version.
+
 __attribute__((target("default")))
 size_t memcpy_interleaved_fastpath_core(uint8_t *dest1, uint8_t *dest2, const uint8_t *src, const uint8_t *limit);
 
@@ -112,6 +126,100 @@ size_t memcpy_interleaved_fastpath(uint8_t *dest1, uint8_t *dest2, const uint8_t
        return consumed + memcpy_interleaved_fastpath_core(dest1, dest2, src, limit);
 }
 
+// uint16_t version.
+
+__attribute__((target("default")))
+size_t memcpy_interleaved_word_fastpath_core(uint16_t *dest1, uint16_t *dest2, const uint16_t *src, const uint16_t *limit);
+
+__attribute__((target("avx2")))
+size_t memcpy_interleaved_word_fastpath_core(uint16_t *dest1, uint16_t *dest2, const uint16_t *src, const uint16_t *limit);
+
+__attribute__((target("default")))
+size_t memcpy_interleaved_word_fastpath_core(uint16_t *dest1, uint16_t *dest2, const uint16_t *src, const uint16_t *limit)
+{
+       // No fast path supported unless we have AVX2.
+       return 0;
+}
+
+__attribute__((target("avx2")))
+size_t memcpy_interleaved_word_fastpath_core(uint16_t *dest1, uint16_t *dest2, const uint16_t *src, const uint16_t *limit)
+{
+       size_t consumed = 0;
+       const __m256i *__restrict in = (const __m256i *)src;
+       __m256i *__restrict out1 = (__m256i *)dest1;
+       __m256i *__restrict out2 = (__m256i *)dest2;
+
+       __m256i shuffle_cw = _mm256_set_epi8(
+               15, 14, 11, 10, 7, 6, 3, 2, 13, 12, 9, 8, 5, 4, 1, 0,
+               15, 14, 11, 10, 7, 6, 3, 2, 13, 12, 9, 8, 5, 4, 1, 0);
+       while (in < (const __m256i *)limit) {
+               // Note: Each element in these comments is 16 bits long (lanes are 2x128 bits).
+               __m256i data1 = _mm256_stream_load_si256(in);         // AaBbCcDd EeFfGgHh
+               __m256i data2 = _mm256_stream_load_si256(in + 1);     // IiJjKkLl MmNnOoPp
+
+               data1 = _mm256_shuffle_epi8(data1, shuffle_cw);       // ABCDabcd EFGHefgh
+               data2 = _mm256_shuffle_epi8(data2, shuffle_cw);       // IJKLijkl MNOPmnop
+
+               data1 = _mm256_permute4x64_epi64(data1, 0b11011000);  // ABCDEFGH abcdefgh
+               data2 = _mm256_permute4x64_epi64(data2, 0b11011000);  // IJKLMNOP ijklmnop
+
+               __m256i lo = _mm256_permute2x128_si256(data1, data2, 0b00100000);
+               __m256i hi = _mm256_permute2x128_si256(data1, data2, 0b00110001);
+
+               _mm256_storeu_si256(out1, lo);
+               _mm256_storeu_si256(out2, hi);
+
+               in += 2;
+               ++out1;
+               ++out2;
+               consumed += 32;
+       }
+
+       return consumed;
+}
+
+// Returns the number of bytes consumed.
+size_t memcpy_interleaved_word_fastpath(uint16_t *dest1, uint16_t *dest2, const uint16_t *src, size_t n)
+{
+       // We assume this to generally be the case, but just to be sure,
+       // drop down to the slow path.
+       if (intptr_t(dest1) % 2 != 0 || intptr_t(dest2) % 2 != 0 || intptr_t(src) % 2 != 0) {
+               return 0;
+       }
+
+       const uint16_t *limit = src + n;
+       size_t consumed = 0;
+
+       // Align end to 32 bytes.
+       limit = (const uint16_t *)(intptr_t(limit) & ~31);
+
+       if (src >= limit) {
+               return 0;
+       }
+
+       // Process [0,15] words, such that start gets aligned to 32 bytes (16 words).
+       const uint16_t *aligned_src = (const uint16_t *)(intptr_t(src + 31) & ~31);
+       if (aligned_src != src) {
+               size_t n2 = aligned_src - src;
+               memcpy_interleaved_word_slow(dest1, dest2, src, n2);
+               dest1 += n2 / 2;
+               dest2 += n2 / 2;
+               if (n2 % 2) {
+                       swap(dest1, dest2);
+               }
+               src = aligned_src;
+               consumed += n2;
+       }
+
+       // Make the length a multiple of 32 words (64 bytes).
+       if (((limit - src) % 32) != 0) {
+               limit -= 16;
+       }
+       assert(((limit - src) % 32) == 0);
+
+       return consumed + memcpy_interleaved_word_fastpath_core(dest1, dest2, src, limit);
+}
+
 #endif  // defined(HAS_MULTIVERSIONING)
 
 void memcpy_interleaved(uint8_t *dest1, uint8_t *dest2, const uint8_t *src, size_t n)
@@ -131,3 +239,21 @@ void memcpy_interleaved(uint8_t *dest1, uint8_t *dest2, const uint8_t *src, size
                memcpy_interleaved_slow(dest1, dest2, src, n);
        }
 }
+
+void memcpy_interleaved_word(uint16_t *dest1, uint16_t *dest2, const uint16_t *src, size_t n)
+{
+#if HAS_MULTIVERSIONING
+       size_t consumed = memcpy_interleaved_word_fastpath(dest1, dest2, src, n);
+       src += consumed;
+       dest1 += consumed / 2;
+       dest2 += consumed / 2;
+       if (consumed % 2) {
+               swap(dest1, dest2);
+       }
+       n -= consumed;
+#endif
+
+       if (n > 0) {
+               memcpy_interleaved_word_slow(dest1, dest2, src, n);
+       }
+}
index a7f8994fdff8ce2dd996775095f17675b3b0c0cd..7965c13c79db0643755875ee63114e86ac33a11d 100644 (file)
@@ -8,4 +8,8 @@
 // TODO: Support stride.
 void memcpy_interleaved(uint8_t *dest1, uint8_t *dest2, const uint8_t *src, size_t n);
 
+// Same, but every other word instead of every other byte.
+// n is number of words, not number of bytes.
+void memcpy_interleaved_word(uint16_t *dest1, uint16_t *dest2, const uint16_t *src, size_t n);
+
 #endif  // !defined(_MEMCPY_INTERLEAVED_H)