]> git.sesse.net Git - ffmpeg/blob - libavcodec/wavpack.c
wavpack: fully support stream parameter changes
[ffmpeg] / libavcodec / wavpack.c
1 /*
2  * WavPack lossless audio decoder
3  * Copyright (c) 2006,2011 Konstantin Shishkov
4  * Copyright (c) 2020 David Bryant
5  *
6  * This file is part of FFmpeg.
7  *
8  * FFmpeg is free software; you can redistribute it and/or
9  * modify it under the terms of the GNU Lesser General Public
10  * License as published by the Free Software Foundation; either
11  * version 2.1 of the License, or (at your option) any later version.
12  *
13  * FFmpeg is distributed in the hope that it will be useful,
14  * but WITHOUT ANY WARRANTY; without even the implied warranty of
15  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
16  * Lesser General Public License for more details.
17  *
18  * You should have received a copy of the GNU Lesser General Public
19  * License along with FFmpeg; if not, write to the Free Software
20  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
21  */
22
23 #include "libavutil/buffer.h"
24 #include "libavutil/channel_layout.h"
25
26 #define BITSTREAM_READER_LE
27 #include "avcodec.h"
28 #include "bytestream.h"
29 #include "get_bits.h"
30 #include "internal.h"
31 #include "thread.h"
32 #include "unary.h"
33 #include "wavpack.h"
34 #include "dsd.h"
35
36 /**
37  * @file
38  * WavPack lossless audio decoder
39  */
40
41 #define DSD_BYTE_READY(low,high) (!(((low) ^ (high)) & 0xff000000))
42
43 #define PTABLE_BITS 8
44 #define PTABLE_BINS (1<<PTABLE_BITS)
45 #define PTABLE_MASK (PTABLE_BINS-1)
46
47 #define UP   0x010000fe
48 #define DOWN 0x00010000
49 #define DECAY 8
50
51 #define PRECISION 20
52 #define VALUE_ONE (1 << PRECISION)
53 #define PRECISION_USE 12
54
55 #define RATE_S 20
56
57 #define MAX_HISTORY_BITS    5
58 #define MAX_HISTORY_BINS    (1 << MAX_HISTORY_BITS)
59 #define MAX_BIN_BYTES       1280    // for value_lookup, per bin (2k - 512 - 256)
60
61 typedef enum {
62     MODULATION_PCM,     // pulse code modulation
63     MODULATION_DSD      // pulse density modulation (aka DSD)
64 } Modulation;
65
66 typedef struct WavpackFrameContext {
67     AVCodecContext *avctx;
68     int frame_flags;
69     int stereo, stereo_in;
70     int joint;
71     uint32_t CRC;
72     GetBitContext gb;
73     int got_extra_bits;
74     uint32_t crc_extra_bits;
75     GetBitContext gb_extra_bits;
76     int samples;
77     int terms;
78     Decorr decorr[MAX_TERMS];
79     int zero, one, zeroes;
80     int extra_bits;
81     int and, or, shift;
82     int post_shift;
83     int hybrid, hybrid_bitrate;
84     int hybrid_maxclip, hybrid_minclip;
85     int float_flag;
86     int float_shift;
87     int float_max_exp;
88     WvChannel ch[2];
89
90     GetByteContext gbyte;
91     int ptable [PTABLE_BINS];
92     uint8_t value_lookup_buffer[MAX_HISTORY_BINS*MAX_BIN_BYTES];
93     uint16_t summed_probabilities[MAX_HISTORY_BINS][256];
94     uint8_t probabilities[MAX_HISTORY_BINS][256];
95     uint8_t *value_lookup[MAX_HISTORY_BINS];
96 } WavpackFrameContext;
97
98 #define WV_MAX_FRAME_DECODERS 14
99
100 typedef struct WavpackContext {
101     AVCodecContext *avctx;
102
103     WavpackFrameContext *fdec[WV_MAX_FRAME_DECODERS];
104     int fdec_num;
105
106     int block;
107     int samples;
108     int ch_offset;
109
110     AVFrame *frame;
111     ThreadFrame curr_frame, prev_frame;
112     Modulation modulation;
113
114     AVBufferRef *dsd_ref;
115     DSDContext *dsdctx;
116     int dsd_channels;
117 } WavpackContext;
118
119 #define LEVEL_DECAY(a)  (((a) + 0x80) >> 8)
120
121 static av_always_inline unsigned get_tail(GetBitContext *gb, int k)
122 {
123     int p, e, res;
124
125     if (k < 1)
126         return 0;
127     p   = av_log2(k);
128     e   = (1 << (p + 1)) - k - 1;
129     res = get_bitsz(gb, p);
130     if (res >= e)
131         res = (res << 1) - e + get_bits1(gb);
132     return res;
133 }
134
135 static int update_error_limit(WavpackFrameContext *ctx)
136 {
137     int i, br[2], sl[2];
138
139     for (i = 0; i <= ctx->stereo_in; i++) {
140         if (ctx->ch[i].bitrate_acc > UINT_MAX - ctx->ch[i].bitrate_delta)
141             return AVERROR_INVALIDDATA;
142         ctx->ch[i].bitrate_acc += ctx->ch[i].bitrate_delta;
143         br[i]                   = ctx->ch[i].bitrate_acc >> 16;
144         sl[i]                   = LEVEL_DECAY(ctx->ch[i].slow_level);
145     }
146     if (ctx->stereo_in && ctx->hybrid_bitrate) {
147         int balance = (sl[1] - sl[0] + br[1] + 1) >> 1;
148         if (balance > br[0]) {
149             br[1] = br[0] * 2;
150             br[0] = 0;
151         } else if (-balance > br[0]) {
152             br[0]  *= 2;
153             br[1]   = 0;
154         } else {
155             br[1] = br[0] + balance;
156             br[0] = br[0] - balance;
157         }
158     }
159     for (i = 0; i <= ctx->stereo_in; i++) {
160         if (ctx->hybrid_bitrate) {
161             if (sl[i] - br[i] > -0x100)
162                 ctx->ch[i].error_limit = wp_exp2(sl[i] - br[i] + 0x100);
163             else
164                 ctx->ch[i].error_limit = 0;
165         } else {
166             ctx->ch[i].error_limit = wp_exp2(br[i]);
167         }
168     }
169
170     return 0;
171 }
172
173 static int wv_get_value(WavpackFrameContext *ctx, GetBitContext *gb,
174                         int channel, int *last)
175 {
176     int t, t2;
177     int sign, base, add, ret;
178     WvChannel *c = &ctx->ch[channel];
179
180     *last = 0;
181
182     if ((ctx->ch[0].median[0] < 2U) && (ctx->ch[1].median[0] < 2U) &&
183         !ctx->zero && !ctx->one) {
184         if (ctx->zeroes) {
185             ctx->zeroes--;
186             if (ctx->zeroes) {
187                 c->slow_level -= LEVEL_DECAY(c->slow_level);
188                 return 0;
189             }
190         } else {
191             t = get_unary_0_33(gb);
192             if (t >= 2) {
193                 if (t >= 32 || get_bits_left(gb) < t - 1)
194                     goto error;
195                 t = get_bits_long(gb, t - 1) | (1 << (t - 1));
196             } else {
197                 if (get_bits_left(gb) < 0)
198                     goto error;
199             }
200             ctx->zeroes = t;
201             if (ctx->zeroes) {
202                 memset(ctx->ch[0].median, 0, sizeof(ctx->ch[0].median));
203                 memset(ctx->ch[1].median, 0, sizeof(ctx->ch[1].median));
204                 c->slow_level -= LEVEL_DECAY(c->slow_level);
205                 return 0;
206             }
207         }
208     }
209
210     if (ctx->zero) {
211         t         = 0;
212         ctx->zero = 0;
213     } else {
214         t = get_unary_0_33(gb);
215         if (get_bits_left(gb) < 0)
216             goto error;
217         if (t == 16) {
218             t2 = get_unary_0_33(gb);
219             if (t2 < 2) {
220                 if (get_bits_left(gb) < 0)
221                     goto error;
222                 t += t2;
223             } else {
224                 if (t2 >= 32 || get_bits_left(gb) < t2 - 1)
225                     goto error;
226                 t += get_bits_long(gb, t2 - 1) | (1 << (t2 - 1));
227             }
228         }
229
230         if (ctx->one) {
231             ctx->one = t & 1;
232             t        = (t >> 1) + 1;
233         } else {
234             ctx->one = t & 1;
235             t      >>= 1;
236         }
237         ctx->zero = !ctx->one;
238     }
239
240     if (ctx->hybrid && !channel) {
241         if (update_error_limit(ctx) < 0)
242             goto error;
243     }
244
245     if (!t) {
246         base = 0;
247         add  = GET_MED(0) - 1;
248         DEC_MED(0);
249     } else if (t == 1) {
250         base = GET_MED(0);
251         add  = GET_MED(1) - 1;
252         INC_MED(0);
253         DEC_MED(1);
254     } else if (t == 2) {
255         base = GET_MED(0) + GET_MED(1);
256         add  = GET_MED(2) - 1;
257         INC_MED(0);
258         INC_MED(1);
259         DEC_MED(2);
260     } else {
261         base = GET_MED(0) + GET_MED(1) + GET_MED(2) * (t - 2U);
262         add  = GET_MED(2) - 1;
263         INC_MED(0);
264         INC_MED(1);
265         INC_MED(2);
266     }
267     if (!c->error_limit) {
268         if (add >= 0x2000000U) {
269             av_log(ctx->avctx, AV_LOG_ERROR, "k %d is too large\n", add);
270             goto error;
271         }
272         ret = base + get_tail(gb, add);
273         if (get_bits_left(gb) <= 0)
274             goto error;
275     } else {
276         int mid = (base * 2U + add + 1) >> 1;
277         while (add > c->error_limit) {
278             if (get_bits_left(gb) <= 0)
279                 goto error;
280             if (get_bits1(gb)) {
281                 add -= (mid - (unsigned)base);
282                 base = mid;
283             } else
284                 add = mid - (unsigned)base - 1;
285             mid = (base * 2U + add + 1) >> 1;
286         }
287         ret = mid;
288     }
289     sign = get_bits1(gb);
290     if (ctx->hybrid_bitrate)
291         c->slow_level += wp_log2(ret) - LEVEL_DECAY(c->slow_level);
292     return sign ? ~ret : ret;
293
294 error:
295     ret = get_bits_left(gb);
296     if (ret <= 0) {
297         av_log(ctx->avctx, AV_LOG_ERROR, "Too few bits (%d) left\n", ret);
298     }
299     *last = 1;
300     return 0;
301 }
302
303 static inline int wv_get_value_integer(WavpackFrameContext *s, uint32_t *crc,
304                                        unsigned S)
305 {
306     unsigned bit;
307
308     if (s->extra_bits) {
309         S *= 1 << s->extra_bits;
310
311         if (s->got_extra_bits &&
312             get_bits_left(&s->gb_extra_bits) >= s->extra_bits) {
313             S   |= get_bits_long(&s->gb_extra_bits, s->extra_bits);
314             *crc = *crc * 9 + (S & 0xffff) * 3 + ((unsigned)S >> 16);
315         }
316     }
317
318     bit = (S & s->and) | s->or;
319     bit = ((S + bit) << s->shift) - bit;
320
321     if (s->hybrid)
322         bit = av_clip(bit, s->hybrid_minclip, s->hybrid_maxclip);
323
324     return bit << s->post_shift;
325 }
326
327 static float wv_get_value_float(WavpackFrameContext *s, uint32_t *crc, int S)
328 {
329     union {
330         float    f;
331         uint32_t u;
332     } value;
333
334     unsigned int sign;
335     int exp = s->float_max_exp;
336
337     if (s->got_extra_bits) {
338         const int max_bits  = 1 + 23 + 8 + 1;
339         const int left_bits = get_bits_left(&s->gb_extra_bits);
340
341         if (left_bits + 8 * AV_INPUT_BUFFER_PADDING_SIZE < max_bits)
342             return 0.0;
343     }
344
345     if (S) {
346         S  *= 1U << s->float_shift;
347         sign = S < 0;
348         if (sign)
349             S = -(unsigned)S;
350         if (S >= 0x1000000U) {
351             if (s->got_extra_bits && get_bits1(&s->gb_extra_bits))
352                 S = get_bits(&s->gb_extra_bits, 23);
353             else
354                 S = 0;
355             exp = 255;
356         } else if (exp) {
357             int shift = 23 - av_log2(S);
358             exp = s->float_max_exp;
359             if (exp <= shift)
360                 shift = --exp;
361             exp -= shift;
362
363             if (shift) {
364                 S <<= shift;
365                 if ((s->float_flag & WV_FLT_SHIFT_ONES) ||
366                     (s->got_extra_bits &&
367                      (s->float_flag & WV_FLT_SHIFT_SAME) &&
368                      get_bits1(&s->gb_extra_bits))) {
369                     S |= (1 << shift) - 1;
370                 } else if (s->got_extra_bits &&
371                            (s->float_flag & WV_FLT_SHIFT_SENT)) {
372                     S |= get_bits(&s->gb_extra_bits, shift);
373                 }
374             }
375         } else {
376             exp = s->float_max_exp;
377         }
378         S &= 0x7fffff;
379     } else {
380         sign = 0;
381         exp  = 0;
382         if (s->got_extra_bits && (s->float_flag & WV_FLT_ZERO_SENT)) {
383             if (get_bits1(&s->gb_extra_bits)) {
384                 S = get_bits(&s->gb_extra_bits, 23);
385                 if (s->float_max_exp >= 25)
386                     exp = get_bits(&s->gb_extra_bits, 8);
387                 sign = get_bits1(&s->gb_extra_bits);
388             } else {
389                 if (s->float_flag & WV_FLT_ZERO_SIGN)
390                     sign = get_bits1(&s->gb_extra_bits);
391             }
392         }
393     }
394
395     *crc = *crc * 27 + S * 9 + exp * 3 + sign;
396
397     value.u = (sign << 31) | (exp << 23) | S;
398     return value.f;
399 }
400
401 static inline int wv_check_crc(WavpackFrameContext *s, uint32_t crc,
402                                uint32_t crc_extra_bits)
403 {
404     if (crc != s->CRC) {
405         av_log(s->avctx, AV_LOG_ERROR, "CRC error\n");
406         return AVERROR_INVALIDDATA;
407     }
408     if (s->got_extra_bits && crc_extra_bits != s->crc_extra_bits) {
409         av_log(s->avctx, AV_LOG_ERROR, "Extra bits CRC error\n");
410         return AVERROR_INVALIDDATA;
411     }
412
413     return 0;
414 }
415
416 static void init_ptable(int *table, int rate_i, int rate_s)
417 {
418     int value = 0x808000, rate = rate_i << 8;
419
420     for (int c = (rate + 128) >> 8; c--;)
421         value += (DOWN - value) >> DECAY;
422
423     for (int i = 0; i < PTABLE_BINS/2; i++) {
424         table[i] = value;
425         table[PTABLE_BINS-1-i] = 0x100ffff - value;
426
427         if (value > 0x010000) {
428             rate += (rate * rate_s + 128) >> 8;
429
430             for (int c = (rate + 64) >> 7; c--;)
431                 value += (DOWN - value) >> DECAY;
432         }
433     }
434 }
435
436 typedef struct {
437     int32_t value, fltr0, fltr1, fltr2, fltr3, fltr4, fltr5, fltr6, factor;
438     unsigned int byte;
439 } DSDfilters;
440
441 static int wv_unpack_dsd_high(WavpackFrameContext *s, uint8_t *dst_left, uint8_t *dst_right)
442 {
443     uint32_t checksum = 0xFFFFFFFF;
444     uint8_t *dst_l = dst_left, *dst_r = dst_right;
445     int total_samples = s->samples, stereo = dst_r ? 1 : 0;
446     DSDfilters filters[2], *sp = filters;
447     int rate_i, rate_s;
448     uint32_t low, high, value;
449
450     if (bytestream2_get_bytes_left(&s->gbyte) < (stereo ? 20 : 13))
451         return AVERROR_INVALIDDATA;
452
453     rate_i = bytestream2_get_byte(&s->gbyte);
454     rate_s = bytestream2_get_byte(&s->gbyte);
455
456     if (rate_s != RATE_S)
457         return AVERROR_INVALIDDATA;
458
459     init_ptable(s->ptable, rate_i, rate_s);
460
461     for (int channel = 0; channel < stereo + 1; channel++) {
462         DSDfilters *sp = filters + channel;
463
464         sp->fltr1 = bytestream2_get_byte(&s->gbyte) << (PRECISION - 8);
465         sp->fltr2 = bytestream2_get_byte(&s->gbyte) << (PRECISION - 8);
466         sp->fltr3 = bytestream2_get_byte(&s->gbyte) << (PRECISION - 8);
467         sp->fltr4 = bytestream2_get_byte(&s->gbyte) << (PRECISION - 8);
468         sp->fltr5 = bytestream2_get_byte(&s->gbyte) << (PRECISION - 8);
469         sp->fltr6 = 0;
470         sp->factor = bytestream2_get_byte(&s->gbyte) & 0xff;
471         sp->factor |= (bytestream2_get_byte(&s->gbyte) << 8) & 0xff00;
472         sp->factor = (int32_t)((uint32_t)sp->factor << 16) >> 16;
473     }
474
475     value = bytestream2_get_be32(&s->gbyte);
476     high = 0xffffffff;
477     low = 0x0;
478
479     while (total_samples--) {
480         int bitcount = 8;
481
482         sp[0].value = sp[0].fltr1 - sp[0].fltr5 + ((sp[0].fltr6 * sp[0].factor) >> 2);
483
484         if (stereo)
485             sp[1].value = sp[1].fltr1 - sp[1].fltr5 + ((sp[1].fltr6 * sp[1].factor) >> 2);
486
487         while (bitcount--) {
488             int32_t *pp = s->ptable + ((sp[0].value >> (PRECISION - PRECISION_USE)) & PTABLE_MASK);
489             uint32_t split = low + ((high - low) >> 8) * (*pp >> 16);
490
491             if (value <= split) {
492                 high = split;
493                 *pp += (UP - *pp) >> DECAY;
494                 sp[0].fltr0 = -1;
495             } else {
496                 low = split + 1;
497                 *pp += (DOWN - *pp) >> DECAY;
498                 sp[0].fltr0 = 0;
499             }
500
501             while (DSD_BYTE_READY(high, low) && bytestream2_get_bytes_left(&s->gbyte)) {
502                 value = (value << 8) | bytestream2_get_byte(&s->gbyte);
503                 high = (high << 8) | 0xff;
504                 low <<= 8;
505             }
506
507             sp[0].value += sp[0].fltr6 * 8;
508             sp[0].byte = (sp[0].byte << 1) | (sp[0].fltr0 & 1);
509             sp[0].factor += (((sp[0].value ^ sp[0].fltr0) >> 31) | 1) &
510                 ((sp[0].value ^ (sp[0].value - (sp[0].fltr6 * 16))) >> 31);
511             sp[0].fltr1 += ((sp[0].fltr0 & VALUE_ONE) - sp[0].fltr1) >> 6;
512             sp[0].fltr2 += ((sp[0].fltr0 & VALUE_ONE) - sp[0].fltr2) >> 4;
513             sp[0].fltr3 += (sp[0].fltr2 - sp[0].fltr3) >> 4;
514             sp[0].fltr4 += (sp[0].fltr3 - sp[0].fltr4) >> 4;
515             sp[0].value = (sp[0].fltr4 - sp[0].fltr5) >> 4;
516             sp[0].fltr5 += sp[0].value;
517             sp[0].fltr6 += (sp[0].value - sp[0].fltr6) >> 3;
518             sp[0].value = sp[0].fltr1 - sp[0].fltr5 + ((sp[0].fltr6 * sp[0].factor) >> 2);
519
520             if (!stereo)
521                 continue;
522
523             pp = s->ptable + ((sp[1].value >> (PRECISION - PRECISION_USE)) & PTABLE_MASK);
524             split = low + ((high - low) >> 8) * (*pp >> 16);
525
526             if (value <= split) {
527                 high = split;
528                 *pp += (UP - *pp) >> DECAY;
529                 sp[1].fltr0 = -1;
530             } else {
531                 low = split + 1;
532                 *pp += (DOWN - *pp) >> DECAY;
533                 sp[1].fltr0 = 0;
534             }
535
536             while (DSD_BYTE_READY(high, low) && bytestream2_get_bytes_left(&s->gbyte)) {
537                 value = (value << 8) | bytestream2_get_byte(&s->gbyte);
538                 high = (high << 8) | 0xff;
539                 low <<= 8;
540             }
541
542             sp[1].value += sp[1].fltr6 * 8;
543             sp[1].byte = (sp[1].byte << 1) | (sp[1].fltr0 & 1);
544             sp[1].factor += (((sp[1].value ^ sp[1].fltr0) >> 31) | 1) &
545                 ((sp[1].value ^ (sp[1].value - (sp[1].fltr6 * 16))) >> 31);
546             sp[1].fltr1 += ((sp[1].fltr0 & VALUE_ONE) - sp[1].fltr1) >> 6;
547             sp[1].fltr2 += ((sp[1].fltr0 & VALUE_ONE) - sp[1].fltr2) >> 4;
548             sp[1].fltr3 += (sp[1].fltr2 - sp[1].fltr3) >> 4;
549             sp[1].fltr4 += (sp[1].fltr3 - sp[1].fltr4) >> 4;
550             sp[1].value = (sp[1].fltr4 - sp[1].fltr5) >> 4;
551             sp[1].fltr5 += sp[1].value;
552             sp[1].fltr6 += (sp[1].value - sp[1].fltr6) >> 3;
553             sp[1].value = sp[1].fltr1 - sp[1].fltr5 + ((sp[1].fltr6 * sp[1].factor) >> 2);
554         }
555
556         checksum += (checksum << 1) + (*dst_l = sp[0].byte & 0xff);
557         sp[0].factor -= (sp[0].factor + 512) >> 10;
558         dst_l += 4;
559
560         if (stereo) {
561             checksum += (checksum << 1) + (*dst_r = filters[1].byte & 0xff);
562             filters[1].factor -= (filters[1].factor + 512) >> 10;
563             dst_r += 4;
564         }
565     }
566
567     if (wv_check_crc(s, checksum, 0)) {
568         if (s->avctx->err_recognition & AV_EF_CRCCHECK)
569             return AVERROR_INVALIDDATA;
570
571         memset(dst_left, 0x69, s->samples * 4);
572
573         if (dst_r)
574             memset(dst_right, 0x69, s->samples * 4);
575     }
576
577     return 0;
578 }
579
580 static int wv_unpack_dsd_fast(WavpackFrameContext *s, uint8_t *dst_left, uint8_t *dst_right)
581 {
582     uint8_t *dst_l = dst_left, *dst_r = dst_right;
583     uint8_t history_bits, max_probability;
584     int total_summed_probabilities  = 0;
585     int total_samples               = s->samples;
586     uint8_t *vlb                    = s->value_lookup_buffer;
587     int history_bins, p0, p1, chan;
588     uint32_t checksum               = 0xFFFFFFFF;
589     uint32_t low, high, value;
590
591     if (!bytestream2_get_bytes_left(&s->gbyte))
592         return AVERROR_INVALIDDATA;
593
594     history_bits = bytestream2_get_byte(&s->gbyte);
595
596     if (!bytestream2_get_bytes_left(&s->gbyte) || history_bits > MAX_HISTORY_BITS)
597         return AVERROR_INVALIDDATA;
598
599     history_bins = 1 << history_bits;
600     max_probability = bytestream2_get_byte(&s->gbyte);
601
602     if (max_probability < 0xff) {
603         uint8_t *outptr = (uint8_t *)s->probabilities;
604         uint8_t *outend = outptr + sizeof(*s->probabilities) * history_bins;
605
606         while (outptr < outend && bytestream2_get_bytes_left(&s->gbyte)) {
607             int code = bytestream2_get_byte(&s->gbyte);
608
609             if (code > max_probability) {
610                 int zcount = code - max_probability;
611
612                 while (outptr < outend && zcount--)
613                     *outptr++ = 0;
614             } else if (code) {
615                 *outptr++ = code;
616             }
617             else {
618                 break;
619             }
620         }
621
622         if (outptr < outend ||
623             (bytestream2_get_bytes_left(&s->gbyte) && bytestream2_get_byte(&s->gbyte)))
624                 return AVERROR_INVALIDDATA;
625     } else if (bytestream2_get_bytes_left(&s->gbyte) > (int)sizeof(*s->probabilities) * history_bins) {
626         bytestream2_get_buffer(&s->gbyte, (uint8_t *)s->probabilities,
627             sizeof(*s->probabilities) * history_bins);
628     } else {
629         return AVERROR_INVALIDDATA;
630     }
631
632     for (p0 = 0; p0 < history_bins; p0++) {
633         int32_t sum_values = 0;
634
635         for (int i = 0; i < 256; i++)
636             s->summed_probabilities[p0][i] = sum_values += s->probabilities[p0][i];
637
638         if (sum_values) {
639             total_summed_probabilities += sum_values;
640
641             if (total_summed_probabilities > history_bins * MAX_BIN_BYTES)
642                 return AVERROR_INVALIDDATA;
643
644             s->value_lookup[p0] = vlb;
645
646             for (int i = 0; i < 256; i++) {
647                 int c = s->probabilities[p0][i];
648
649                 while (c--)
650                     *vlb++ = i;
651             }
652         }
653     }
654
655     if (bytestream2_get_bytes_left(&s->gbyte) < 4)
656         return AVERROR_INVALIDDATA;
657
658     chan = p0 = p1 = 0;
659     low = 0; high = 0xffffffff;
660     value = bytestream2_get_be32(&s->gbyte);
661
662     if (dst_r)
663         total_samples *= 2;
664
665     while (total_samples--) {
666         unsigned int mult, index, code;
667
668         if (!s->summed_probabilities[p0][255])
669             return AVERROR_INVALIDDATA;
670
671         mult = (high - low) / s->summed_probabilities[p0][255];
672
673         if (!mult) {
674             if (bytestream2_get_bytes_left(&s->gbyte) >= 4)
675                 value = bytestream2_get_be32(&s->gbyte);
676
677             low = 0;
678             high = 0xffffffff;
679             mult = high / s->summed_probabilities[p0][255];
680
681             if (!mult)
682                 return AVERROR_INVALIDDATA;
683         }
684
685         index = (value - low) / mult;
686
687         if (index >= s->summed_probabilities[p0][255])
688             return AVERROR_INVALIDDATA;
689
690         if (!dst_r) {
691             if ((*dst_l = code = s->value_lookup[p0][index]))
692                 low += s->summed_probabilities[p0][code-1] * mult;
693
694             dst_l += 4;
695         } else {
696             if ((code = s->value_lookup[p0][index]))
697                 low += s->summed_probabilities[p0][code-1] * mult;
698
699             if (chan) {
700                 *dst_r = code;
701                 dst_r += 4;
702             }
703             else {
704                 *dst_l = code;
705                 dst_l += 4;
706             }
707
708             chan ^= 1;
709         }
710
711         high = low + s->probabilities[p0][code] * mult - 1;
712         checksum += (checksum << 1) + code;
713
714         if (!dst_r) {
715             p0 = code & (history_bins-1);
716         } else {
717             p0 = p1;
718             p1 = code & (history_bins-1);
719         }
720
721         while (DSD_BYTE_READY(high, low) && bytestream2_get_bytes_left(&s->gbyte)) {
722             value = (value << 8) | bytestream2_get_byte(&s->gbyte);
723             high = (high << 8) | 0xff;
724             low <<= 8;
725         }
726     }
727
728     if (wv_check_crc(s, checksum, 0)) {
729         if (s->avctx->err_recognition & AV_EF_CRCCHECK)
730             return AVERROR_INVALIDDATA;
731
732         memset(dst_left, 0x69, s->samples * 4);
733
734         if (dst_r)
735             memset(dst_right, 0x69, s->samples * 4);
736     }
737
738     return 0;
739 }
740
741 static int wv_unpack_dsd_copy(WavpackFrameContext *s, uint8_t *dst_left, uint8_t *dst_right)
742 {
743     uint8_t *dst_l = dst_left, *dst_r = dst_right;
744     int total_samples           = s->samples;
745     uint32_t checksum           = 0xFFFFFFFF;
746
747     if (bytestream2_get_bytes_left(&s->gbyte) != total_samples * (dst_r ? 2 : 1))
748         return AVERROR_INVALIDDATA;
749
750     while (total_samples--) {
751         checksum += (checksum << 1) + (*dst_l = bytestream2_get_byte(&s->gbyte));
752         dst_l += 4;
753
754         if (dst_r) {
755             checksum += (checksum << 1) + (*dst_r = bytestream2_get_byte(&s->gbyte));
756             dst_r += 4;
757         }
758     }
759
760     if (wv_check_crc(s, checksum, 0)) {
761         if (s->avctx->err_recognition & AV_EF_CRCCHECK)
762             return AVERROR_INVALIDDATA;
763
764         memset(dst_left, 0x69, s->samples * 4);
765
766         if (dst_r)
767             memset(dst_right, 0x69, s->samples * 4);
768     }
769
770     return 0;
771 }
772
773 static inline int wv_unpack_stereo(WavpackFrameContext *s, GetBitContext *gb,
774                                    void *dst_l, void *dst_r, const int type)
775 {
776     int i, j, count = 0;
777     int last, t;
778     int A, B, L, L2, R, R2;
779     int pos                 = 0;
780     uint32_t crc            = 0xFFFFFFFF;
781     uint32_t crc_extra_bits = 0xFFFFFFFF;
782     int16_t *dst16_l        = dst_l;
783     int16_t *dst16_r        = dst_r;
784     int32_t *dst32_l        = dst_l;
785     int32_t *dst32_r        = dst_r;
786     float *dstfl_l          = dst_l;
787     float *dstfl_r          = dst_r;
788
789     s->one = s->zero = s->zeroes = 0;
790     do {
791         L = wv_get_value(s, gb, 0, &last);
792         if (last)
793             break;
794         R = wv_get_value(s, gb, 1, &last);
795         if (last)
796             break;
797         for (i = 0; i < s->terms; i++) {
798             t = s->decorr[i].value;
799             if (t > 0) {
800                 if (t > 8) {
801                     if (t & 1) {
802                         A = 2U * s->decorr[i].samplesA[0] - s->decorr[i].samplesA[1];
803                         B = 2U * s->decorr[i].samplesB[0] - s->decorr[i].samplesB[1];
804                     } else {
805                         A = (int)(3U * s->decorr[i].samplesA[0] - s->decorr[i].samplesA[1]) >> 1;
806                         B = (int)(3U * s->decorr[i].samplesB[0] - s->decorr[i].samplesB[1]) >> 1;
807                     }
808                     s->decorr[i].samplesA[1] = s->decorr[i].samplesA[0];
809                     s->decorr[i].samplesB[1] = s->decorr[i].samplesB[0];
810                     j                        = 0;
811                 } else {
812                     A = s->decorr[i].samplesA[pos];
813                     B = s->decorr[i].samplesB[pos];
814                     j = (pos + t) & 7;
815                 }
816                 if (type != AV_SAMPLE_FMT_S16P) {
817                     L2 = L + ((s->decorr[i].weightA * (int64_t)A + 512) >> 10);
818                     R2 = R + ((s->decorr[i].weightB * (int64_t)B + 512) >> 10);
819                 } else {
820                     L2 = L + (unsigned)((int)(s->decorr[i].weightA * (unsigned)A + 512) >> 10);
821                     R2 = R + (unsigned)((int)(s->decorr[i].weightB * (unsigned)B + 512) >> 10);
822                 }
823                 if (A && L)
824                     s->decorr[i].weightA -= ((((L ^ A) >> 30) & 2) - 1) * s->decorr[i].delta;
825                 if (B && R)
826                     s->decorr[i].weightB -= ((((R ^ B) >> 30) & 2) - 1) * s->decorr[i].delta;
827                 s->decorr[i].samplesA[j] = L = L2;
828                 s->decorr[i].samplesB[j] = R = R2;
829             } else if (t == -1) {
830                 if (type != AV_SAMPLE_FMT_S16P)
831                     L2 = L + ((s->decorr[i].weightA * (int64_t)s->decorr[i].samplesA[0] + 512) >> 10);
832                 else
833                     L2 = L + (unsigned)((int)(s->decorr[i].weightA * (unsigned)s->decorr[i].samplesA[0] + 512) >> 10);
834                 UPDATE_WEIGHT_CLIP(s->decorr[i].weightA, s->decorr[i].delta, s->decorr[i].samplesA[0], L);
835                 L = L2;
836                 if (type != AV_SAMPLE_FMT_S16P)
837                     R2 = R + ((s->decorr[i].weightB * (int64_t)L2 + 512) >> 10);
838                 else
839                     R2 = R + (unsigned)((int)(s->decorr[i].weightB * (unsigned)L2 + 512) >> 10);
840                 UPDATE_WEIGHT_CLIP(s->decorr[i].weightB, s->decorr[i].delta, L2, R);
841                 R                        = R2;
842                 s->decorr[i].samplesA[0] = R;
843             } else {
844                 if (type != AV_SAMPLE_FMT_S16P)
845                     R2 = R + ((s->decorr[i].weightB * (int64_t)s->decorr[i].samplesB[0] + 512) >> 10);
846                 else
847                     R2 = R + (unsigned)((int)(s->decorr[i].weightB * (unsigned)s->decorr[i].samplesB[0] + 512) >> 10);
848                 UPDATE_WEIGHT_CLIP(s->decorr[i].weightB, s->decorr[i].delta, s->decorr[i].samplesB[0], R);
849                 R = R2;
850
851                 if (t == -3) {
852                     R2                       = s->decorr[i].samplesA[0];
853                     s->decorr[i].samplesA[0] = R;
854                 }
855
856                 if (type != AV_SAMPLE_FMT_S16P)
857                     L2 = L + ((s->decorr[i].weightA * (int64_t)R2 + 512) >> 10);
858                 else
859                     L2 = L + (unsigned)((int)(s->decorr[i].weightA * (unsigned)R2 + 512) >> 10);
860                 UPDATE_WEIGHT_CLIP(s->decorr[i].weightA, s->decorr[i].delta, R2, L);
861                 L                        = L2;
862                 s->decorr[i].samplesB[0] = L;
863             }
864         }
865
866         if (type == AV_SAMPLE_FMT_S16P) {
867             if (FFABS((int64_t)L) + FFABS((int64_t)R) > (1<<19)) {
868                 av_log(s->avctx, AV_LOG_ERROR, "sample %d %d too large\n", L, R);
869                 return AVERROR_INVALIDDATA;
870             }
871         }
872
873         pos = (pos + 1) & 7;
874         if (s->joint)
875             L += (unsigned)(R -= (unsigned)(L >> 1));
876         crc = (crc * 3 + L) * 3 + R;
877
878         if (type == AV_SAMPLE_FMT_FLTP) {
879             *dstfl_l++ = wv_get_value_float(s, &crc_extra_bits, L);
880             *dstfl_r++ = wv_get_value_float(s, &crc_extra_bits, R);
881         } else if (type == AV_SAMPLE_FMT_S32P) {
882             *dst32_l++ = wv_get_value_integer(s, &crc_extra_bits, L);
883             *dst32_r++ = wv_get_value_integer(s, &crc_extra_bits, R);
884         } else {
885             *dst16_l++ = wv_get_value_integer(s, &crc_extra_bits, L);
886             *dst16_r++ = wv_get_value_integer(s, &crc_extra_bits, R);
887         }
888         count++;
889     } while (!last && count < s->samples);
890
891     if (last && count < s->samples) {
892         int size = av_get_bytes_per_sample(type);
893         memset((uint8_t*)dst_l + count*size, 0, (s->samples-count)*size);
894         memset((uint8_t*)dst_r + count*size, 0, (s->samples-count)*size);
895     }
896
897     if ((s->avctx->err_recognition & AV_EF_CRCCHECK) &&
898         wv_check_crc(s, crc, crc_extra_bits))
899         return AVERROR_INVALIDDATA;
900
901     return 0;
902 }
903
904 static inline int wv_unpack_mono(WavpackFrameContext *s, GetBitContext *gb,
905                                  void *dst, const int type)
906 {
907     int i, j, count = 0;
908     int last, t;
909     int A, S, T;
910     int pos                  = 0;
911     uint32_t crc             = 0xFFFFFFFF;
912     uint32_t crc_extra_bits  = 0xFFFFFFFF;
913     int16_t *dst16           = dst;
914     int32_t *dst32           = dst;
915     float *dstfl             = dst;
916
917     s->one = s->zero = s->zeroes = 0;
918     do {
919         T = wv_get_value(s, gb, 0, &last);
920         S = 0;
921         if (last)
922             break;
923         for (i = 0; i < s->terms; i++) {
924             t = s->decorr[i].value;
925             if (t > 8) {
926                 if (t & 1)
927                     A =  2U * s->decorr[i].samplesA[0] - s->decorr[i].samplesA[1];
928                 else
929                     A = (int)(3U * s->decorr[i].samplesA[0] - s->decorr[i].samplesA[1]) >> 1;
930                 s->decorr[i].samplesA[1] = s->decorr[i].samplesA[0];
931                 j                        = 0;
932             } else {
933                 A = s->decorr[i].samplesA[pos];
934                 j = (pos + t) & 7;
935             }
936             if (type != AV_SAMPLE_FMT_S16P)
937                 S = T + ((s->decorr[i].weightA * (int64_t)A + 512) >> 10);
938             else
939                 S = T + (unsigned)((int)(s->decorr[i].weightA * (unsigned)A + 512) >> 10);
940             if (A && T)
941                 s->decorr[i].weightA -= ((((T ^ A) >> 30) & 2) - 1) * s->decorr[i].delta;
942             s->decorr[i].samplesA[j] = T = S;
943         }
944         pos = (pos + 1) & 7;
945         crc = crc * 3 + S;
946
947         if (type == AV_SAMPLE_FMT_FLTP) {
948             *dstfl++ = wv_get_value_float(s, &crc_extra_bits, S);
949         } else if (type == AV_SAMPLE_FMT_S32P) {
950             *dst32++ = wv_get_value_integer(s, &crc_extra_bits, S);
951         } else {
952             *dst16++ = wv_get_value_integer(s, &crc_extra_bits, S);
953         }
954         count++;
955     } while (!last && count < s->samples);
956
957     if (last && count < s->samples) {
958         int size = av_get_bytes_per_sample(type);
959         memset((uint8_t*)dst + count*size, 0, (s->samples-count)*size);
960     }
961
962     if (s->avctx->err_recognition & AV_EF_CRCCHECK) {
963         int ret = wv_check_crc(s, crc, crc_extra_bits);
964         if (ret < 0 && s->avctx->err_recognition & AV_EF_EXPLODE)
965             return ret;
966     }
967
968     return 0;
969 }
970
971 static av_cold int wv_alloc_frame_context(WavpackContext *c)
972 {
973     if (c->fdec_num == WV_MAX_FRAME_DECODERS)
974         return -1;
975
976     c->fdec[c->fdec_num] = av_mallocz(sizeof(**c->fdec));
977     if (!c->fdec[c->fdec_num])
978         return -1;
979     c->fdec_num++;
980     c->fdec[c->fdec_num - 1]->avctx = c->avctx;
981
982     return 0;
983 }
984
985 static int wv_dsd_reset(WavpackContext *s, int channels)
986 {
987     int i;
988
989     s->dsdctx = NULL;
990     s->dsd_channels = 0;
991     av_buffer_unref(&s->dsd_ref);
992
993     if (!channels)
994         return 0;
995
996     if (channels > INT_MAX / sizeof(*s->dsdctx))
997         return AVERROR(EINVAL);
998
999     s->dsd_ref = av_buffer_allocz(channels * sizeof(*s->dsdctx));
1000     if (!s->dsd_ref)
1001         return AVERROR(ENOMEM);
1002     s->dsdctx = (DSDContext*)s->dsd_ref->data;
1003     s->dsd_channels = channels;
1004
1005     for (i = 0; i < channels; i++)
1006         memset(s->dsdctx[i].buf, 0x69, sizeof(s->dsdctx[i].buf));
1007
1008     return 0;
1009 }
1010
1011 #if HAVE_THREADS
1012 static int init_thread_copy(AVCodecContext *avctx)
1013 {
1014     WavpackContext *s = avctx->priv_data;
1015     s->avctx = avctx;
1016
1017     s->curr_frame.f = av_frame_alloc();
1018     s->prev_frame.f = av_frame_alloc();
1019
1020     if (!s->curr_frame.f || !s->prev_frame.f)
1021         return AVERROR(ENOMEM);
1022
1023     return 0;
1024 }
1025
1026 static int update_thread_context(AVCodecContext *dst, const AVCodecContext *src)
1027 {
1028     WavpackContext *fsrc = src->priv_data;
1029     WavpackContext *fdst = dst->priv_data;
1030     int ret;
1031
1032     if (dst == src)
1033         return 0;
1034
1035     ff_thread_release_buffer(dst, &fdst->curr_frame);
1036     if (fsrc->curr_frame.f->data[0]) {
1037         if ((ret = ff_thread_ref_frame(&fdst->curr_frame, &fsrc->curr_frame)) < 0)
1038             return ret;
1039     }
1040
1041     av_buffer_unref(&fdst->dsd_ref);
1042     fdst->dsdctx = NULL;
1043     fdst->dsd_channels = 0;
1044     if (fsrc->dsd_ref) {
1045         fdst->dsd_ref = av_buffer_ref(fsrc->dsd_ref);
1046         if (!fdst->dsd_ref)
1047             return AVERROR(ENOMEM);
1048         fdst->dsdctx = (DSDContext*)fdst->dsd_ref->data;
1049         fdst->dsd_channels = fsrc->dsd_channels;
1050     }
1051
1052     return 0;
1053 }
1054 #endif
1055
1056 static av_cold int wavpack_decode_init(AVCodecContext *avctx)
1057 {
1058     WavpackContext *s = avctx->priv_data;
1059
1060     s->avctx = avctx;
1061
1062     s->fdec_num = 0;
1063
1064     avctx->internal->allocate_progress = 1;
1065
1066     s->curr_frame.f = av_frame_alloc();
1067     s->prev_frame.f = av_frame_alloc();
1068
1069     if (!s->curr_frame.f || !s->prev_frame.f)
1070         return AVERROR(ENOMEM);
1071
1072     ff_init_dsd_data();
1073
1074     return 0;
1075 }
1076
1077 static av_cold int wavpack_decode_end(AVCodecContext *avctx)
1078 {
1079     WavpackContext *s = avctx->priv_data;
1080
1081     for (int i = 0; i < s->fdec_num; i++)
1082         av_freep(&s->fdec[i]);
1083     s->fdec_num = 0;
1084
1085     ff_thread_release_buffer(avctx, &s->curr_frame);
1086     av_frame_free(&s->curr_frame.f);
1087
1088     ff_thread_release_buffer(avctx, &s->prev_frame);
1089     av_frame_free(&s->prev_frame.f);
1090
1091     av_buffer_unref(&s->dsd_ref);
1092
1093     return 0;
1094 }
1095
1096 static int wavpack_decode_block(AVCodecContext *avctx, int block_no,
1097                                 const uint8_t *buf, int buf_size)
1098 {
1099     WavpackContext *wc = avctx->priv_data;
1100     WavpackFrameContext *s;
1101     GetByteContext gb;
1102     enum AVSampleFormat sample_fmt;
1103     void *samples_l = NULL, *samples_r = NULL;
1104     int ret;
1105     int got_terms   = 0, got_weights = 0, got_samples = 0,
1106         got_entropy = 0, got_pcm     = 0, got_float   = 0, got_hybrid = 0;
1107     int got_dsd = 0;
1108     int i, j, id, size, ssize, weights, t;
1109     int bpp, chan = 0, orig_bpp, sample_rate = 0, rate_x = 1, dsd_mode = 0;
1110     int multiblock;
1111     uint64_t chmask = 0;
1112
1113     if (block_no >= wc->fdec_num && wv_alloc_frame_context(wc) < 0) {
1114         av_log(avctx, AV_LOG_ERROR, "Error creating frame decode context\n");
1115         return AVERROR_INVALIDDATA;
1116     }
1117
1118     s = wc->fdec[block_no];
1119     if (!s) {
1120         av_log(avctx, AV_LOG_ERROR, "Context for block %d is not present\n",
1121                block_no);
1122         return AVERROR_INVALIDDATA;
1123     }
1124
1125     memset(s->decorr, 0, MAX_TERMS * sizeof(Decorr));
1126     memset(s->ch, 0, sizeof(s->ch));
1127     s->extra_bits     = 0;
1128     s->and            = s->or = s->shift = 0;
1129     s->got_extra_bits = 0;
1130
1131     bytestream2_init(&gb, buf, buf_size);
1132
1133     s->samples = bytestream2_get_le32(&gb);
1134     if (s->samples != wc->samples) {
1135         av_log(avctx, AV_LOG_ERROR, "Mismatching number of samples in "
1136                "a sequence: %d and %d\n", wc->samples, s->samples);
1137         return AVERROR_INVALIDDATA;
1138     }
1139     s->frame_flags = bytestream2_get_le32(&gb);
1140
1141     if (s->frame_flags & (WV_FLOAT_DATA | WV_DSD_DATA))
1142         sample_fmt = AV_SAMPLE_FMT_FLTP;
1143     else if ((s->frame_flags & 0x03) <= 1)
1144         sample_fmt = AV_SAMPLE_FMT_S16P;
1145     else
1146         sample_fmt          = AV_SAMPLE_FMT_S32P;
1147
1148     bpp            = av_get_bytes_per_sample(sample_fmt);
1149     orig_bpp       = ((s->frame_flags & 0x03) + 1) << 3;
1150     multiblock     = (s->frame_flags & WV_SINGLE_BLOCK) != WV_SINGLE_BLOCK;
1151
1152     s->stereo         = !(s->frame_flags & WV_MONO);
1153     s->stereo_in      =  (s->frame_flags & WV_FALSE_STEREO) ? 0 : s->stereo;
1154     s->joint          =   s->frame_flags & WV_JOINT_STEREO;
1155     s->hybrid         =   s->frame_flags & WV_HYBRID_MODE;
1156     s->hybrid_bitrate =   s->frame_flags & WV_HYBRID_BITRATE;
1157     s->post_shift     = bpp * 8 - orig_bpp + ((s->frame_flags >> 13) & 0x1f);
1158     if (s->post_shift < 0 || s->post_shift > 31) {
1159         return AVERROR_INVALIDDATA;
1160     }
1161     s->hybrid_maxclip =  ((1LL << (orig_bpp - 1)) - 1);
1162     s->hybrid_minclip = ((-1UL << (orig_bpp - 1)));
1163     s->CRC            = bytestream2_get_le32(&gb);
1164
1165     // parse metadata blocks
1166     while (bytestream2_get_bytes_left(&gb)) {
1167         id   = bytestream2_get_byte(&gb);
1168         size = bytestream2_get_byte(&gb);
1169         if (id & WP_IDF_LONG)
1170             size |= (bytestream2_get_le16u(&gb)) << 8;
1171         size <<= 1; // size is specified in words
1172         ssize  = size;
1173         if (id & WP_IDF_ODD)
1174             size--;
1175         if (size < 0) {
1176             av_log(avctx, AV_LOG_ERROR,
1177                    "Got incorrect block %02X with size %i\n", id, size);
1178             break;
1179         }
1180         if (bytestream2_get_bytes_left(&gb) < ssize) {
1181             av_log(avctx, AV_LOG_ERROR,
1182                    "Block size %i is out of bounds\n", size);
1183             break;
1184         }
1185         switch (id & WP_IDF_MASK) {
1186         case WP_ID_DECTERMS:
1187             if (size > MAX_TERMS) {
1188                 av_log(avctx, AV_LOG_ERROR, "Too many decorrelation terms\n");
1189                 s->terms = 0;
1190                 bytestream2_skip(&gb, ssize);
1191                 continue;
1192             }
1193             s->terms = size;
1194             for (i = 0; i < s->terms; i++) {
1195                 uint8_t val = bytestream2_get_byte(&gb);
1196                 s->decorr[s->terms - i - 1].value = (val & 0x1F) - 5;
1197                 s->decorr[s->terms - i - 1].delta =  val >> 5;
1198             }
1199             got_terms = 1;
1200             break;
1201         case WP_ID_DECWEIGHTS:
1202             if (!got_terms) {
1203                 av_log(avctx, AV_LOG_ERROR, "No decorrelation terms met\n");
1204                 continue;
1205             }
1206             weights = size >> s->stereo_in;
1207             if (weights > MAX_TERMS || weights > s->terms) {
1208                 av_log(avctx, AV_LOG_ERROR, "Too many decorrelation weights\n");
1209                 bytestream2_skip(&gb, ssize);
1210                 continue;
1211             }
1212             for (i = 0; i < weights; i++) {
1213                 t = (int8_t)bytestream2_get_byte(&gb);
1214                 s->decorr[s->terms - i - 1].weightA = t * (1 << 3);
1215                 if (s->decorr[s->terms - i - 1].weightA > 0)
1216                     s->decorr[s->terms - i - 1].weightA +=
1217                         (s->decorr[s->terms - i - 1].weightA + 64) >> 7;
1218                 if (s->stereo_in) {
1219                     t = (int8_t)bytestream2_get_byte(&gb);
1220                     s->decorr[s->terms - i - 1].weightB = t * (1 << 3);
1221                     if (s->decorr[s->terms - i - 1].weightB > 0)
1222                         s->decorr[s->terms - i - 1].weightB +=
1223                             (s->decorr[s->terms - i - 1].weightB + 64) >> 7;
1224                 }
1225             }
1226             got_weights = 1;
1227             break;
1228         case WP_ID_DECSAMPLES:
1229             if (!got_terms) {
1230                 av_log(avctx, AV_LOG_ERROR, "No decorrelation terms met\n");
1231                 continue;
1232             }
1233             t = 0;
1234             for (i = s->terms - 1; (i >= 0) && (t < size); i--) {
1235                 if (s->decorr[i].value > 8) {
1236                     s->decorr[i].samplesA[0] =
1237                         wp_exp2(bytestream2_get_le16(&gb));
1238                     s->decorr[i].samplesA[1] =
1239                         wp_exp2(bytestream2_get_le16(&gb));
1240
1241                     if (s->stereo_in) {
1242                         s->decorr[i].samplesB[0] =
1243                             wp_exp2(bytestream2_get_le16(&gb));
1244                         s->decorr[i].samplesB[1] =
1245                             wp_exp2(bytestream2_get_le16(&gb));
1246                         t                       += 4;
1247                     }
1248                     t += 4;
1249                 } else if (s->decorr[i].value < 0) {
1250                     s->decorr[i].samplesA[0] =
1251                         wp_exp2(bytestream2_get_le16(&gb));
1252                     s->decorr[i].samplesB[0] =
1253                         wp_exp2(bytestream2_get_le16(&gb));
1254                     t                       += 4;
1255                 } else {
1256                     for (j = 0; j < s->decorr[i].value; j++) {
1257                         s->decorr[i].samplesA[j] =
1258                             wp_exp2(bytestream2_get_le16(&gb));
1259                         if (s->stereo_in) {
1260                             s->decorr[i].samplesB[j] =
1261                                 wp_exp2(bytestream2_get_le16(&gb));
1262                         }
1263                     }
1264                     t += s->decorr[i].value * 2 * (s->stereo_in + 1);
1265                 }
1266             }
1267             got_samples = 1;
1268             break;
1269         case WP_ID_ENTROPY:
1270             if (size != 6 * (s->stereo_in + 1)) {
1271                 av_log(avctx, AV_LOG_ERROR,
1272                        "Entropy vars size should be %i, got %i.\n",
1273                        6 * (s->stereo_in + 1), size);
1274                 bytestream2_skip(&gb, ssize);
1275                 continue;
1276             }
1277             for (j = 0; j <= s->stereo_in; j++)
1278                 for (i = 0; i < 3; i++) {
1279                     s->ch[j].median[i] = wp_exp2(bytestream2_get_le16(&gb));
1280                 }
1281             got_entropy = 1;
1282             break;
1283         case WP_ID_HYBRID:
1284             if (s->hybrid_bitrate) {
1285                 for (i = 0; i <= s->stereo_in; i++) {
1286                     s->ch[i].slow_level = wp_exp2(bytestream2_get_le16(&gb));
1287                     size               -= 2;
1288                 }
1289             }
1290             for (i = 0; i < (s->stereo_in + 1); i++) {
1291                 s->ch[i].bitrate_acc = bytestream2_get_le16(&gb) << 16;
1292                 size                -= 2;
1293             }
1294             if (size > 0) {
1295                 for (i = 0; i < (s->stereo_in + 1); i++) {
1296                     s->ch[i].bitrate_delta =
1297                         wp_exp2((int16_t)bytestream2_get_le16(&gb));
1298                 }
1299             } else {
1300                 for (i = 0; i < (s->stereo_in + 1); i++)
1301                     s->ch[i].bitrate_delta = 0;
1302             }
1303             got_hybrid = 1;
1304             break;
1305         case WP_ID_INT32INFO: {
1306             uint8_t val[4];
1307             if (size != 4) {
1308                 av_log(avctx, AV_LOG_ERROR,
1309                        "Invalid INT32INFO, size = %i\n",
1310                        size);
1311                 bytestream2_skip(&gb, ssize - 4);
1312                 continue;
1313             }
1314             bytestream2_get_buffer(&gb, val, 4);
1315             if (val[0] > 30) {
1316                 av_log(avctx, AV_LOG_ERROR,
1317                        "Invalid INT32INFO, extra_bits = %d (> 30)\n", val[0]);
1318                 continue;
1319             } else if (val[0]) {
1320                 s->extra_bits = val[0];
1321             } else if (val[1]) {
1322                 s->shift = val[1];
1323             } else if (val[2]) {
1324                 s->and   = s->or = 1;
1325                 s->shift = val[2];
1326             } else if (val[3]) {
1327                 s->and   = 1;
1328                 s->shift = val[3];
1329             }
1330             if (s->shift > 31) {
1331                 av_log(avctx, AV_LOG_ERROR,
1332                        "Invalid INT32INFO, shift = %d (> 31)\n", s->shift);
1333                 s->and = s->or = s->shift = 0;
1334                 continue;
1335             }
1336             /* original WavPack decoder forces 32-bit lossy sound to be treated
1337              * as 24-bit one in order to have proper clipping */
1338             if (s->hybrid && bpp == 4 && s->post_shift < 8 && s->shift > 8) {
1339                 s->post_shift      += 8;
1340                 s->shift           -= 8;
1341                 s->hybrid_maxclip >>= 8;
1342                 s->hybrid_minclip >>= 8;
1343             }
1344             break;
1345         }
1346         case WP_ID_FLOATINFO:
1347             if (size != 4) {
1348                 av_log(avctx, AV_LOG_ERROR,
1349                        "Invalid FLOATINFO, size = %i\n", size);
1350                 bytestream2_skip(&gb, ssize);
1351                 continue;
1352             }
1353             s->float_flag    = bytestream2_get_byte(&gb);
1354             s->float_shift   = bytestream2_get_byte(&gb);
1355             s->float_max_exp = bytestream2_get_byte(&gb);
1356             if (s->float_shift > 31) {
1357                 av_log(avctx, AV_LOG_ERROR,
1358                        "Invalid FLOATINFO, shift = %d (> 31)\n", s->float_shift);
1359                 s->float_shift = 0;
1360                 continue;
1361             }
1362             got_float        = 1;
1363             bytestream2_skip(&gb, 1);
1364             break;
1365         case WP_ID_DATA:
1366             if ((ret = init_get_bits8(&s->gb, gb.buffer, size)) < 0)
1367                 return ret;
1368             bytestream2_skip(&gb, size);
1369             got_pcm      = 1;
1370             break;
1371         case WP_ID_DSD_DATA:
1372             if (size < 2) {
1373                 av_log(avctx, AV_LOG_ERROR, "Invalid DSD_DATA, size = %i\n",
1374                        size);
1375                 bytestream2_skip(&gb, ssize);
1376                 continue;
1377             }
1378             rate_x = 1 << bytestream2_get_byte(&gb);
1379             dsd_mode = bytestream2_get_byte(&gb);
1380             if (dsd_mode && dsd_mode != 1 && dsd_mode != 3) {
1381                 av_log(avctx, AV_LOG_ERROR, "Invalid DSD encoding mode: %d\n",
1382                     dsd_mode);
1383                 return AVERROR_INVALIDDATA;
1384             }
1385             bytestream2_init(&s->gbyte, gb.buffer, size-2);
1386             bytestream2_skip(&gb, size-2);
1387             got_dsd      = 1;
1388             break;
1389         case WP_ID_EXTRABITS:
1390             if (size <= 4) {
1391                 av_log(avctx, AV_LOG_ERROR, "Invalid EXTRABITS, size = %i\n",
1392                        size);
1393                 bytestream2_skip(&gb, size);
1394                 continue;
1395             }
1396             if ((ret = init_get_bits8(&s->gb_extra_bits, gb.buffer, size)) < 0)
1397                 return ret;
1398             s->crc_extra_bits  = get_bits_long(&s->gb_extra_bits, 32);
1399             bytestream2_skip(&gb, size);
1400             s->got_extra_bits  = 1;
1401             break;
1402         case WP_ID_CHANINFO:
1403             if (size <= 1) {
1404                 av_log(avctx, AV_LOG_ERROR,
1405                        "Insufficient channel information\n");
1406                 return AVERROR_INVALIDDATA;
1407             }
1408             chan = bytestream2_get_byte(&gb);
1409             switch (size - 2) {
1410             case 0:
1411                 chmask = bytestream2_get_byte(&gb);
1412                 break;
1413             case 1:
1414                 chmask = bytestream2_get_le16(&gb);
1415                 break;
1416             case 2:
1417                 chmask = bytestream2_get_le24(&gb);
1418                 break;
1419             case 3:
1420                 chmask = bytestream2_get_le32(&gb);
1421                 break;
1422             case 4:
1423                 size = bytestream2_get_byte(&gb);
1424                 chan  |= (bytestream2_get_byte(&gb) & 0xF) << 8;
1425                 chan  += 1;
1426                 if (avctx->channels != chan)
1427                     av_log(avctx, AV_LOG_WARNING, "%i channels signalled"
1428                            " instead of %i.\n", chan, avctx->channels);
1429                 chmask = bytestream2_get_le24(&gb);
1430                 break;
1431             case 5:
1432                 size = bytestream2_get_byte(&gb);
1433                 chan  |= (bytestream2_get_byte(&gb) & 0xF) << 8;
1434                 chan  += 1;
1435                 if (avctx->channels != chan)
1436                     av_log(avctx, AV_LOG_WARNING, "%i channels signalled"
1437                            " instead of %i.\n", chan, avctx->channels);
1438                 chmask = bytestream2_get_le32(&gb);
1439                 break;
1440             default:
1441                 av_log(avctx, AV_LOG_ERROR, "Invalid channel info size %d\n",
1442                        size);
1443                 chan   = avctx->channels;
1444                 chmask = avctx->channel_layout;
1445             }
1446             break;
1447         case WP_ID_SAMPLE_RATE:
1448             if (size != 3) {
1449                 av_log(avctx, AV_LOG_ERROR, "Invalid custom sample rate.\n");
1450                 return AVERROR_INVALIDDATA;
1451             }
1452             sample_rate = bytestream2_get_le24(&gb);
1453             break;
1454         default:
1455             bytestream2_skip(&gb, size);
1456         }
1457         if (id & WP_IDF_ODD)
1458             bytestream2_skip(&gb, 1);
1459     }
1460
1461     if (got_pcm) {
1462         if (!got_terms) {
1463             av_log(avctx, AV_LOG_ERROR, "No block with decorrelation terms\n");
1464             return AVERROR_INVALIDDATA;
1465         }
1466         if (!got_weights) {
1467             av_log(avctx, AV_LOG_ERROR, "No block with decorrelation weights\n");
1468             return AVERROR_INVALIDDATA;
1469         }
1470         if (!got_samples) {
1471             av_log(avctx, AV_LOG_ERROR, "No block with decorrelation samples\n");
1472             return AVERROR_INVALIDDATA;
1473         }
1474         if (!got_entropy) {
1475             av_log(avctx, AV_LOG_ERROR, "No block with entropy info\n");
1476             return AVERROR_INVALIDDATA;
1477         }
1478         if (s->hybrid && !got_hybrid) {
1479             av_log(avctx, AV_LOG_ERROR, "Hybrid config not found\n");
1480             return AVERROR_INVALIDDATA;
1481         }
1482         if (!got_float && sample_fmt == AV_SAMPLE_FMT_FLTP) {
1483             av_log(avctx, AV_LOG_ERROR, "Float information not found\n");
1484             return AVERROR_INVALIDDATA;
1485         }
1486         if (s->got_extra_bits && sample_fmt != AV_SAMPLE_FMT_FLTP) {
1487             const int size   = get_bits_left(&s->gb_extra_bits);
1488             const int wanted = s->samples * s->extra_bits << s->stereo_in;
1489             if (size < wanted) {
1490                 av_log(avctx, AV_LOG_ERROR, "Too small EXTRABITS\n");
1491                 s->got_extra_bits = 0;
1492             }
1493         }
1494     }
1495
1496     if (!got_pcm && !got_dsd) {
1497         av_log(avctx, AV_LOG_ERROR, "Packed samples not found\n");
1498         return AVERROR_INVALIDDATA;
1499     }
1500
1501     if ((got_pcm && wc->modulation != MODULATION_PCM) ||
1502         (got_dsd && wc->modulation != MODULATION_DSD)) {
1503             av_log(avctx, AV_LOG_ERROR, "Invalid PCM/DSD mix encountered\n");
1504             return AVERROR_INVALIDDATA;
1505     }
1506
1507     if (!wc->ch_offset) {
1508         int      new_channels = avctx->channels;
1509         uint64_t new_chmask   = avctx->channel_layout;
1510         int new_samplerate;
1511         int sr = (s->frame_flags >> 23) & 0xf;
1512         if (sr == 0xf) {
1513             if (!sample_rate) {
1514                 av_log(avctx, AV_LOG_ERROR, "Custom sample rate missing.\n");
1515                 return AVERROR_INVALIDDATA;
1516             }
1517             new_samplerate = sample_rate * rate_x;
1518         } else
1519             new_samplerate = wv_rates[sr] * rate_x;
1520
1521         if (multiblock) {
1522             if (chan)
1523                 new_channels = chan;
1524             if (chmask)
1525                 new_chmask = chmask;
1526         } else {
1527             new_channels = s->stereo ? 2 : 1;
1528             new_chmask   = s->stereo ? AV_CH_LAYOUT_STEREO :
1529                                        AV_CH_LAYOUT_MONO;
1530         }
1531
1532         if (new_chmask &&
1533             av_get_channel_layout_nb_channels(new_chmask) != new_channels) {
1534             av_log(avctx, AV_LOG_ERROR, "Channel mask does not match the channel count\n");
1535             return AVERROR_INVALIDDATA;
1536         }
1537
1538         /* clear DSD state if stream properties change */
1539         if (new_channels   != wc->dsd_channels      ||
1540             new_chmask     != avctx->channel_layout ||
1541             new_samplerate != avctx->sample_rate    ||
1542             !!got_dsd      != !!wc->dsdctx) {
1543             ret = wv_dsd_reset(wc, got_dsd ? new_channels : 0);
1544             if (ret < 0) {
1545                 av_log(avctx, AV_LOG_ERROR, "Error reinitializing the DSD context\n");
1546                 return ret;
1547             }
1548             ff_thread_release_buffer(avctx, &wc->curr_frame);
1549         }
1550         avctx->channels            = new_channels;
1551         avctx->channel_layout      = new_chmask;
1552         avctx->sample_rate         = new_samplerate;
1553         avctx->sample_fmt          = sample_fmt;
1554         avctx->bits_per_raw_sample = orig_bpp;
1555
1556         ff_thread_release_buffer(avctx, &wc->prev_frame);
1557         FFSWAP(ThreadFrame, wc->curr_frame, wc->prev_frame);
1558
1559         /* get output buffer */
1560         wc->curr_frame.f->nb_samples = s->samples;
1561         if ((ret = ff_thread_get_buffer(avctx, &wc->curr_frame, AV_GET_BUFFER_FLAG_REF)) < 0)
1562             return ret;
1563
1564         wc->frame = wc->curr_frame.f;
1565         ff_thread_finish_setup(avctx);
1566     }
1567
1568     if (wc->ch_offset + s->stereo >= avctx->channels) {
1569         av_log(avctx, AV_LOG_WARNING, "Too many channels coded in a packet.\n");
1570         return ((avctx->err_recognition & AV_EF_EXPLODE) || !wc->ch_offset) ? AVERROR_INVALIDDATA : 0;
1571     }
1572
1573     samples_l = wc->frame->extended_data[wc->ch_offset];
1574     if (s->stereo)
1575         samples_r = wc->frame->extended_data[wc->ch_offset + 1];
1576
1577     wc->ch_offset += 1 + s->stereo;
1578
1579     if (s->stereo_in) {
1580         if (got_dsd) {
1581             if (dsd_mode == 3) {
1582                 ret = wv_unpack_dsd_high(s, samples_l, samples_r);
1583             } else if (dsd_mode == 1) {
1584                 ret = wv_unpack_dsd_fast(s, samples_l, samples_r);
1585             } else {
1586                 ret = wv_unpack_dsd_copy(s, samples_l, samples_r);
1587             }
1588         } else {
1589             ret = wv_unpack_stereo(s, &s->gb, samples_l, samples_r, avctx->sample_fmt);
1590         }
1591         if (ret < 0)
1592             return ret;
1593     } else {
1594         if (got_dsd) {
1595             if (dsd_mode == 3) {
1596                 ret = wv_unpack_dsd_high(s, samples_l, NULL);
1597             } else if (dsd_mode == 1) {
1598                 ret = wv_unpack_dsd_fast(s, samples_l, NULL);
1599             } else {
1600                 ret = wv_unpack_dsd_copy(s, samples_l, NULL);
1601             }
1602         } else {
1603             ret = wv_unpack_mono(s, &s->gb, samples_l, avctx->sample_fmt);
1604         }
1605         if (ret < 0)
1606             return ret;
1607
1608         if (s->stereo)
1609             memcpy(samples_r, samples_l, bpp * s->samples);
1610     }
1611
1612     return 0;
1613 }
1614
1615 static void wavpack_decode_flush(AVCodecContext *avctx)
1616 {
1617     WavpackContext *s = avctx->priv_data;
1618
1619     wv_dsd_reset(s, 0);
1620 }
1621
1622 static int dsd_channel(AVCodecContext *avctx, void *frmptr, int jobnr, int threadnr)
1623 {
1624     WavpackContext *s  = avctx->priv_data;
1625     AVFrame *frame = frmptr;
1626
1627     ff_dsd2pcm_translate (&s->dsdctx [jobnr], s->samples, 0,
1628         (uint8_t *)frame->extended_data[jobnr], 4,
1629         (float *)frame->extended_data[jobnr], 1);
1630
1631     return 0;
1632 }
1633
1634 static int wavpack_decode_frame(AVCodecContext *avctx, void *data,
1635                                 int *got_frame_ptr, AVPacket *avpkt)
1636 {
1637     WavpackContext *s  = avctx->priv_data;
1638     const uint8_t *buf = avpkt->data;
1639     int buf_size       = avpkt->size;
1640     int frame_size, ret, frame_flags;
1641
1642     if (avpkt->size <= WV_HEADER_SIZE)
1643         return AVERROR_INVALIDDATA;
1644
1645     s->frame     = NULL;
1646     s->block     = 0;
1647     s->ch_offset = 0;
1648
1649     /* determine number of samples */
1650     s->samples  = AV_RL32(buf + 20);
1651     frame_flags = AV_RL32(buf + 24);
1652     if (s->samples <= 0 || s->samples > WV_MAX_SAMPLES) {
1653         av_log(avctx, AV_LOG_ERROR, "Invalid number of samples: %d\n",
1654                s->samples);
1655         return AVERROR_INVALIDDATA;
1656     }
1657
1658     s->modulation = (frame_flags & WV_DSD_DATA) ? MODULATION_DSD : MODULATION_PCM;
1659
1660     while (buf_size > WV_HEADER_SIZE) {
1661         frame_size = AV_RL32(buf + 4) - 12;
1662         buf       += 20;
1663         buf_size  -= 20;
1664         if (frame_size <= 0 || frame_size > buf_size) {
1665             av_log(avctx, AV_LOG_ERROR,
1666                    "Block %d has invalid size (size %d vs. %d bytes left)\n",
1667                    s->block, frame_size, buf_size);
1668             ret = AVERROR_INVALIDDATA;
1669             goto error;
1670         }
1671         if ((ret = wavpack_decode_block(avctx, s->block, buf, frame_size)) < 0)
1672             goto error;
1673         s->block++;
1674         buf      += frame_size;
1675         buf_size -= frame_size;
1676     }
1677
1678     if (s->ch_offset != avctx->channels) {
1679         av_log(avctx, AV_LOG_ERROR, "Not enough channels coded in a packet.\n");
1680         ret = AVERROR_INVALIDDATA;
1681         goto error;
1682     }
1683
1684     ff_thread_await_progress(&s->prev_frame, INT_MAX, 0);
1685     ff_thread_release_buffer(avctx, &s->prev_frame);
1686
1687     if (s->modulation == MODULATION_DSD)
1688         avctx->execute2(avctx, dsd_channel, s->frame, NULL, avctx->channels);
1689
1690     ff_thread_report_progress(&s->curr_frame, INT_MAX, 0);
1691
1692     if ((ret = av_frame_ref(data, s->frame)) < 0)
1693         return ret;
1694
1695     *got_frame_ptr = 1;
1696
1697     return avpkt->size;
1698
1699 error:
1700     if (s->frame) {
1701         ff_thread_await_progress(&s->prev_frame, INT_MAX, 0);
1702         ff_thread_release_buffer(avctx, &s->prev_frame);
1703         ff_thread_report_progress(&s->curr_frame, INT_MAX, 0);
1704     }
1705
1706     return ret;
1707 }
1708
1709 AVCodec ff_wavpack_decoder = {
1710     .name           = "wavpack",
1711     .long_name      = NULL_IF_CONFIG_SMALL("WavPack"),
1712     .type           = AVMEDIA_TYPE_AUDIO,
1713     .id             = AV_CODEC_ID_WAVPACK,
1714     .priv_data_size = sizeof(WavpackContext),
1715     .init           = wavpack_decode_init,
1716     .close          = wavpack_decode_end,
1717     .decode         = wavpack_decode_frame,
1718     .flush          = wavpack_decode_flush,
1719     .init_thread_copy = ONLY_IF_THREADS_ENABLED(init_thread_copy),
1720     .update_thread_context = ONLY_IF_THREADS_ENABLED(update_thread_context),
1721     .capabilities   = AV_CODEC_CAP_DR1 | AV_CODEC_CAP_FRAME_THREADS |
1722                       AV_CODEC_CAP_SLICE_THREADS
1723 };