]> git.sesse.net Git - ffmpeg/blob - libavcodec/wavpack.c
avcodec/wavpack: fix some syle issues
[ffmpeg] / libavcodec / wavpack.c
1 /*
2  * WavPack lossless audio decoder
3  * Copyright (c) 2006,2011 Konstantin Shishkov
4  * Copyright (c) 2020 David Bryant
5  *
6  * This file is part of FFmpeg.
7  *
8  * FFmpeg is free software; you can redistribute it and/or
9  * modify it under the terms of the GNU Lesser General Public
10  * License as published by the Free Software Foundation; either
11  * version 2.1 of the License, or (at your option) any later version.
12  *
13  * FFmpeg is distributed in the hope that it will be useful,
14  * but WITHOUT ANY WARRANTY; without even the implied warranty of
15  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
16  * Lesser General Public License for more details.
17  *
18  * You should have received a copy of the GNU Lesser General Public
19  * License along with FFmpeg; if not, write to the Free Software
20  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
21  */
22
23 #include "libavutil/channel_layout.h"
24
25 #define BITSTREAM_READER_LE
26 #include "avcodec.h"
27 #include "bytestream.h"
28 #include "get_bits.h"
29 #include "internal.h"
30 #include "thread.h"
31 #include "unary.h"
32 #include "wavpack.h"
33 #include "dsd.h"
34
35 /**
36  * @file
37  * WavPack lossless audio decoder
38  */
39
40 #define DSD_BYTE_READY(low,high) (!(((low) ^ (high)) & 0xff000000))
41
42 #define PTABLE_BITS 8
43 #define PTABLE_BINS (1<<PTABLE_BITS)
44 #define PTABLE_MASK (PTABLE_BINS-1)
45
46 #define UP   0x010000fe
47 #define DOWN 0x00010000
48 #define DECAY 8
49
50 #define PRECISION 20
51 #define VALUE_ONE (1 << PRECISION)
52 #define PRECISION_USE 12
53
54 #define RATE_S 20
55
56 #define MAX_HISTORY_BITS    5
57 #define MAX_HISTORY_BINS    (1 << MAX_HISTORY_BITS)
58 #define MAX_BIN_BYTES       1280    // for value_lookup, per bin (2k - 512 - 256)
59
60 typedef enum {
61     MODULATION_PCM,     // pulse code modulation
62     MODULATION_DSD      // pulse density modulation (aka DSD)
63 } Modulation;
64
65 typedef struct WavpackFrameContext {
66     AVCodecContext *avctx;
67     int frame_flags;
68     int stereo, stereo_in;
69     int joint;
70     uint32_t CRC;
71     GetBitContext gb;
72     int got_extra_bits;
73     uint32_t crc_extra_bits;
74     GetBitContext gb_extra_bits;
75     int samples;
76     int terms;
77     Decorr decorr[MAX_TERMS];
78     int zero, one, zeroes;
79     int extra_bits;
80     int and, or, shift;
81     int post_shift;
82     int hybrid, hybrid_bitrate;
83     int hybrid_maxclip, hybrid_minclip;
84     int float_flag;
85     int float_shift;
86     int float_max_exp;
87     WvChannel ch[2];
88
89     GetByteContext gbyte;
90     int ptable [PTABLE_BINS];
91     uint8_t value_lookup_buffer[MAX_HISTORY_BINS*MAX_BIN_BYTES];
92     uint16_t summed_probabilities[MAX_HISTORY_BINS][256];
93     uint8_t probabilities[MAX_HISTORY_BINS][256];
94     uint8_t *value_lookup[MAX_HISTORY_BINS];
95 } WavpackFrameContext;
96
97 #define WV_MAX_FRAME_DECODERS 14
98
99 typedef struct WavpackContext {
100     AVCodecContext *avctx;
101
102     WavpackFrameContext *fdec[WV_MAX_FRAME_DECODERS];
103     int fdec_num;
104
105     int block;
106     int samples;
107     int ch_offset;
108
109     AVFrame *frame;
110     ThreadFrame curr_frame, prev_frame;
111     Modulation modulation;
112     DSDContext *dsdctx;
113 } WavpackContext;
114
115 #define LEVEL_DECAY(a)  (((a) + 0x80) >> 8)
116
117 static av_always_inline unsigned get_tail(GetBitContext *gb, int k)
118 {
119     int p, e, res;
120
121     if (k < 1)
122         return 0;
123     p   = av_log2(k);
124     e   = (1 << (p + 1)) - k - 1;
125     res = get_bitsz(gb, p);
126     if (res >= e)
127         res = (res << 1) - e + get_bits1(gb);
128     return res;
129 }
130
131 static int update_error_limit(WavpackFrameContext *ctx)
132 {
133     int i, br[2], sl[2];
134
135     for (i = 0; i <= ctx->stereo_in; i++) {
136         if (ctx->ch[i].bitrate_acc > UINT_MAX - ctx->ch[i].bitrate_delta)
137             return AVERROR_INVALIDDATA;
138         ctx->ch[i].bitrate_acc += ctx->ch[i].bitrate_delta;
139         br[i]                   = ctx->ch[i].bitrate_acc >> 16;
140         sl[i]                   = LEVEL_DECAY(ctx->ch[i].slow_level);
141     }
142     if (ctx->stereo_in && ctx->hybrid_bitrate) {
143         int balance = (sl[1] - sl[0] + br[1] + 1) >> 1;
144         if (balance > br[0]) {
145             br[1] = br[0] * 2;
146             br[0] = 0;
147         } else if (-balance > br[0]) {
148             br[0]  *= 2;
149             br[1]   = 0;
150         } else {
151             br[1] = br[0] + balance;
152             br[0] = br[0] - balance;
153         }
154     }
155     for (i = 0; i <= ctx->stereo_in; i++) {
156         if (ctx->hybrid_bitrate) {
157             if (sl[i] - br[i] > -0x100)
158                 ctx->ch[i].error_limit = wp_exp2(sl[i] - br[i] + 0x100);
159             else
160                 ctx->ch[i].error_limit = 0;
161         } else {
162             ctx->ch[i].error_limit = wp_exp2(br[i]);
163         }
164     }
165
166     return 0;
167 }
168
169 static int wv_get_value(WavpackFrameContext *ctx, GetBitContext *gb,
170                         int channel, int *last)
171 {
172     int t, t2;
173     int sign, base, add, ret;
174     WvChannel *c = &ctx->ch[channel];
175
176     *last = 0;
177
178     if ((ctx->ch[0].median[0] < 2U) && (ctx->ch[1].median[0] < 2U) &&
179         !ctx->zero && !ctx->one) {
180         if (ctx->zeroes) {
181             ctx->zeroes--;
182             if (ctx->zeroes) {
183                 c->slow_level -= LEVEL_DECAY(c->slow_level);
184                 return 0;
185             }
186         } else {
187             t = get_unary_0_33(gb);
188             if (t >= 2) {
189                 if (t >= 32 || get_bits_left(gb) < t - 1)
190                     goto error;
191                 t = get_bits_long(gb, t - 1) | (1 << (t - 1));
192             } else {
193                 if (get_bits_left(gb) < 0)
194                     goto error;
195             }
196             ctx->zeroes = t;
197             if (ctx->zeroes) {
198                 memset(ctx->ch[0].median, 0, sizeof(ctx->ch[0].median));
199                 memset(ctx->ch[1].median, 0, sizeof(ctx->ch[1].median));
200                 c->slow_level -= LEVEL_DECAY(c->slow_level);
201                 return 0;
202             }
203         }
204     }
205
206     if (ctx->zero) {
207         t         = 0;
208         ctx->zero = 0;
209     } else {
210         t = get_unary_0_33(gb);
211         if (get_bits_left(gb) < 0)
212             goto error;
213         if (t == 16) {
214             t2 = get_unary_0_33(gb);
215             if (t2 < 2) {
216                 if (get_bits_left(gb) < 0)
217                     goto error;
218                 t += t2;
219             } else {
220                 if (t2 >= 32 || get_bits_left(gb) < t2 - 1)
221                     goto error;
222                 t += get_bits_long(gb, t2 - 1) | (1 << (t2 - 1));
223             }
224         }
225
226         if (ctx->one) {
227             ctx->one = t & 1;
228             t        = (t >> 1) + 1;
229         } else {
230             ctx->one = t & 1;
231             t      >>= 1;
232         }
233         ctx->zero = !ctx->one;
234     }
235
236     if (ctx->hybrid && !channel) {
237         if (update_error_limit(ctx) < 0)
238             goto error;
239     }
240
241     if (!t) {
242         base = 0;
243         add  = GET_MED(0) - 1;
244         DEC_MED(0);
245     } else if (t == 1) {
246         base = GET_MED(0);
247         add  = GET_MED(1) - 1;
248         INC_MED(0);
249         DEC_MED(1);
250     } else if (t == 2) {
251         base = GET_MED(0) + GET_MED(1);
252         add  = GET_MED(2) - 1;
253         INC_MED(0);
254         INC_MED(1);
255         DEC_MED(2);
256     } else {
257         base = GET_MED(0) + GET_MED(1) + GET_MED(2) * (t - 2U);
258         add  = GET_MED(2) - 1;
259         INC_MED(0);
260         INC_MED(1);
261         INC_MED(2);
262     }
263     if (!c->error_limit) {
264         if (add >= 0x2000000U) {
265             av_log(ctx->avctx, AV_LOG_ERROR, "k %d is too large\n", add);
266             goto error;
267         }
268         ret = base + get_tail(gb, add);
269         if (get_bits_left(gb) <= 0)
270             goto error;
271     } else {
272         int mid = (base * 2U + add + 1) >> 1;
273         while (add > c->error_limit) {
274             if (get_bits_left(gb) <= 0)
275                 goto error;
276             if (get_bits1(gb)) {
277                 add -= (mid - (unsigned)base);
278                 base = mid;
279             } else
280                 add = mid - (unsigned)base - 1;
281             mid = (base * 2U + add + 1) >> 1;
282         }
283         ret = mid;
284     }
285     sign = get_bits1(gb);
286     if (ctx->hybrid_bitrate)
287         c->slow_level += wp_log2(ret) - LEVEL_DECAY(c->slow_level);
288     return sign ? ~ret : ret;
289
290 error:
291     ret = get_bits_left(gb);
292     if (ret <= 0) {
293         av_log(ctx->avctx, AV_LOG_ERROR, "Too few bits (%d) left\n", ret);
294     }
295     *last = 1;
296     return 0;
297 }
298
299 static inline int wv_get_value_integer(WavpackFrameContext *s, uint32_t *crc,
300                                        unsigned S)
301 {
302     unsigned bit;
303
304     if (s->extra_bits) {
305         S *= 1 << s->extra_bits;
306
307         if (s->got_extra_bits &&
308             get_bits_left(&s->gb_extra_bits) >= s->extra_bits) {
309             S   |= get_bits_long(&s->gb_extra_bits, s->extra_bits);
310             *crc = *crc * 9 + (S & 0xffff) * 3 + ((unsigned)S >> 16);
311         }
312     }
313
314     bit = (S & s->and) | s->or;
315     bit = ((S + bit) << s->shift) - bit;
316
317     if (s->hybrid)
318         bit = av_clip(bit, s->hybrid_minclip, s->hybrid_maxclip);
319
320     return bit << s->post_shift;
321 }
322
323 static float wv_get_value_float(WavpackFrameContext *s, uint32_t *crc, int S)
324 {
325     union {
326         float    f;
327         uint32_t u;
328     } value;
329
330     unsigned int sign;
331     int exp = s->float_max_exp;
332
333     if (s->got_extra_bits) {
334         const int max_bits  = 1 + 23 + 8 + 1;
335         const int left_bits = get_bits_left(&s->gb_extra_bits);
336
337         if (left_bits + 8 * AV_INPUT_BUFFER_PADDING_SIZE < max_bits)
338             return 0.0;
339     }
340
341     if (S) {
342         S  *= 1U << s->float_shift;
343         sign = S < 0;
344         if (sign)
345             S = -(unsigned)S;
346         if (S >= 0x1000000U) {
347             if (s->got_extra_bits && get_bits1(&s->gb_extra_bits))
348                 S = get_bits(&s->gb_extra_bits, 23);
349             else
350                 S = 0;
351             exp = 255;
352         } else if (exp) {
353             int shift = 23 - av_log2(S);
354             exp = s->float_max_exp;
355             if (exp <= shift)
356                 shift = --exp;
357             exp -= shift;
358
359             if (shift) {
360                 S <<= shift;
361                 if ((s->float_flag & WV_FLT_SHIFT_ONES) ||
362                     (s->got_extra_bits &&
363                      (s->float_flag & WV_FLT_SHIFT_SAME) &&
364                      get_bits1(&s->gb_extra_bits))) {
365                     S |= (1 << shift) - 1;
366                 } else if (s->got_extra_bits &&
367                            (s->float_flag & WV_FLT_SHIFT_SENT)) {
368                     S |= get_bits(&s->gb_extra_bits, shift);
369                 }
370             }
371         } else {
372             exp = s->float_max_exp;
373         }
374         S &= 0x7fffff;
375     } else {
376         sign = 0;
377         exp  = 0;
378         if (s->got_extra_bits && (s->float_flag & WV_FLT_ZERO_SENT)) {
379             if (get_bits1(&s->gb_extra_bits)) {
380                 S = get_bits(&s->gb_extra_bits, 23);
381                 if (s->float_max_exp >= 25)
382                     exp = get_bits(&s->gb_extra_bits, 8);
383                 sign = get_bits1(&s->gb_extra_bits);
384             } else {
385                 if (s->float_flag & WV_FLT_ZERO_SIGN)
386                     sign = get_bits1(&s->gb_extra_bits);
387             }
388         }
389     }
390
391     *crc = *crc * 27 + S * 9 + exp * 3 + sign;
392
393     value.u = (sign << 31) | (exp << 23) | S;
394     return value.f;
395 }
396
397 static inline int wv_check_crc(WavpackFrameContext *s, uint32_t crc,
398                                uint32_t crc_extra_bits)
399 {
400     if (crc != s->CRC) {
401         av_log(s->avctx, AV_LOG_ERROR, "CRC error\n");
402         return AVERROR_INVALIDDATA;
403     }
404     if (s->got_extra_bits && crc_extra_bits != s->crc_extra_bits) {
405         av_log(s->avctx, AV_LOG_ERROR, "Extra bits CRC error\n");
406         return AVERROR_INVALIDDATA;
407     }
408
409     return 0;
410 }
411
412 static void init_ptable(int *table, int rate_i, int rate_s)
413 {
414     int value = 0x808000, rate = rate_i << 8;
415
416     for (int c = (rate + 128) >> 8; c--;)
417         value += (DOWN - value) >> DECAY;
418
419     for (int i = 0; i < PTABLE_BINS/2; i++) {
420         table[i] = value;
421         table[PTABLE_BINS-1-i] = 0x100ffff - value;
422
423         if (value > 0x010000) {
424             rate += (rate * rate_s + 128) >> 8;
425
426             for (int c = (rate + 64) >> 7; c--;)
427                 value += (DOWN - value) >> DECAY;
428         }
429     }
430 }
431
432 typedef struct {
433     int32_t value, fltr0, fltr1, fltr2, fltr3, fltr4, fltr5, fltr6, factor;
434     unsigned int byte;
435 } DSDfilters;
436
437 static int wv_unpack_dsd_high(WavpackFrameContext *s, uint8_t *dst_left, uint8_t *dst_right)
438 {
439     uint32_t checksum = 0xFFFFFFFF;
440     uint8_t *dst_l = dst_left, *dst_r = dst_right;
441     int total_samples = s->samples, stereo = dst_r ? 1 : 0;
442     DSDfilters filters[2], *sp = filters;
443     int rate_i, rate_s;
444     uint32_t low, high, value;
445
446     if (bytestream2_get_bytes_left(&s->gbyte) < (stereo ? 20 : 13))
447         return AVERROR_INVALIDDATA;
448
449     rate_i = bytestream2_get_byte(&s->gbyte);
450     rate_s = bytestream2_get_byte(&s->gbyte);
451
452     if (rate_s != RATE_S)
453         return AVERROR_INVALIDDATA;
454
455     init_ptable(s->ptable, rate_i, rate_s);
456
457     for (int channel = 0; channel < stereo + 1; channel++) {
458         DSDfilters *sp = filters + channel;
459
460         sp->fltr1 = bytestream2_get_byte(&s->gbyte) << (PRECISION - 8);
461         sp->fltr2 = bytestream2_get_byte(&s->gbyte) << (PRECISION - 8);
462         sp->fltr3 = bytestream2_get_byte(&s->gbyte) << (PRECISION - 8);
463         sp->fltr4 = bytestream2_get_byte(&s->gbyte) << (PRECISION - 8);
464         sp->fltr5 = bytestream2_get_byte(&s->gbyte) << (PRECISION - 8);
465         sp->fltr6 = 0;
466         sp->factor = bytestream2_get_byte(&s->gbyte) & 0xff;
467         sp->factor |= (bytestream2_get_byte(&s->gbyte) << 8) & 0xff00;
468         sp->factor = (int32_t)((uint32_t)sp->factor << 16) >> 16;
469     }
470
471     value = bytestream2_get_be32(&s->gbyte);
472     high = 0xffffffff;
473     low = 0x0;
474
475     while (total_samples--) {
476         int bitcount = 8;
477
478         sp[0].value = sp[0].fltr1 - sp[0].fltr5 + ((sp[0].fltr6 * sp[0].factor) >> 2);
479
480         if (stereo)
481             sp[1].value = sp[1].fltr1 - sp[1].fltr5 + ((sp[1].fltr6 * sp[1].factor) >> 2);
482
483         while (bitcount--) {
484             int32_t *pp = s->ptable + ((sp[0].value >> (PRECISION - PRECISION_USE)) & PTABLE_MASK);
485             uint32_t split = low + ((high - low) >> 8) * (*pp >> 16);
486
487             if (value <= split) {
488                 high = split;
489                 *pp += (UP - *pp) >> DECAY;
490                 sp[0].fltr0 = -1;
491             } else {
492                 low = split + 1;
493                 *pp += (DOWN - *pp) >> DECAY;
494                 sp[0].fltr0 = 0;
495             }
496
497             while (DSD_BYTE_READY(high, low) && bytestream2_get_bytes_left(&s->gbyte)) {
498                 value = (value << 8) | bytestream2_get_byte(&s->gbyte);
499                 high = (high << 8) | 0xff;
500                 low <<= 8;
501             }
502
503             sp[0].value += sp[0].fltr6 * 8;
504             sp[0].byte = (sp[0].byte << 1) | (sp[0].fltr0 & 1);
505             sp[0].factor += (((sp[0].value ^ sp[0].fltr0) >> 31) | 1) &
506                 ((sp[0].value ^ (sp[0].value - (sp[0].fltr6 * 16))) >> 31);
507             sp[0].fltr1 += ((sp[0].fltr0 & VALUE_ONE) - sp[0].fltr1) >> 6;
508             sp[0].fltr2 += ((sp[0].fltr0 & VALUE_ONE) - sp[0].fltr2) >> 4;
509             sp[0].fltr3 += (sp[0].fltr2 - sp[0].fltr3) >> 4;
510             sp[0].fltr4 += (sp[0].fltr3 - sp[0].fltr4) >> 4;
511             sp[0].value = (sp[0].fltr4 - sp[0].fltr5) >> 4;
512             sp[0].fltr5 += sp[0].value;
513             sp[0].fltr6 += (sp[0].value - sp[0].fltr6) >> 3;
514             sp[0].value = sp[0].fltr1 - sp[0].fltr5 + ((sp[0].fltr6 * sp[0].factor) >> 2);
515
516             if (!stereo)
517                 continue;
518
519             pp = s->ptable + ((sp[1].value >> (PRECISION - PRECISION_USE)) & PTABLE_MASK);
520             split = low + ((high - low) >> 8) * (*pp >> 16);
521
522             if (value <= split) {
523                 high = split;
524                 *pp += (UP - *pp) >> DECAY;
525                 sp[1].fltr0 = -1;
526             } else {
527                 low = split + 1;
528                 *pp += (DOWN - *pp) >> DECAY;
529                 sp[1].fltr0 = 0;
530             }
531
532             while (DSD_BYTE_READY(high, low) && bytestream2_get_bytes_left(&s->gbyte)) {
533                 value = (value << 8) | bytestream2_get_byte(&s->gbyte);
534                 high = (high << 8) | 0xff;
535                 low <<= 8;
536             }
537
538             sp[1].value += sp[1].fltr6 * 8;
539             sp[1].byte = (sp[1].byte << 1) | (sp[1].fltr0 & 1);
540             sp[1].factor += (((sp[1].value ^ sp[1].fltr0) >> 31) | 1) &
541                 ((sp[1].value ^ (sp[1].value - (sp[1].fltr6 * 16))) >> 31);
542             sp[1].fltr1 += ((sp[1].fltr0 & VALUE_ONE) - sp[1].fltr1) >> 6;
543             sp[1].fltr2 += ((sp[1].fltr0 & VALUE_ONE) - sp[1].fltr2) >> 4;
544             sp[1].fltr3 += (sp[1].fltr2 - sp[1].fltr3) >> 4;
545             sp[1].fltr4 += (sp[1].fltr3 - sp[1].fltr4) >> 4;
546             sp[1].value = (sp[1].fltr4 - sp[1].fltr5) >> 4;
547             sp[1].fltr5 += sp[1].value;
548             sp[1].fltr6 += (sp[1].value - sp[1].fltr6) >> 3;
549             sp[1].value = sp[1].fltr1 - sp[1].fltr5 + ((sp[1].fltr6 * sp[1].factor) >> 2);
550         }
551
552         checksum += (checksum << 1) + (*dst_l = sp[0].byte & 0xff);
553         sp[0].factor -= (sp[0].factor + 512) >> 10;
554         dst_l += 4;
555
556         if (stereo) {
557             checksum += (checksum << 1) + (*dst_r = filters[1].byte & 0xff);
558             filters[1].factor -= (filters[1].factor + 512) >> 10;
559             dst_r += 4;
560         }
561     }
562
563     if (wv_check_crc(s, checksum, 0)) {
564         if (s->avctx->err_recognition & AV_EF_CRCCHECK)
565             return AVERROR_INVALIDDATA;
566
567         memset(dst_left, 0x69, s->samples * 4);
568
569         if (dst_r)
570             memset(dst_right, 0x69, s->samples * 4);
571     }
572
573     return 0;
574 }
575
576 static int wv_unpack_dsd_fast(WavpackFrameContext *s, uint8_t *dst_left, uint8_t *dst_right)
577 {
578     uint8_t *dst_l = dst_left, *dst_r = dst_right;
579     uint8_t history_bits, max_probability;
580     int total_summed_probabilities  = 0;
581     int total_samples               = s->samples;
582     uint8_t *vlb                    = s->value_lookup_buffer;
583     int history_bins, p0, p1, chan;
584     uint32_t checksum               = 0xFFFFFFFF;
585     uint32_t low, high, value;
586
587     if (!bytestream2_get_bytes_left(&s->gbyte))
588         return AVERROR_INVALIDDATA;
589
590     history_bits = bytestream2_get_byte(&s->gbyte);
591
592     if (!bytestream2_get_bytes_left(&s->gbyte) || history_bits > MAX_HISTORY_BITS)
593         return AVERROR_INVALIDDATA;
594
595     history_bins = 1 << history_bits;
596     max_probability = bytestream2_get_byte(&s->gbyte);
597
598     if (max_probability < 0xff) {
599         uint8_t *outptr = (uint8_t *)s->probabilities;
600         uint8_t *outend = outptr + sizeof(*s->probabilities) * history_bins;
601
602         while (outptr < outend && bytestream2_get_bytes_left(&s->gbyte)) {
603             int code = bytestream2_get_byte(&s->gbyte);
604
605             if (code > max_probability) {
606                 int zcount = code - max_probability;
607
608                 while (outptr < outend && zcount--)
609                     *outptr++ = 0;
610             } else if (code) {
611                 *outptr++ = code;
612             }
613             else {
614                 break;
615             }
616         }
617
618         if (outptr < outend ||
619             (bytestream2_get_bytes_left(&s->gbyte) && bytestream2_get_byte(&s->gbyte)))
620                 return AVERROR_INVALIDDATA;
621     } else if (bytestream2_get_bytes_left(&s->gbyte) > (int)sizeof(*s->probabilities) * history_bins) {
622         bytestream2_get_buffer(&s->gbyte, (uint8_t *)s->probabilities,
623             sizeof(*s->probabilities) * history_bins);
624     } else {
625         return AVERROR_INVALIDDATA;
626     }
627
628     for (p0 = 0; p0 < history_bins; p0++) {
629         int32_t sum_values = 0;
630
631         for (int i = 0; i < 256; i++)
632             s->summed_probabilities[p0][i] = sum_values += s->probabilities[p0][i];
633
634         if (sum_values) {
635             total_summed_probabilities += sum_values;
636
637             if (total_summed_probabilities > history_bins * MAX_BIN_BYTES)
638                 return AVERROR_INVALIDDATA;
639
640             s->value_lookup[p0] = vlb;
641
642             for (int i = 0; i < 256; i++) {
643                 int c = s->probabilities[p0][i];
644
645                 while (c--)
646                     *vlb++ = i;
647             }
648         }
649     }
650
651     if (bytestream2_get_bytes_left(&s->gbyte) < 4)
652         return AVERROR_INVALIDDATA;
653
654     chan = p0 = p1 = 0;
655     low = 0; high = 0xffffffff;
656     value = bytestream2_get_be32(&s->gbyte);
657
658     if (dst_r)
659         total_samples *= 2;
660
661     while (total_samples--) {
662         unsigned int mult, index, code;
663
664         if (!s->summed_probabilities[p0][255])
665             return AVERROR_INVALIDDATA;
666
667         mult = (high - low) / s->summed_probabilities[p0][255];
668
669         if (!mult) {
670             if (bytestream2_get_bytes_left(&s->gbyte) >= 4)
671                 value = bytestream2_get_be32(&s->gbyte);
672
673             low = 0;
674             high = 0xffffffff;
675             mult = high / s->summed_probabilities[p0][255];
676
677             if (!mult)
678                 return AVERROR_INVALIDDATA;
679         }
680
681         index = (value - low) / mult;
682
683         if (index >= s->summed_probabilities[p0][255])
684             return AVERROR_INVALIDDATA;
685
686         if (!dst_r) {
687             if ((*dst_l = code = s->value_lookup[p0][index]))
688                 low += s->summed_probabilities[p0][code-1] * mult;
689
690             dst_l += 4;
691         } else {
692             if ((code = s->value_lookup[p0][index]))
693                 low += s->summed_probabilities[p0][code-1] * mult;
694
695             if (chan) {
696                 *dst_r = code;
697                 dst_r += 4;
698             }
699             else {
700                 *dst_l = code;
701                 dst_l += 4;
702             }
703
704             chan ^= 1;
705         }
706
707         high = low + s->probabilities[p0][code] * mult - 1;
708         checksum += (checksum << 1) + code;
709
710         if (!dst_r) {
711             p0 = code & (history_bins-1);
712         } else {
713             p0 = p1;
714             p1 = code & (history_bins-1);
715         }
716
717         while (DSD_BYTE_READY(high, low) && bytestream2_get_bytes_left(&s->gbyte)) {
718             value = (value << 8) | bytestream2_get_byte(&s->gbyte);
719             high = (high << 8) | 0xff;
720             low <<= 8;
721         }
722     }
723
724     if (wv_check_crc(s, checksum, 0)) {
725         if (s->avctx->err_recognition & AV_EF_CRCCHECK)
726             return AVERROR_INVALIDDATA;
727
728         memset(dst_left, 0x69, s->samples * 4);
729
730         if (dst_r)
731             memset(dst_right, 0x69, s->samples * 4);
732     }
733
734     return 0;
735 }
736
737 static int wv_unpack_dsd_copy(WavpackFrameContext *s, uint8_t *dst_left, uint8_t *dst_right)
738 {
739     uint8_t *dst_l = dst_left, *dst_r = dst_right;
740     int total_samples           = s->samples;
741     uint32_t checksum           = 0xFFFFFFFF;
742
743     if (bytestream2_get_bytes_left(&s->gbyte) != total_samples * (dst_r ? 2 : 1))
744         return AVERROR_INVALIDDATA;
745
746     while (total_samples--) {
747         checksum += (checksum << 1) + (*dst_l = bytestream2_get_byte(&s->gbyte));
748         dst_l += 4;
749
750         if (dst_r) {
751             checksum += (checksum << 1) + (*dst_r = bytestream2_get_byte(&s->gbyte));
752             dst_r += 4;
753         }
754     }
755
756     if (wv_check_crc(s, checksum, 0)) {
757         if (s->avctx->err_recognition & AV_EF_CRCCHECK)
758             return AVERROR_INVALIDDATA;
759
760         memset(dst_left, 0x69, s->samples * 4);
761
762         if (dst_r)
763             memset(dst_right, 0x69, s->samples * 4);
764     }
765
766     return 0;
767 }
768
769 static inline int wv_unpack_stereo(WavpackFrameContext *s, GetBitContext *gb,
770                                    void *dst_l, void *dst_r, const int type)
771 {
772     int i, j, count = 0;
773     int last, t;
774     int A, B, L, L2, R, R2;
775     int pos                 = 0;
776     uint32_t crc            = 0xFFFFFFFF;
777     uint32_t crc_extra_bits = 0xFFFFFFFF;
778     int16_t *dst16_l        = dst_l;
779     int16_t *dst16_r        = dst_r;
780     int32_t *dst32_l        = dst_l;
781     int32_t *dst32_r        = dst_r;
782     float *dstfl_l          = dst_l;
783     float *dstfl_r          = dst_r;
784
785     s->one = s->zero = s->zeroes = 0;
786     do {
787         L = wv_get_value(s, gb, 0, &last);
788         if (last)
789             break;
790         R = wv_get_value(s, gb, 1, &last);
791         if (last)
792             break;
793         for (i = 0; i < s->terms; i++) {
794             t = s->decorr[i].value;
795             if (t > 0) {
796                 if (t > 8) {
797                     if (t & 1) {
798                         A = 2U * s->decorr[i].samplesA[0] - s->decorr[i].samplesA[1];
799                         B = 2U * s->decorr[i].samplesB[0] - s->decorr[i].samplesB[1];
800                     } else {
801                         A = (int)(3U * s->decorr[i].samplesA[0] - s->decorr[i].samplesA[1]) >> 1;
802                         B = (int)(3U * s->decorr[i].samplesB[0] - s->decorr[i].samplesB[1]) >> 1;
803                     }
804                     s->decorr[i].samplesA[1] = s->decorr[i].samplesA[0];
805                     s->decorr[i].samplesB[1] = s->decorr[i].samplesB[0];
806                     j                        = 0;
807                 } else {
808                     A = s->decorr[i].samplesA[pos];
809                     B = s->decorr[i].samplesB[pos];
810                     j = (pos + t) & 7;
811                 }
812                 if (type != AV_SAMPLE_FMT_S16P) {
813                     L2 = L + ((s->decorr[i].weightA * (int64_t)A + 512) >> 10);
814                     R2 = R + ((s->decorr[i].weightB * (int64_t)B + 512) >> 10);
815                 } else {
816                     L2 = L + (unsigned)((int)(s->decorr[i].weightA * (unsigned)A + 512) >> 10);
817                     R2 = R + (unsigned)((int)(s->decorr[i].weightB * (unsigned)B + 512) >> 10);
818                 }
819                 if (A && L)
820                     s->decorr[i].weightA -= ((((L ^ A) >> 30) & 2) - 1) * s->decorr[i].delta;
821                 if (B && R)
822                     s->decorr[i].weightB -= ((((R ^ B) >> 30) & 2) - 1) * s->decorr[i].delta;
823                 s->decorr[i].samplesA[j] = L = L2;
824                 s->decorr[i].samplesB[j] = R = R2;
825             } else if (t == -1) {
826                 if (type != AV_SAMPLE_FMT_S16P)
827                     L2 = L + ((s->decorr[i].weightA * (int64_t)s->decorr[i].samplesA[0] + 512) >> 10);
828                 else
829                     L2 = L + (unsigned)((int)(s->decorr[i].weightA * (unsigned)s->decorr[i].samplesA[0] + 512) >> 10);
830                 UPDATE_WEIGHT_CLIP(s->decorr[i].weightA, s->decorr[i].delta, s->decorr[i].samplesA[0], L);
831                 L = L2;
832                 if (type != AV_SAMPLE_FMT_S16P)
833                     R2 = R + ((s->decorr[i].weightB * (int64_t)L2 + 512) >> 10);
834                 else
835                     R2 = R + (unsigned)((int)(s->decorr[i].weightB * (unsigned)L2 + 512) >> 10);
836                 UPDATE_WEIGHT_CLIP(s->decorr[i].weightB, s->decorr[i].delta, L2, R);
837                 R                        = R2;
838                 s->decorr[i].samplesA[0] = R;
839             } else {
840                 if (type != AV_SAMPLE_FMT_S16P)
841                     R2 = R + ((s->decorr[i].weightB * (int64_t)s->decorr[i].samplesB[0] + 512) >> 10);
842                 else
843                     R2 = R + (unsigned)((int)(s->decorr[i].weightB * (unsigned)s->decorr[i].samplesB[0] + 512) >> 10);
844                 UPDATE_WEIGHT_CLIP(s->decorr[i].weightB, s->decorr[i].delta, s->decorr[i].samplesB[0], R);
845                 R = R2;
846
847                 if (t == -3) {
848                     R2                       = s->decorr[i].samplesA[0];
849                     s->decorr[i].samplesA[0] = R;
850                 }
851
852                 if (type != AV_SAMPLE_FMT_S16P)
853                     L2 = L + ((s->decorr[i].weightA * (int64_t)R2 + 512) >> 10);
854                 else
855                     L2 = L + (unsigned)((int)(s->decorr[i].weightA * (unsigned)R2 + 512) >> 10);
856                 UPDATE_WEIGHT_CLIP(s->decorr[i].weightA, s->decorr[i].delta, R2, L);
857                 L                        = L2;
858                 s->decorr[i].samplesB[0] = L;
859             }
860         }
861
862         if (type == AV_SAMPLE_FMT_S16P) {
863             if (FFABS((int64_t)L) + FFABS((int64_t)R) > (1<<19)) {
864                 av_log(s->avctx, AV_LOG_ERROR, "sample %d %d too large\n", L, R);
865                 return AVERROR_INVALIDDATA;
866             }
867         }
868
869         pos = (pos + 1) & 7;
870         if (s->joint)
871             L += (unsigned)(R -= (unsigned)(L >> 1));
872         crc = (crc * 3 + L) * 3 + R;
873
874         if (type == AV_SAMPLE_FMT_FLTP) {
875             *dstfl_l++ = wv_get_value_float(s, &crc_extra_bits, L);
876             *dstfl_r++ = wv_get_value_float(s, &crc_extra_bits, R);
877         } else if (type == AV_SAMPLE_FMT_S32P) {
878             *dst32_l++ = wv_get_value_integer(s, &crc_extra_bits, L);
879             *dst32_r++ = wv_get_value_integer(s, &crc_extra_bits, R);
880         } else {
881             *dst16_l++ = wv_get_value_integer(s, &crc_extra_bits, L);
882             *dst16_r++ = wv_get_value_integer(s, &crc_extra_bits, R);
883         }
884         count++;
885     } while (!last && count < s->samples);
886
887     if (last && count < s->samples) {
888         int size = av_get_bytes_per_sample(type);
889         memset((uint8_t*)dst_l + count*size, 0, (s->samples-count)*size);
890         memset((uint8_t*)dst_r + count*size, 0, (s->samples-count)*size);
891     }
892
893     if ((s->avctx->err_recognition & AV_EF_CRCCHECK) &&
894         wv_check_crc(s, crc, crc_extra_bits))
895         return AVERROR_INVALIDDATA;
896
897     return 0;
898 }
899
900 static inline int wv_unpack_mono(WavpackFrameContext *s, GetBitContext *gb,
901                                  void *dst, const int type)
902 {
903     int i, j, count = 0;
904     int last, t;
905     int A, S, T;
906     int pos                  = 0;
907     uint32_t crc             = 0xFFFFFFFF;
908     uint32_t crc_extra_bits  = 0xFFFFFFFF;
909     int16_t *dst16           = dst;
910     int32_t *dst32           = dst;
911     float *dstfl             = dst;
912
913     s->one = s->zero = s->zeroes = 0;
914     do {
915         T = wv_get_value(s, gb, 0, &last);
916         S = 0;
917         if (last)
918             break;
919         for (i = 0; i < s->terms; i++) {
920             t = s->decorr[i].value;
921             if (t > 8) {
922                 if (t & 1)
923                     A =  2U * s->decorr[i].samplesA[0] - s->decorr[i].samplesA[1];
924                 else
925                     A = (int)(3U * s->decorr[i].samplesA[0] - s->decorr[i].samplesA[1]) >> 1;
926                 s->decorr[i].samplesA[1] = s->decorr[i].samplesA[0];
927                 j                        = 0;
928             } else {
929                 A = s->decorr[i].samplesA[pos];
930                 j = (pos + t) & 7;
931             }
932             if (type != AV_SAMPLE_FMT_S16P)
933                 S = T + ((s->decorr[i].weightA * (int64_t)A + 512) >> 10);
934             else
935                 S = T + (unsigned)((int)(s->decorr[i].weightA * (unsigned)A + 512) >> 10);
936             if (A && T)
937                 s->decorr[i].weightA -= ((((T ^ A) >> 30) & 2) - 1) * s->decorr[i].delta;
938             s->decorr[i].samplesA[j] = T = S;
939         }
940         pos = (pos + 1) & 7;
941         crc = crc * 3 + S;
942
943         if (type == AV_SAMPLE_FMT_FLTP) {
944             *dstfl++ = wv_get_value_float(s, &crc_extra_bits, S);
945         } else if (type == AV_SAMPLE_FMT_S32P) {
946             *dst32++ = wv_get_value_integer(s, &crc_extra_bits, S);
947         } else {
948             *dst16++ = wv_get_value_integer(s, &crc_extra_bits, S);
949         }
950         count++;
951     } while (!last && count < s->samples);
952
953     if (last && count < s->samples) {
954         int size = av_get_bytes_per_sample(type);
955         memset((uint8_t*)dst + count*size, 0, (s->samples-count)*size);
956     }
957
958     if (s->avctx->err_recognition & AV_EF_CRCCHECK) {
959         int ret = wv_check_crc(s, crc, crc_extra_bits);
960         if (ret < 0 && s->avctx->err_recognition & AV_EF_EXPLODE)
961             return ret;
962     }
963
964     return 0;
965 }
966
967 static av_cold int wv_alloc_frame_context(WavpackContext *c)
968 {
969     if (c->fdec_num == WV_MAX_FRAME_DECODERS)
970         return -1;
971
972     c->fdec[c->fdec_num] = av_mallocz(sizeof(**c->fdec));
973     if (!c->fdec[c->fdec_num])
974         return -1;
975     c->fdec_num++;
976     c->fdec[c->fdec_num - 1]->avctx = c->avctx;
977
978     return 0;
979 }
980
981 #if HAVE_THREADS
982 static int init_thread_copy(AVCodecContext *avctx)
983 {
984     WavpackContext *s = avctx->priv_data;
985     s->avctx = avctx;
986
987     s->curr_frame.f = av_frame_alloc();
988     s->prev_frame.f = av_frame_alloc();
989
990     return 0;
991 }
992
993 static int update_thread_context(AVCodecContext *dst, const AVCodecContext *src)
994 {
995     WavpackContext *fsrc = src->priv_data;
996     WavpackContext *fdst = dst->priv_data;
997     int ret;
998
999     if (dst == src)
1000         return 0;
1001
1002     ff_thread_release_buffer(dst, &fdst->curr_frame);
1003     if (fsrc->curr_frame.f->data[0]) {
1004         if ((ret = ff_thread_ref_frame(&fdst->curr_frame, &fsrc->curr_frame)) < 0)
1005             return ret;
1006     }
1007
1008     return 0;
1009 }
1010 #endif
1011
1012 static av_cold int wavpack_decode_init(AVCodecContext *avctx)
1013 {
1014     WavpackContext *s = avctx->priv_data;
1015
1016     s->avctx = avctx;
1017
1018     s->fdec_num = 0;
1019
1020     avctx->internal->allocate_progress = 1;
1021
1022     s->curr_frame.f = av_frame_alloc();
1023     s->prev_frame.f = av_frame_alloc();
1024
1025     // the DSD to PCM context is shared (and used serially) between all decoding threads
1026     s->dsdctx = av_calloc(avctx->channels, sizeof(DSDContext));
1027
1028     if (!s->curr_frame.f || !s->prev_frame.f || !s->dsdctx)
1029         return AVERROR(ENOMEM);
1030
1031     for (int i = 0; i < avctx->channels; i++)
1032         memset(s->dsdctx[i].buf, 0x69, sizeof(s->dsdctx[i].buf));
1033
1034     ff_init_dsd_data();
1035
1036     return 0;
1037 }
1038
1039 static av_cold int wavpack_decode_end(AVCodecContext *avctx)
1040 {
1041     WavpackContext *s = avctx->priv_data;
1042
1043     for (int i = 0; i < s->fdec_num; i++)
1044         av_freep(&s->fdec[i]);
1045     s->fdec_num = 0;
1046
1047     ff_thread_release_buffer(avctx, &s->curr_frame);
1048     av_frame_free(&s->curr_frame.f);
1049
1050     ff_thread_release_buffer(avctx, &s->prev_frame);
1051     av_frame_free(&s->prev_frame.f);
1052
1053     if (!avctx->internal->is_copy)
1054         av_freep(&s->dsdctx);
1055
1056     return 0;
1057 }
1058
1059 static int wavpack_decode_block(AVCodecContext *avctx, int block_no,
1060                                 const uint8_t *buf, int buf_size)
1061 {
1062     WavpackContext *wc = avctx->priv_data;
1063     WavpackFrameContext *s;
1064     GetByteContext gb;
1065     void *samples_l = NULL, *samples_r = NULL;
1066     int ret;
1067     int got_terms   = 0, got_weights = 0, got_samples = 0,
1068         got_entropy = 0, got_pcm     = 0, got_float   = 0, got_hybrid = 0;
1069     int got_dsd = 0;
1070     int i, j, id, size, ssize, weights, t;
1071     int bpp, chan = 0, orig_bpp, sample_rate = 0, rate_x = 1, dsd_mode = 0;
1072     int multiblock;
1073     uint64_t chmask = 0;
1074
1075     if (block_no >= wc->fdec_num && wv_alloc_frame_context(wc) < 0) {
1076         av_log(avctx, AV_LOG_ERROR, "Error creating frame decode context\n");
1077         return AVERROR_INVALIDDATA;
1078     }
1079
1080     s = wc->fdec[block_no];
1081     if (!s) {
1082         av_log(avctx, AV_LOG_ERROR, "Context for block %d is not present\n",
1083                block_no);
1084         return AVERROR_INVALIDDATA;
1085     }
1086
1087     memset(s->decorr, 0, MAX_TERMS * sizeof(Decorr));
1088     memset(s->ch, 0, sizeof(s->ch));
1089     s->extra_bits     = 0;
1090     s->and            = s->or = s->shift = 0;
1091     s->got_extra_bits = 0;
1092
1093     bytestream2_init(&gb, buf, buf_size);
1094
1095     s->samples = bytestream2_get_le32(&gb);
1096     if (s->samples != wc->samples) {
1097         av_log(avctx, AV_LOG_ERROR, "Mismatching number of samples in "
1098                "a sequence: %d and %d\n", wc->samples, s->samples);
1099         return AVERROR_INVALIDDATA;
1100     }
1101     s->frame_flags = bytestream2_get_le32(&gb);
1102     bpp            = av_get_bytes_per_sample(avctx->sample_fmt);
1103     orig_bpp       = ((s->frame_flags & 0x03) + 1) << 3;
1104     multiblock     = (s->frame_flags & WV_SINGLE_BLOCK) != WV_SINGLE_BLOCK;
1105
1106     s->stereo         = !(s->frame_flags & WV_MONO);
1107     s->stereo_in      =  (s->frame_flags & WV_FALSE_STEREO) ? 0 : s->stereo;
1108     s->joint          =   s->frame_flags & WV_JOINT_STEREO;
1109     s->hybrid         =   s->frame_flags & WV_HYBRID_MODE;
1110     s->hybrid_bitrate =   s->frame_flags & WV_HYBRID_BITRATE;
1111     s->post_shift     = bpp * 8 - orig_bpp + ((s->frame_flags >> 13) & 0x1f);
1112     if (s->post_shift < 0 || s->post_shift > 31) {
1113         return AVERROR_INVALIDDATA;
1114     }
1115     s->hybrid_maxclip =  ((1LL << (orig_bpp - 1)) - 1);
1116     s->hybrid_minclip = ((-1UL << (orig_bpp - 1)));
1117     s->CRC            = bytestream2_get_le32(&gb);
1118
1119     // parse metadata blocks
1120     while (bytestream2_get_bytes_left(&gb)) {
1121         id   = bytestream2_get_byte(&gb);
1122         size = bytestream2_get_byte(&gb);
1123         if (id & WP_IDF_LONG)
1124             size |= (bytestream2_get_le16u(&gb)) << 8;
1125         size <<= 1; // size is specified in words
1126         ssize  = size;
1127         if (id & WP_IDF_ODD)
1128             size--;
1129         if (size < 0) {
1130             av_log(avctx, AV_LOG_ERROR,
1131                    "Got incorrect block %02X with size %i\n", id, size);
1132             break;
1133         }
1134         if (bytestream2_get_bytes_left(&gb) < ssize) {
1135             av_log(avctx, AV_LOG_ERROR,
1136                    "Block size %i is out of bounds\n", size);
1137             break;
1138         }
1139         switch (id & WP_IDF_MASK) {
1140         case WP_ID_DECTERMS:
1141             if (size > MAX_TERMS) {
1142                 av_log(avctx, AV_LOG_ERROR, "Too many decorrelation terms\n");
1143                 s->terms = 0;
1144                 bytestream2_skip(&gb, ssize);
1145                 continue;
1146             }
1147             s->terms = size;
1148             for (i = 0; i < s->terms; i++) {
1149                 uint8_t val = bytestream2_get_byte(&gb);
1150                 s->decorr[s->terms - i - 1].value = (val & 0x1F) - 5;
1151                 s->decorr[s->terms - i - 1].delta =  val >> 5;
1152             }
1153             got_terms = 1;
1154             break;
1155         case WP_ID_DECWEIGHTS:
1156             if (!got_terms) {
1157                 av_log(avctx, AV_LOG_ERROR, "No decorrelation terms met\n");
1158                 continue;
1159             }
1160             weights = size >> s->stereo_in;
1161             if (weights > MAX_TERMS || weights > s->terms) {
1162                 av_log(avctx, AV_LOG_ERROR, "Too many decorrelation weights\n");
1163                 bytestream2_skip(&gb, ssize);
1164                 continue;
1165             }
1166             for (i = 0; i < weights; i++) {
1167                 t = (int8_t)bytestream2_get_byte(&gb);
1168                 s->decorr[s->terms - i - 1].weightA = t * (1 << 3);
1169                 if (s->decorr[s->terms - i - 1].weightA > 0)
1170                     s->decorr[s->terms - i - 1].weightA +=
1171                         (s->decorr[s->terms - i - 1].weightA + 64) >> 7;
1172                 if (s->stereo_in) {
1173                     t = (int8_t)bytestream2_get_byte(&gb);
1174                     s->decorr[s->terms - i - 1].weightB = t * (1 << 3);
1175                     if (s->decorr[s->terms - i - 1].weightB > 0)
1176                         s->decorr[s->terms - i - 1].weightB +=
1177                             (s->decorr[s->terms - i - 1].weightB + 64) >> 7;
1178                 }
1179             }
1180             got_weights = 1;
1181             break;
1182         case WP_ID_DECSAMPLES:
1183             if (!got_terms) {
1184                 av_log(avctx, AV_LOG_ERROR, "No decorrelation terms met\n");
1185                 continue;
1186             }
1187             t = 0;
1188             for (i = s->terms - 1; (i >= 0) && (t < size); i--) {
1189                 if (s->decorr[i].value > 8) {
1190                     s->decorr[i].samplesA[0] =
1191                         wp_exp2(bytestream2_get_le16(&gb));
1192                     s->decorr[i].samplesA[1] =
1193                         wp_exp2(bytestream2_get_le16(&gb));
1194
1195                     if (s->stereo_in) {
1196                         s->decorr[i].samplesB[0] =
1197                             wp_exp2(bytestream2_get_le16(&gb));
1198                         s->decorr[i].samplesB[1] =
1199                             wp_exp2(bytestream2_get_le16(&gb));
1200                         t                       += 4;
1201                     }
1202                     t += 4;
1203                 } else if (s->decorr[i].value < 0) {
1204                     s->decorr[i].samplesA[0] =
1205                         wp_exp2(bytestream2_get_le16(&gb));
1206                     s->decorr[i].samplesB[0] =
1207                         wp_exp2(bytestream2_get_le16(&gb));
1208                     t                       += 4;
1209                 } else {
1210                     for (j = 0; j < s->decorr[i].value; j++) {
1211                         s->decorr[i].samplesA[j] =
1212                             wp_exp2(bytestream2_get_le16(&gb));
1213                         if (s->stereo_in) {
1214                             s->decorr[i].samplesB[j] =
1215                                 wp_exp2(bytestream2_get_le16(&gb));
1216                         }
1217                     }
1218                     t += s->decorr[i].value * 2 * (s->stereo_in + 1);
1219                 }
1220             }
1221             got_samples = 1;
1222             break;
1223         case WP_ID_ENTROPY:
1224             if (size != 6 * (s->stereo_in + 1)) {
1225                 av_log(avctx, AV_LOG_ERROR,
1226                        "Entropy vars size should be %i, got %i.\n",
1227                        6 * (s->stereo_in + 1), size);
1228                 bytestream2_skip(&gb, ssize);
1229                 continue;
1230             }
1231             for (j = 0; j <= s->stereo_in; j++)
1232                 for (i = 0; i < 3; i++) {
1233                     s->ch[j].median[i] = wp_exp2(bytestream2_get_le16(&gb));
1234                 }
1235             got_entropy = 1;
1236             break;
1237         case WP_ID_HYBRID:
1238             if (s->hybrid_bitrate) {
1239                 for (i = 0; i <= s->stereo_in; i++) {
1240                     s->ch[i].slow_level = wp_exp2(bytestream2_get_le16(&gb));
1241                     size               -= 2;
1242                 }
1243             }
1244             for (i = 0; i < (s->stereo_in + 1); i++) {
1245                 s->ch[i].bitrate_acc = bytestream2_get_le16(&gb) << 16;
1246                 size                -= 2;
1247             }
1248             if (size > 0) {
1249                 for (i = 0; i < (s->stereo_in + 1); i++) {
1250                     s->ch[i].bitrate_delta =
1251                         wp_exp2((int16_t)bytestream2_get_le16(&gb));
1252                 }
1253             } else {
1254                 for (i = 0; i < (s->stereo_in + 1); i++)
1255                     s->ch[i].bitrate_delta = 0;
1256             }
1257             got_hybrid = 1;
1258             break;
1259         case WP_ID_INT32INFO: {
1260             uint8_t val[4];
1261             if (size != 4) {
1262                 av_log(avctx, AV_LOG_ERROR,
1263                        "Invalid INT32INFO, size = %i\n",
1264                        size);
1265                 bytestream2_skip(&gb, ssize - 4);
1266                 continue;
1267             }
1268             bytestream2_get_buffer(&gb, val, 4);
1269             if (val[0] > 30) {
1270                 av_log(avctx, AV_LOG_ERROR,
1271                        "Invalid INT32INFO, extra_bits = %d (> 30)\n", val[0]);
1272                 continue;
1273             } else if (val[0]) {
1274                 s->extra_bits = val[0];
1275             } else if (val[1]) {
1276                 s->shift = val[1];
1277             } else if (val[2]) {
1278                 s->and   = s->or = 1;
1279                 s->shift = val[2];
1280             } else if (val[3]) {
1281                 s->and   = 1;
1282                 s->shift = val[3];
1283             }
1284             if (s->shift > 31) {
1285                 av_log(avctx, AV_LOG_ERROR,
1286                        "Invalid INT32INFO, shift = %d (> 31)\n", s->shift);
1287                 s->and = s->or = s->shift = 0;
1288                 continue;
1289             }
1290             /* original WavPack decoder forces 32-bit lossy sound to be treated
1291              * as 24-bit one in order to have proper clipping */
1292             if (s->hybrid && bpp == 4 && s->post_shift < 8 && s->shift > 8) {
1293                 s->post_shift      += 8;
1294                 s->shift           -= 8;
1295                 s->hybrid_maxclip >>= 8;
1296                 s->hybrid_minclip >>= 8;
1297             }
1298             break;
1299         }
1300         case WP_ID_FLOATINFO:
1301             if (size != 4) {
1302                 av_log(avctx, AV_LOG_ERROR,
1303                        "Invalid FLOATINFO, size = %i\n", size);
1304                 bytestream2_skip(&gb, ssize);
1305                 continue;
1306             }
1307             s->float_flag    = bytestream2_get_byte(&gb);
1308             s->float_shift   = bytestream2_get_byte(&gb);
1309             s->float_max_exp = bytestream2_get_byte(&gb);
1310             if (s->float_shift > 31) {
1311                 av_log(avctx, AV_LOG_ERROR,
1312                        "Invalid FLOATINFO, shift = %d (> 31)\n", s->float_shift);
1313                 s->float_shift = 0;
1314                 continue;
1315             }
1316             got_float        = 1;
1317             bytestream2_skip(&gb, 1);
1318             break;
1319         case WP_ID_DATA:
1320             if ((ret = init_get_bits8(&s->gb, gb.buffer, size)) < 0)
1321                 return ret;
1322             bytestream2_skip(&gb, size);
1323             got_pcm      = 1;
1324             break;
1325         case WP_ID_DSD_DATA:
1326             if (size < 2) {
1327                 av_log(avctx, AV_LOG_ERROR, "Invalid DSD_DATA, size = %i\n",
1328                        size);
1329                 bytestream2_skip(&gb, ssize);
1330                 continue;
1331             }
1332             rate_x = 1 << bytestream2_get_byte(&gb);
1333             dsd_mode = bytestream2_get_byte(&gb);
1334             if (dsd_mode && dsd_mode != 1 && dsd_mode != 3) {
1335                 av_log(avctx, AV_LOG_ERROR, "Invalid DSD encoding mode: %d\n",
1336                     dsd_mode);
1337                 return AVERROR_INVALIDDATA;
1338             }
1339             bytestream2_init(&s->gbyte, gb.buffer, size-2);
1340             bytestream2_skip(&gb, size-2);
1341             got_dsd      = 1;
1342             break;
1343         case WP_ID_EXTRABITS:
1344             if (size <= 4) {
1345                 av_log(avctx, AV_LOG_ERROR, "Invalid EXTRABITS, size = %i\n",
1346                        size);
1347                 bytestream2_skip(&gb, size);
1348                 continue;
1349             }
1350             if ((ret = init_get_bits8(&s->gb_extra_bits, gb.buffer, size)) < 0)
1351                 return ret;
1352             s->crc_extra_bits  = get_bits_long(&s->gb_extra_bits, 32);
1353             bytestream2_skip(&gb, size);
1354             s->got_extra_bits  = 1;
1355             break;
1356         case WP_ID_CHANINFO:
1357             if (size <= 1) {
1358                 av_log(avctx, AV_LOG_ERROR,
1359                        "Insufficient channel information\n");
1360                 return AVERROR_INVALIDDATA;
1361             }
1362             chan = bytestream2_get_byte(&gb);
1363             switch (size - 2) {
1364             case 0:
1365                 chmask = bytestream2_get_byte(&gb);
1366                 break;
1367             case 1:
1368                 chmask = bytestream2_get_le16(&gb);
1369                 break;
1370             case 2:
1371                 chmask = bytestream2_get_le24(&gb);
1372                 break;
1373             case 3:
1374                 chmask = bytestream2_get_le32(&gb);
1375                 break;
1376             case 4:
1377                 size = bytestream2_get_byte(&gb);
1378                 chan  |= (bytestream2_get_byte(&gb) & 0xF) << 8;
1379                 chan  += 1;
1380                 if (avctx->channels != chan)
1381                     av_log(avctx, AV_LOG_WARNING, "%i channels signalled"
1382                            " instead of %i.\n", chan, avctx->channels);
1383                 chmask = bytestream2_get_le24(&gb);
1384                 break;
1385             case 5:
1386                 size = bytestream2_get_byte(&gb);
1387                 chan  |= (bytestream2_get_byte(&gb) & 0xF) << 8;
1388                 chan  += 1;
1389                 if (avctx->channels != chan)
1390                     av_log(avctx, AV_LOG_WARNING, "%i channels signalled"
1391                            " instead of %i.\n", chan, avctx->channels);
1392                 chmask = bytestream2_get_le32(&gb);
1393                 break;
1394             default:
1395                 av_log(avctx, AV_LOG_ERROR, "Invalid channel info size %d\n",
1396                        size);
1397                 chan   = avctx->channels;
1398                 chmask = avctx->channel_layout;
1399             }
1400             break;
1401         case WP_ID_SAMPLE_RATE:
1402             if (size != 3) {
1403                 av_log(avctx, AV_LOG_ERROR, "Invalid custom sample rate.\n");
1404                 return AVERROR_INVALIDDATA;
1405             }
1406             sample_rate = bytestream2_get_le24(&gb);
1407             break;
1408         default:
1409             bytestream2_skip(&gb, size);
1410         }
1411         if (id & WP_IDF_ODD)
1412             bytestream2_skip(&gb, 1);
1413     }
1414
1415     if (got_pcm) {
1416         if (!got_terms) {
1417             av_log(avctx, AV_LOG_ERROR, "No block with decorrelation terms\n");
1418             return AVERROR_INVALIDDATA;
1419         }
1420         if (!got_weights) {
1421             av_log(avctx, AV_LOG_ERROR, "No block with decorrelation weights\n");
1422             return AVERROR_INVALIDDATA;
1423         }
1424         if (!got_samples) {
1425             av_log(avctx, AV_LOG_ERROR, "No block with decorrelation samples\n");
1426             return AVERROR_INVALIDDATA;
1427         }
1428         if (!got_entropy) {
1429             av_log(avctx, AV_LOG_ERROR, "No block with entropy info\n");
1430             return AVERROR_INVALIDDATA;
1431         }
1432         if (s->hybrid && !got_hybrid) {
1433             av_log(avctx, AV_LOG_ERROR, "Hybrid config not found\n");
1434             return AVERROR_INVALIDDATA;
1435         }
1436         if (!got_float && avctx->sample_fmt == AV_SAMPLE_FMT_FLTP) {
1437             av_log(avctx, AV_LOG_ERROR, "Float information not found\n");
1438             return AVERROR_INVALIDDATA;
1439         }
1440         if (s->got_extra_bits && avctx->sample_fmt != AV_SAMPLE_FMT_FLTP) {
1441             const int size   = get_bits_left(&s->gb_extra_bits);
1442             const int wanted = s->samples * s->extra_bits << s->stereo_in;
1443             if (size < wanted) {
1444                 av_log(avctx, AV_LOG_ERROR, "Too small EXTRABITS\n");
1445                 s->got_extra_bits = 0;
1446             }
1447         }
1448     }
1449
1450     if (!got_pcm && !got_dsd) {
1451         av_log(avctx, AV_LOG_ERROR, "Packed samples not found\n");
1452         return AVERROR_INVALIDDATA;
1453     }
1454
1455     if ((got_pcm && wc->modulation != MODULATION_PCM) ||
1456         (got_dsd && wc->modulation != MODULATION_DSD)) {
1457             av_log(avctx, AV_LOG_ERROR, "Invalid PCM/DSD mix encountered\n");
1458             return AVERROR_INVALIDDATA;
1459     }
1460
1461     if (!wc->ch_offset) {
1462         int sr = (s->frame_flags >> 23) & 0xf;
1463         if (sr == 0xf) {
1464             if (!sample_rate) {
1465                 av_log(avctx, AV_LOG_ERROR, "Custom sample rate missing.\n");
1466                 return AVERROR_INVALIDDATA;
1467             }
1468             avctx->sample_rate = sample_rate * rate_x;
1469         } else
1470             avctx->sample_rate = wv_rates[sr] * rate_x;
1471
1472         if (multiblock) {
1473             if (chan)
1474                 avctx->channels = chan;
1475             if (chmask)
1476                 avctx->channel_layout = chmask;
1477         } else {
1478             avctx->channels       = s->stereo ? 2 : 1;
1479             avctx->channel_layout = s->stereo ? AV_CH_LAYOUT_STEREO :
1480                                                 AV_CH_LAYOUT_MONO;
1481         }
1482
1483         ff_thread_release_buffer(avctx, &wc->prev_frame);
1484         FFSWAP(ThreadFrame, wc->curr_frame, wc->prev_frame);
1485
1486         /* get output buffer */
1487         wc->curr_frame.f->nb_samples = s->samples;
1488         if ((ret = ff_thread_get_buffer(avctx, &wc->curr_frame, AV_GET_BUFFER_FLAG_REF)) < 0)
1489             return ret;
1490
1491         wc->frame = wc->curr_frame.f;
1492         ff_thread_finish_setup(avctx);
1493     }
1494
1495     if (wc->ch_offset + s->stereo >= avctx->channels) {
1496         av_log(avctx, AV_LOG_WARNING, "Too many channels coded in a packet.\n");
1497         return ((avctx->err_recognition & AV_EF_EXPLODE) || !wc->ch_offset) ? AVERROR_INVALIDDATA : 0;
1498     }
1499
1500     samples_l = wc->frame->extended_data[wc->ch_offset];
1501     if (s->stereo)
1502         samples_r = wc->frame->extended_data[wc->ch_offset + 1];
1503
1504     wc->ch_offset += 1 + s->stereo;
1505
1506     if (s->stereo_in) {
1507         if (got_dsd) {
1508             if (dsd_mode == 3) {
1509                 ret = wv_unpack_dsd_high(s, samples_l, samples_r);
1510             } else if (dsd_mode == 1) {
1511                 ret = wv_unpack_dsd_fast(s, samples_l, samples_r);
1512             } else {
1513                 ret = wv_unpack_dsd_copy(s, samples_l, samples_r);
1514             }
1515         } else {
1516             ret = wv_unpack_stereo(s, &s->gb, samples_l, samples_r, avctx->sample_fmt);
1517         }
1518         if (ret < 0)
1519             return ret;
1520     } else {
1521         if (got_dsd) {
1522             if (dsd_mode == 3) {
1523                 ret = wv_unpack_dsd_high(s, samples_l, NULL);
1524             } else if (dsd_mode == 1) {
1525                 ret = wv_unpack_dsd_fast(s, samples_l, NULL);
1526             } else {
1527                 ret = wv_unpack_dsd_copy(s, samples_l, NULL);
1528             }
1529         } else {
1530             ret = wv_unpack_mono(s, &s->gb, samples_l, avctx->sample_fmt);
1531         }
1532         if (ret < 0)
1533             return ret;
1534
1535         if (s->stereo)
1536             memcpy(samples_r, samples_l, bpp * s->samples);
1537     }
1538
1539     return 0;
1540 }
1541
1542 static void wavpack_decode_flush(AVCodecContext *avctx)
1543 {
1544     WavpackContext *s = avctx->priv_data;
1545
1546     if (!avctx->internal->is_copy) {
1547         for (int i = 0; i < avctx->channels; i++)
1548             memset(s->dsdctx[i].buf, 0x69, sizeof(s->dsdctx[i].buf));
1549     }
1550 }
1551
1552 static int dsd_channel(AVCodecContext *avctx, void *frmptr, int jobnr, int threadnr)
1553 {
1554     WavpackContext *s  = avctx->priv_data;
1555     AVFrame *frame = frmptr;
1556
1557     ff_dsd2pcm_translate (&s->dsdctx [jobnr], s->samples, 0,
1558         (uint8_t *)frame->extended_data[jobnr], 4,
1559         (float *)frame->extended_data[jobnr], 1);
1560
1561     return 0;
1562 }
1563
1564 static int wavpack_decode_frame(AVCodecContext *avctx, void *data,
1565                                 int *got_frame_ptr, AVPacket *avpkt)
1566 {
1567     WavpackContext *s  = avctx->priv_data;
1568     const uint8_t *buf = avpkt->data;
1569     int buf_size       = avpkt->size;
1570     int frame_size, ret, frame_flags;
1571
1572     if (avpkt->size <= WV_HEADER_SIZE)
1573         return AVERROR_INVALIDDATA;
1574
1575     s->frame     = NULL;
1576     s->block     = 0;
1577     s->ch_offset = 0;
1578
1579     /* determine number of samples */
1580     s->samples  = AV_RL32(buf + 20);
1581     frame_flags = AV_RL32(buf + 24);
1582     if (s->samples <= 0 || s->samples > WV_MAX_SAMPLES) {
1583         av_log(avctx, AV_LOG_ERROR, "Invalid number of samples: %d\n",
1584                s->samples);
1585         return AVERROR_INVALIDDATA;
1586     }
1587
1588     s->modulation = (frame_flags & WV_DSD_DATA) ? MODULATION_DSD : MODULATION_PCM;
1589
1590     if (frame_flags & (WV_FLOAT_DATA | WV_DSD_DATA)) {
1591         avctx->sample_fmt = AV_SAMPLE_FMT_FLTP;
1592     } else if ((frame_flags & 0x03) <= 1) {
1593         avctx->sample_fmt = AV_SAMPLE_FMT_S16P;
1594     } else {
1595         avctx->sample_fmt          = AV_SAMPLE_FMT_S32P;
1596         avctx->bits_per_raw_sample = ((frame_flags & 0x03) + 1) << 3;
1597     }
1598
1599     while (buf_size > WV_HEADER_SIZE) {
1600         frame_size = AV_RL32(buf + 4) - 12;
1601         buf       += 20;
1602         buf_size  -= 20;
1603         if (frame_size <= 0 || frame_size > buf_size) {
1604             av_log(avctx, AV_LOG_ERROR,
1605                    "Block %d has invalid size (size %d vs. %d bytes left)\n",
1606                    s->block, frame_size, buf_size);
1607             ret = AVERROR_INVALIDDATA;
1608             goto error;
1609         }
1610         if ((ret = wavpack_decode_block(avctx, s->block, buf, frame_size)) < 0)
1611             goto error;
1612         s->block++;
1613         buf      += frame_size;
1614         buf_size -= frame_size;
1615     }
1616
1617     if (s->ch_offset != avctx->channels) {
1618         av_log(avctx, AV_LOG_ERROR, "Not enough channels coded in a packet.\n");
1619         ret = AVERROR_INVALIDDATA;
1620         goto error;
1621     }
1622
1623     ff_thread_await_progress(&s->prev_frame, INT_MAX, 0);
1624     ff_thread_release_buffer(avctx, &s->prev_frame);
1625
1626     if (s->modulation == MODULATION_DSD)
1627         avctx->execute2(avctx, dsd_channel, s->frame, NULL, avctx->channels);
1628
1629     ff_thread_report_progress(&s->curr_frame, INT_MAX, 0);
1630
1631     if ((ret = av_frame_ref(data, s->frame)) < 0)
1632         return ret;
1633
1634     *got_frame_ptr = 1;
1635
1636     return avpkt->size;
1637
1638 error:
1639     if (s->frame) {
1640         ff_thread_await_progress(&s->prev_frame, INT_MAX, 0);
1641         ff_thread_release_buffer(avctx, &s->prev_frame);
1642         ff_thread_report_progress(&s->curr_frame, INT_MAX, 0);
1643     }
1644
1645     return ret;
1646 }
1647
1648 AVCodec ff_wavpack_decoder = {
1649     .name           = "wavpack",
1650     .long_name      = NULL_IF_CONFIG_SMALL("WavPack"),
1651     .type           = AVMEDIA_TYPE_AUDIO,
1652     .id             = AV_CODEC_ID_WAVPACK,
1653     .priv_data_size = sizeof(WavpackContext),
1654     .init           = wavpack_decode_init,
1655     .close          = wavpack_decode_end,
1656     .decode         = wavpack_decode_frame,
1657     .flush          = wavpack_decode_flush,
1658     .init_thread_copy = ONLY_IF_THREADS_ENABLED(init_thread_copy),
1659     .update_thread_context = ONLY_IF_THREADS_ENABLED(update_thread_context),
1660     .capabilities   = AV_CODEC_CAP_DR1 | AV_CODEC_CAP_FRAME_THREADS |
1661                       AV_CODEC_CAP_SLICE_THREADS
1662 };