]> git.sesse.net Git - ffmpeg/blob - libavcodec/wavpack.c
avfilter/vf_identity: fix typo
[ffmpeg] / libavcodec / wavpack.c
1 /*
2  * WavPack lossless audio decoder
3  * Copyright (c) 2006,2011 Konstantin Shishkov
4  * Copyright (c) 2020 David Bryant
5  *
6  * This file is part of FFmpeg.
7  *
8  * FFmpeg is free software; you can redistribute it and/or
9  * modify it under the terms of the GNU Lesser General Public
10  * License as published by the Free Software Foundation; either
11  * version 2.1 of the License, or (at your option) any later version.
12  *
13  * FFmpeg is distributed in the hope that it will be useful,
14  * but WITHOUT ANY WARRANTY; without even the implied warranty of
15  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
16  * Lesser General Public License for more details.
17  *
18  * You should have received a copy of the GNU Lesser General Public
19  * License along with FFmpeg; if not, write to the Free Software
20  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
21  */
22
23 #include "libavutil/buffer.h"
24 #include "libavutil/channel_layout.h"
25
26 #define BITSTREAM_READER_LE
27 #include "avcodec.h"
28 #include "bytestream.h"
29 #include "get_bits.h"
30 #include "internal.h"
31 #include "thread.h"
32 #include "unary.h"
33 #include "wavpack.h"
34 #include "dsd.h"
35
36 /**
37  * @file
38  * WavPack lossless audio decoder
39  */
40
41 #define DSD_BYTE_READY(low,high) (!(((low) ^ (high)) & 0xff000000))
42
43 #define PTABLE_BITS 8
44 #define PTABLE_BINS (1<<PTABLE_BITS)
45 #define PTABLE_MASK (PTABLE_BINS-1)
46
47 #define UP   0x010000fe
48 #define DOWN 0x00010000
49 #define DECAY 8
50
51 #define PRECISION 20
52 #define VALUE_ONE (1 << PRECISION)
53 #define PRECISION_USE 12
54
55 #define RATE_S 20
56
57 #define MAX_HISTORY_BITS    5
58 #define MAX_HISTORY_BINS    (1 << MAX_HISTORY_BITS)
59 #define MAX_BIN_BYTES       1280    // for value_lookup, per bin (2k - 512 - 256)
60
61 typedef enum {
62     MODULATION_PCM,     // pulse code modulation
63     MODULATION_DSD      // pulse density modulation (aka DSD)
64 } Modulation;
65
66 typedef struct WavpackFrameContext {
67     AVCodecContext *avctx;
68     int frame_flags;
69     int stereo, stereo_in;
70     int joint;
71     uint32_t CRC;
72     GetBitContext gb;
73     int got_extra_bits;
74     uint32_t crc_extra_bits;
75     GetBitContext gb_extra_bits;
76     int samples;
77     int terms;
78     Decorr decorr[MAX_TERMS];
79     int zero, one, zeroes;
80     int extra_bits;
81     int and, or, shift;
82     int post_shift;
83     int hybrid, hybrid_bitrate;
84     int hybrid_maxclip, hybrid_minclip;
85     int float_flag;
86     int float_shift;
87     int float_max_exp;
88     WvChannel ch[2];
89
90     GetByteContext gbyte;
91     int ptable [PTABLE_BINS];
92     uint8_t value_lookup_buffer[MAX_HISTORY_BINS*MAX_BIN_BYTES];
93     uint16_t summed_probabilities[MAX_HISTORY_BINS][256];
94     uint8_t probabilities[MAX_HISTORY_BINS][256];
95     uint8_t *value_lookup[MAX_HISTORY_BINS];
96 } WavpackFrameContext;
97
98 #define WV_MAX_FRAME_DECODERS 14
99
100 typedef struct WavpackContext {
101     AVCodecContext *avctx;
102
103     WavpackFrameContext *fdec[WV_MAX_FRAME_DECODERS];
104     int fdec_num;
105
106     int block;
107     int samples;
108     int ch_offset;
109
110     AVFrame *frame;
111     ThreadFrame curr_frame, prev_frame;
112     Modulation modulation;
113
114     AVBufferRef *dsd_ref;
115     DSDContext *dsdctx;
116     int dsd_channels;
117 } WavpackContext;
118
119 #define LEVEL_DECAY(a)  (((a) + 0x80) >> 8)
120
121 static av_always_inline unsigned get_tail(GetBitContext *gb, int k)
122 {
123     int p, e, res;
124
125     if (k < 1)
126         return 0;
127     p   = av_log2(k);
128     e   = (1 << (p + 1)) - k - 1;
129     res = get_bitsz(gb, p);
130     if (res >= e)
131         res = (res << 1) - e + get_bits1(gb);
132     return res;
133 }
134
135 static int update_error_limit(WavpackFrameContext *ctx)
136 {
137     int i, br[2], sl[2];
138
139     for (i = 0; i <= ctx->stereo_in; i++) {
140         if (ctx->ch[i].bitrate_acc > UINT_MAX - ctx->ch[i].bitrate_delta)
141             return AVERROR_INVALIDDATA;
142         ctx->ch[i].bitrate_acc += ctx->ch[i].bitrate_delta;
143         br[i]                   = ctx->ch[i].bitrate_acc >> 16;
144         sl[i]                   = LEVEL_DECAY(ctx->ch[i].slow_level);
145     }
146     if (ctx->stereo_in && ctx->hybrid_bitrate) {
147         int balance = (sl[1] - sl[0] + br[1] + 1) >> 1;
148         if (balance > br[0]) {
149             br[1] = br[0] * 2;
150             br[0] = 0;
151         } else if (-balance > br[0]) {
152             br[0]  *= 2;
153             br[1]   = 0;
154         } else {
155             br[1] = br[0] + balance;
156             br[0] = br[0] - balance;
157         }
158     }
159     for (i = 0; i <= ctx->stereo_in; i++) {
160         if (ctx->hybrid_bitrate) {
161             if (sl[i] - br[i] > -0x100)
162                 ctx->ch[i].error_limit = wp_exp2(sl[i] - br[i] + 0x100);
163             else
164                 ctx->ch[i].error_limit = 0;
165         } else {
166             ctx->ch[i].error_limit = wp_exp2(br[i]);
167         }
168     }
169
170     return 0;
171 }
172
173 static int wv_get_value(WavpackFrameContext *ctx, GetBitContext *gb,
174                         int channel, int *last)
175 {
176     int t, t2;
177     int sign, base, add, ret;
178     WvChannel *c = &ctx->ch[channel];
179
180     *last = 0;
181
182     if ((ctx->ch[0].median[0] < 2U) && (ctx->ch[1].median[0] < 2U) &&
183         !ctx->zero && !ctx->one) {
184         if (ctx->zeroes) {
185             ctx->zeroes--;
186             if (ctx->zeroes) {
187                 c->slow_level -= LEVEL_DECAY(c->slow_level);
188                 return 0;
189             }
190         } else {
191             t = get_unary_0_33(gb);
192             if (t >= 2) {
193                 if (t >= 32 || get_bits_left(gb) < t - 1)
194                     goto error;
195                 t = get_bits_long(gb, t - 1) | (1 << (t - 1));
196             } else {
197                 if (get_bits_left(gb) < 0)
198                     goto error;
199             }
200             ctx->zeroes = t;
201             if (ctx->zeroes) {
202                 memset(ctx->ch[0].median, 0, sizeof(ctx->ch[0].median));
203                 memset(ctx->ch[1].median, 0, sizeof(ctx->ch[1].median));
204                 c->slow_level -= LEVEL_DECAY(c->slow_level);
205                 return 0;
206             }
207         }
208     }
209
210     if (ctx->zero) {
211         t         = 0;
212         ctx->zero = 0;
213     } else {
214         t = get_unary_0_33(gb);
215         if (get_bits_left(gb) < 0)
216             goto error;
217         if (t == 16) {
218             t2 = get_unary_0_33(gb);
219             if (t2 < 2) {
220                 if (get_bits_left(gb) < 0)
221                     goto error;
222                 t += t2;
223             } else {
224                 if (t2 >= 32 || get_bits_left(gb) < t2 - 1)
225                     goto error;
226                 t += get_bits_long(gb, t2 - 1) | (1 << (t2 - 1));
227             }
228         }
229
230         if (ctx->one) {
231             ctx->one = t & 1;
232             t        = (t >> 1) + 1;
233         } else {
234             ctx->one = t & 1;
235             t      >>= 1;
236         }
237         ctx->zero = !ctx->one;
238     }
239
240     if (ctx->hybrid && !channel) {
241         if (update_error_limit(ctx) < 0)
242             goto error;
243     }
244
245     if (!t) {
246         base = 0;
247         add  = GET_MED(0) - 1;
248         DEC_MED(0);
249     } else if (t == 1) {
250         base = GET_MED(0);
251         add  = GET_MED(1) - 1;
252         INC_MED(0);
253         DEC_MED(1);
254     } else if (t == 2) {
255         base = GET_MED(0) + GET_MED(1);
256         add  = GET_MED(2) - 1;
257         INC_MED(0);
258         INC_MED(1);
259         DEC_MED(2);
260     } else {
261         base = GET_MED(0) + GET_MED(1) + GET_MED(2) * (t - 2U);
262         add  = GET_MED(2) - 1;
263         INC_MED(0);
264         INC_MED(1);
265         INC_MED(2);
266     }
267     if (!c->error_limit) {
268         if (add >= 0x2000000U) {
269             av_log(ctx->avctx, AV_LOG_ERROR, "k %d is too large\n", add);
270             goto error;
271         }
272         ret = base + get_tail(gb, add);
273         if (get_bits_left(gb) <= 0)
274             goto error;
275     } else {
276         int mid = (base * 2U + add + 1) >> 1;
277         while (add > c->error_limit) {
278             if (get_bits_left(gb) <= 0)
279                 goto error;
280             if (get_bits1(gb)) {
281                 add -= (mid - (unsigned)base);
282                 base = mid;
283             } else
284                 add = mid - (unsigned)base - 1;
285             mid = (base * 2U + add + 1) >> 1;
286         }
287         ret = mid;
288     }
289     sign = get_bits1(gb);
290     if (ctx->hybrid_bitrate)
291         c->slow_level += wp_log2(ret) - LEVEL_DECAY(c->slow_level);
292     return sign ? ~ret : ret;
293
294 error:
295     ret = get_bits_left(gb);
296     if (ret <= 0) {
297         av_log(ctx->avctx, AV_LOG_ERROR, "Too few bits (%d) left\n", ret);
298     }
299     *last = 1;
300     return 0;
301 }
302
303 static inline int wv_get_value_integer(WavpackFrameContext *s, uint32_t *crc,
304                                        unsigned S)
305 {
306     unsigned bit;
307
308     if (s->extra_bits) {
309         S *= 1 << s->extra_bits;
310
311         if (s->got_extra_bits &&
312             get_bits_left(&s->gb_extra_bits) >= s->extra_bits) {
313             S   |= get_bits_long(&s->gb_extra_bits, s->extra_bits);
314             *crc = *crc * 9 + (S & 0xffff) * 3 + ((unsigned)S >> 16);
315         }
316     }
317
318     bit = (S & s->and) | s->or;
319     bit = ((S + bit) << s->shift) - bit;
320
321     if (s->hybrid)
322         bit = av_clip(bit, s->hybrid_minclip, s->hybrid_maxclip);
323
324     return bit << s->post_shift;
325 }
326
327 static float wv_get_value_float(WavpackFrameContext *s, uint32_t *crc, int S)
328 {
329     union {
330         float    f;
331         uint32_t u;
332     } value;
333
334     unsigned int sign;
335     int exp = s->float_max_exp;
336
337     if (s->got_extra_bits) {
338         const int max_bits  = 1 + 23 + 8 + 1;
339         const int left_bits = get_bits_left(&s->gb_extra_bits);
340
341         if (left_bits + 8 * AV_INPUT_BUFFER_PADDING_SIZE < max_bits)
342             return 0.0;
343     }
344
345     if (S) {
346         S  *= 1U << s->float_shift;
347         sign = S < 0;
348         if (sign)
349             S = -(unsigned)S;
350         if (S >= 0x1000000U) {
351             if (s->got_extra_bits && get_bits1(&s->gb_extra_bits))
352                 S = get_bits(&s->gb_extra_bits, 23);
353             else
354                 S = 0;
355             exp = 255;
356         } else if (exp) {
357             int shift = 23 - av_log2(S);
358             exp = s->float_max_exp;
359             if (exp <= shift)
360                 shift = --exp;
361             exp -= shift;
362
363             if (shift) {
364                 S <<= shift;
365                 if ((s->float_flag & WV_FLT_SHIFT_ONES) ||
366                     (s->got_extra_bits &&
367                      (s->float_flag & WV_FLT_SHIFT_SAME) &&
368                      get_bits1(&s->gb_extra_bits))) {
369                     S |= (1 << shift) - 1;
370                 } else if (s->got_extra_bits &&
371                            (s->float_flag & WV_FLT_SHIFT_SENT)) {
372                     S |= get_bits(&s->gb_extra_bits, shift);
373                 }
374             }
375         } else {
376             exp = s->float_max_exp;
377         }
378         S &= 0x7fffff;
379     } else {
380         sign = 0;
381         exp  = 0;
382         if (s->got_extra_bits && (s->float_flag & WV_FLT_ZERO_SENT)) {
383             if (get_bits1(&s->gb_extra_bits)) {
384                 S = get_bits(&s->gb_extra_bits, 23);
385                 if (s->float_max_exp >= 25)
386                     exp = get_bits(&s->gb_extra_bits, 8);
387                 sign = get_bits1(&s->gb_extra_bits);
388             } else {
389                 if (s->float_flag & WV_FLT_ZERO_SIGN)
390                     sign = get_bits1(&s->gb_extra_bits);
391             }
392         }
393     }
394
395     *crc = *crc * 27 + S * 9 + exp * 3 + sign;
396
397     value.u = (sign << 31) | (exp << 23) | S;
398     return value.f;
399 }
400
401 static inline int wv_check_crc(WavpackFrameContext *s, uint32_t crc,
402                                uint32_t crc_extra_bits)
403 {
404     if (crc != s->CRC) {
405         av_log(s->avctx, AV_LOG_ERROR, "CRC error\n");
406         return AVERROR_INVALIDDATA;
407     }
408     if (s->got_extra_bits && crc_extra_bits != s->crc_extra_bits) {
409         av_log(s->avctx, AV_LOG_ERROR, "Extra bits CRC error\n");
410         return AVERROR_INVALIDDATA;
411     }
412
413     return 0;
414 }
415
416 static void init_ptable(int *table, int rate_i, int rate_s)
417 {
418     int value = 0x808000, rate = rate_i << 8;
419
420     for (int c = (rate + 128) >> 8; c--;)
421         value += (DOWN - value) >> DECAY;
422
423     for (int i = 0; i < PTABLE_BINS/2; i++) {
424         table[i] = value;
425         table[PTABLE_BINS-1-i] = 0x100ffff - value;
426
427         if (value > 0x010000) {
428             rate += (rate * rate_s + 128) >> 8;
429
430             for (int c = (rate + 64) >> 7; c--;)
431                 value += (DOWN - value) >> DECAY;
432         }
433     }
434 }
435
436 typedef struct {
437     int32_t value, fltr0, fltr1, fltr2, fltr3, fltr4, fltr5, fltr6, factor;
438     unsigned int byte;
439 } DSDfilters;
440
441 static int wv_unpack_dsd_high(WavpackFrameContext *s, uint8_t *dst_left, uint8_t *dst_right)
442 {
443     uint32_t checksum = 0xFFFFFFFF;
444     uint8_t *dst_l = dst_left, *dst_r = dst_right;
445     int total_samples = s->samples, stereo = dst_r ? 1 : 0;
446     DSDfilters filters[2], *sp = filters;
447     int rate_i, rate_s;
448     uint32_t low, high, value;
449
450     if (bytestream2_get_bytes_left(&s->gbyte) < (stereo ? 20 : 13))
451         return AVERROR_INVALIDDATA;
452
453     rate_i = bytestream2_get_byte(&s->gbyte);
454     rate_s = bytestream2_get_byte(&s->gbyte);
455
456     if (rate_s != RATE_S)
457         return AVERROR_INVALIDDATA;
458
459     init_ptable(s->ptable, rate_i, rate_s);
460
461     for (int channel = 0; channel < stereo + 1; channel++) {
462         DSDfilters *sp = filters + channel;
463
464         sp->fltr1 = bytestream2_get_byte(&s->gbyte) << (PRECISION - 8);
465         sp->fltr2 = bytestream2_get_byte(&s->gbyte) << (PRECISION - 8);
466         sp->fltr3 = bytestream2_get_byte(&s->gbyte) << (PRECISION - 8);
467         sp->fltr4 = bytestream2_get_byte(&s->gbyte) << (PRECISION - 8);
468         sp->fltr5 = bytestream2_get_byte(&s->gbyte) << (PRECISION - 8);
469         sp->fltr6 = 0;
470         sp->factor = bytestream2_get_byte(&s->gbyte) & 0xff;
471         sp->factor |= (bytestream2_get_byte(&s->gbyte) << 8) & 0xff00;
472         sp->factor = (int32_t)((uint32_t)sp->factor << 16) >> 16;
473     }
474
475     value = bytestream2_get_be32(&s->gbyte);
476     high = 0xffffffff;
477     low = 0x0;
478
479     while (total_samples--) {
480         int bitcount = 8;
481
482         sp[0].value = sp[0].fltr1 - sp[0].fltr5 + ((sp[0].fltr6 * sp[0].factor) >> 2);
483
484         if (stereo)
485             sp[1].value = sp[1].fltr1 - sp[1].fltr5 + ((sp[1].fltr6 * sp[1].factor) >> 2);
486
487         while (bitcount--) {
488             int32_t *pp = s->ptable + ((sp[0].value >> (PRECISION - PRECISION_USE)) & PTABLE_MASK);
489             uint32_t split = low + ((high - low) >> 8) * (*pp >> 16);
490
491             if (value <= split) {
492                 high = split;
493                 *pp += (UP - *pp) >> DECAY;
494                 sp[0].fltr0 = -1;
495             } else {
496                 low = split + 1;
497                 *pp += (DOWN - *pp) >> DECAY;
498                 sp[0].fltr0 = 0;
499             }
500
501             while (DSD_BYTE_READY(high, low) && bytestream2_get_bytes_left(&s->gbyte)) {
502                 value = (value << 8) | bytestream2_get_byte(&s->gbyte);
503                 high = (high << 8) | 0xff;
504                 low <<= 8;
505             }
506
507             sp[0].value += sp[0].fltr6 * 8;
508             sp[0].byte = (sp[0].byte << 1) | (sp[0].fltr0 & 1);
509             sp[0].factor += (((sp[0].value ^ sp[0].fltr0) >> 31) | 1) &
510                 ((sp[0].value ^ (sp[0].value - (sp[0].fltr6 * 16))) >> 31);
511             sp[0].fltr1 += ((sp[0].fltr0 & VALUE_ONE) - sp[0].fltr1) >> 6;
512             sp[0].fltr2 += ((sp[0].fltr0 & VALUE_ONE) - sp[0].fltr2) >> 4;
513             sp[0].fltr3 += (sp[0].fltr2 - sp[0].fltr3) >> 4;
514             sp[0].fltr4 += (sp[0].fltr3 - sp[0].fltr4) >> 4;
515             sp[0].value = (sp[0].fltr4 - sp[0].fltr5) >> 4;
516             sp[0].fltr5 += sp[0].value;
517             sp[0].fltr6 += (sp[0].value - sp[0].fltr6) >> 3;
518             sp[0].value = sp[0].fltr1 - sp[0].fltr5 + ((sp[0].fltr6 * sp[0].factor) >> 2);
519
520             if (!stereo)
521                 continue;
522
523             pp = s->ptable + ((sp[1].value >> (PRECISION - PRECISION_USE)) & PTABLE_MASK);
524             split = low + ((high - low) >> 8) * (*pp >> 16);
525
526             if (value <= split) {
527                 high = split;
528                 *pp += (UP - *pp) >> DECAY;
529                 sp[1].fltr0 = -1;
530             } else {
531                 low = split + 1;
532                 *pp += (DOWN - *pp) >> DECAY;
533                 sp[1].fltr0 = 0;
534             }
535
536             while (DSD_BYTE_READY(high, low) && bytestream2_get_bytes_left(&s->gbyte)) {
537                 value = (value << 8) | bytestream2_get_byte(&s->gbyte);
538                 high = (high << 8) | 0xff;
539                 low <<= 8;
540             }
541
542             sp[1].value += sp[1].fltr6 * 8;
543             sp[1].byte = (sp[1].byte << 1) | (sp[1].fltr0 & 1);
544             sp[1].factor += (((sp[1].value ^ sp[1].fltr0) >> 31) | 1) &
545                 ((sp[1].value ^ (sp[1].value - (sp[1].fltr6 * 16))) >> 31);
546             sp[1].fltr1 += ((sp[1].fltr0 & VALUE_ONE) - sp[1].fltr1) >> 6;
547             sp[1].fltr2 += ((sp[1].fltr0 & VALUE_ONE) - sp[1].fltr2) >> 4;
548             sp[1].fltr3 += (sp[1].fltr2 - sp[1].fltr3) >> 4;
549             sp[1].fltr4 += (sp[1].fltr3 - sp[1].fltr4) >> 4;
550             sp[1].value = (sp[1].fltr4 - sp[1].fltr5) >> 4;
551             sp[1].fltr5 += sp[1].value;
552             sp[1].fltr6 += (sp[1].value - sp[1].fltr6) >> 3;
553             sp[1].value = sp[1].fltr1 - sp[1].fltr5 + ((sp[1].fltr6 * sp[1].factor) >> 2);
554         }
555
556         checksum += (checksum << 1) + (*dst_l = sp[0].byte & 0xff);
557         sp[0].factor -= (sp[0].factor + 512) >> 10;
558         dst_l += 4;
559
560         if (stereo) {
561             checksum += (checksum << 1) + (*dst_r = filters[1].byte & 0xff);
562             filters[1].factor -= (filters[1].factor + 512) >> 10;
563             dst_r += 4;
564         }
565     }
566
567     if (wv_check_crc(s, checksum, 0)) {
568         if (s->avctx->err_recognition & AV_EF_CRCCHECK)
569             return AVERROR_INVALIDDATA;
570
571         memset(dst_left, 0x69, s->samples * 4);
572
573         if (dst_r)
574             memset(dst_right, 0x69, s->samples * 4);
575     }
576
577     return 0;
578 }
579
580 static int wv_unpack_dsd_fast(WavpackFrameContext *s, uint8_t *dst_left, uint8_t *dst_right)
581 {
582     uint8_t *dst_l = dst_left, *dst_r = dst_right;
583     uint8_t history_bits, max_probability;
584     int total_summed_probabilities  = 0;
585     int total_samples               = s->samples;
586     uint8_t *vlb                    = s->value_lookup_buffer;
587     int history_bins, p0, p1, chan;
588     uint32_t checksum               = 0xFFFFFFFF;
589     uint32_t low, high, value;
590
591     if (!bytestream2_get_bytes_left(&s->gbyte))
592         return AVERROR_INVALIDDATA;
593
594     history_bits = bytestream2_get_byte(&s->gbyte);
595
596     if (!bytestream2_get_bytes_left(&s->gbyte) || history_bits > MAX_HISTORY_BITS)
597         return AVERROR_INVALIDDATA;
598
599     history_bins = 1 << history_bits;
600     max_probability = bytestream2_get_byte(&s->gbyte);
601
602     if (max_probability < 0xff) {
603         uint8_t *outptr = (uint8_t *)s->probabilities;
604         uint8_t *outend = outptr + sizeof(*s->probabilities) * history_bins;
605
606         while (outptr < outend && bytestream2_get_bytes_left(&s->gbyte)) {
607             int code = bytestream2_get_byte(&s->gbyte);
608
609             if (code > max_probability) {
610                 int zcount = code - max_probability;
611
612                 while (outptr < outend && zcount--)
613                     *outptr++ = 0;
614             } else if (code) {
615                 *outptr++ = code;
616             }
617             else {
618                 break;
619             }
620         }
621
622         if (outptr < outend ||
623             (bytestream2_get_bytes_left(&s->gbyte) && bytestream2_get_byte(&s->gbyte)))
624                 return AVERROR_INVALIDDATA;
625     } else if (bytestream2_get_bytes_left(&s->gbyte) > (int)sizeof(*s->probabilities) * history_bins) {
626         bytestream2_get_buffer(&s->gbyte, (uint8_t *)s->probabilities,
627             sizeof(*s->probabilities) * history_bins);
628     } else {
629         return AVERROR_INVALIDDATA;
630     }
631
632     for (p0 = 0; p0 < history_bins; p0++) {
633         int32_t sum_values = 0;
634
635         for (int i = 0; i < 256; i++)
636             s->summed_probabilities[p0][i] = sum_values += s->probabilities[p0][i];
637
638         if (sum_values) {
639             total_summed_probabilities += sum_values;
640
641             if (total_summed_probabilities > history_bins * MAX_BIN_BYTES)
642                 return AVERROR_INVALIDDATA;
643
644             s->value_lookup[p0] = vlb;
645
646             for (int i = 0; i < 256; i++) {
647                 int c = s->probabilities[p0][i];
648
649                 while (c--)
650                     *vlb++ = i;
651             }
652         }
653     }
654
655     if (bytestream2_get_bytes_left(&s->gbyte) < 4)
656         return AVERROR_INVALIDDATA;
657
658     chan = p0 = p1 = 0;
659     low = 0; high = 0xffffffff;
660     value = bytestream2_get_be32(&s->gbyte);
661
662     if (dst_r)
663         total_samples *= 2;
664
665     while (total_samples--) {
666         unsigned int mult, index, code;
667
668         if (!s->summed_probabilities[p0][255])
669             return AVERROR_INVALIDDATA;
670
671         mult = (high - low) / s->summed_probabilities[p0][255];
672
673         if (!mult) {
674             if (bytestream2_get_bytes_left(&s->gbyte) >= 4)
675                 value = bytestream2_get_be32(&s->gbyte);
676
677             low = 0;
678             high = 0xffffffff;
679             mult = high / s->summed_probabilities[p0][255];
680
681             if (!mult)
682                 return AVERROR_INVALIDDATA;
683         }
684
685         index = (value - low) / mult;
686
687         if (index >= s->summed_probabilities[p0][255])
688             return AVERROR_INVALIDDATA;
689
690         if (!dst_r) {
691             if ((*dst_l = code = s->value_lookup[p0][index]))
692                 low += s->summed_probabilities[p0][code-1] * mult;
693
694             dst_l += 4;
695         } else {
696             if ((code = s->value_lookup[p0][index]))
697                 low += s->summed_probabilities[p0][code-1] * mult;
698
699             if (chan) {
700                 *dst_r = code;
701                 dst_r += 4;
702             }
703             else {
704                 *dst_l = code;
705                 dst_l += 4;
706             }
707
708             chan ^= 1;
709         }
710
711         high = low + s->probabilities[p0][code] * mult - 1;
712         checksum += (checksum << 1) + code;
713
714         if (!dst_r) {
715             p0 = code & (history_bins-1);
716         } else {
717             p0 = p1;
718             p1 = code & (history_bins-1);
719         }
720
721         while (DSD_BYTE_READY(high, low) && bytestream2_get_bytes_left(&s->gbyte)) {
722             value = (value << 8) | bytestream2_get_byte(&s->gbyte);
723             high = (high << 8) | 0xff;
724             low <<= 8;
725         }
726     }
727
728     if (wv_check_crc(s, checksum, 0)) {
729         if (s->avctx->err_recognition & AV_EF_CRCCHECK)
730             return AVERROR_INVALIDDATA;
731
732         memset(dst_left, 0x69, s->samples * 4);
733
734         if (dst_r)
735             memset(dst_right, 0x69, s->samples * 4);
736     }
737
738     return 0;
739 }
740
741 static int wv_unpack_dsd_copy(WavpackFrameContext *s, uint8_t *dst_left, uint8_t *dst_right)
742 {
743     uint8_t *dst_l = dst_left, *dst_r = dst_right;
744     int total_samples           = s->samples;
745     uint32_t checksum           = 0xFFFFFFFF;
746
747     if (bytestream2_get_bytes_left(&s->gbyte) != total_samples * (dst_r ? 2 : 1))
748         return AVERROR_INVALIDDATA;
749
750     while (total_samples--) {
751         checksum += (checksum << 1) + (*dst_l = bytestream2_get_byte(&s->gbyte));
752         dst_l += 4;
753
754         if (dst_r) {
755             checksum += (checksum << 1) + (*dst_r = bytestream2_get_byte(&s->gbyte));
756             dst_r += 4;
757         }
758     }
759
760     if (wv_check_crc(s, checksum, 0)) {
761         if (s->avctx->err_recognition & AV_EF_CRCCHECK)
762             return AVERROR_INVALIDDATA;
763
764         memset(dst_left, 0x69, s->samples * 4);
765
766         if (dst_r)
767             memset(dst_right, 0x69, s->samples * 4);
768     }
769
770     return 0;
771 }
772
773 static inline int wv_unpack_stereo(WavpackFrameContext *s, GetBitContext *gb,
774                                    void *dst_l, void *dst_r, const int type)
775 {
776     int i, j, count = 0;
777     int last, t;
778     int A, B, L, L2, R, R2;
779     int pos                 = 0;
780     uint32_t crc            = 0xFFFFFFFF;
781     uint32_t crc_extra_bits = 0xFFFFFFFF;
782     int16_t *dst16_l        = dst_l;
783     int16_t *dst16_r        = dst_r;
784     int32_t *dst32_l        = dst_l;
785     int32_t *dst32_r        = dst_r;
786     float *dstfl_l          = dst_l;
787     float *dstfl_r          = dst_r;
788
789     s->one = s->zero = s->zeroes = 0;
790     do {
791         L = wv_get_value(s, gb, 0, &last);
792         if (last)
793             break;
794         R = wv_get_value(s, gb, 1, &last);
795         if (last)
796             break;
797         for (i = 0; i < s->terms; i++) {
798             t = s->decorr[i].value;
799             if (t > 0) {
800                 if (t > 8) {
801                     if (t & 1) {
802                         A = 2U * s->decorr[i].samplesA[0] - s->decorr[i].samplesA[1];
803                         B = 2U * s->decorr[i].samplesB[0] - s->decorr[i].samplesB[1];
804                     } else {
805                         A = (int)(3U * s->decorr[i].samplesA[0] - s->decorr[i].samplesA[1]) >> 1;
806                         B = (int)(3U * s->decorr[i].samplesB[0] - s->decorr[i].samplesB[1]) >> 1;
807                     }
808                     s->decorr[i].samplesA[1] = s->decorr[i].samplesA[0];
809                     s->decorr[i].samplesB[1] = s->decorr[i].samplesB[0];
810                     j                        = 0;
811                 } else {
812                     A = s->decorr[i].samplesA[pos];
813                     B = s->decorr[i].samplesB[pos];
814                     j = (pos + t) & 7;
815                 }
816                 if (type != AV_SAMPLE_FMT_S16P) {
817                     L2 = L + ((s->decorr[i].weightA * (int64_t)A + 512) >> 10);
818                     R2 = R + ((s->decorr[i].weightB * (int64_t)B + 512) >> 10);
819                 } else {
820                     L2 = L + (unsigned)((int)(s->decorr[i].weightA * (unsigned)A + 512) >> 10);
821                     R2 = R + (unsigned)((int)(s->decorr[i].weightB * (unsigned)B + 512) >> 10);
822                 }
823                 if (A && L)
824                     s->decorr[i].weightA -= ((((L ^ A) >> 30) & 2) - 1) * s->decorr[i].delta;
825                 if (B && R)
826                     s->decorr[i].weightB -= ((((R ^ B) >> 30) & 2) - 1) * s->decorr[i].delta;
827                 s->decorr[i].samplesA[j] = L = L2;
828                 s->decorr[i].samplesB[j] = R = R2;
829             } else if (t == -1) {
830                 if (type != AV_SAMPLE_FMT_S16P)
831                     L2 = L + ((s->decorr[i].weightA * (int64_t)s->decorr[i].samplesA[0] + 512) >> 10);
832                 else
833                     L2 = L + (unsigned)((int)(s->decorr[i].weightA * (unsigned)s->decorr[i].samplesA[0] + 512) >> 10);
834                 UPDATE_WEIGHT_CLIP(s->decorr[i].weightA, s->decorr[i].delta, s->decorr[i].samplesA[0], L);
835                 L = L2;
836                 if (type != AV_SAMPLE_FMT_S16P)
837                     R2 = R + ((s->decorr[i].weightB * (int64_t)L2 + 512) >> 10);
838                 else
839                     R2 = R + (unsigned)((int)(s->decorr[i].weightB * (unsigned)L2 + 512) >> 10);
840                 UPDATE_WEIGHT_CLIP(s->decorr[i].weightB, s->decorr[i].delta, L2, R);
841                 R                        = R2;
842                 s->decorr[i].samplesA[0] = R;
843             } else {
844                 if (type != AV_SAMPLE_FMT_S16P)
845                     R2 = R + ((s->decorr[i].weightB * (int64_t)s->decorr[i].samplesB[0] + 512) >> 10);
846                 else
847                     R2 = R + (unsigned)((int)(s->decorr[i].weightB * (unsigned)s->decorr[i].samplesB[0] + 512) >> 10);
848                 UPDATE_WEIGHT_CLIP(s->decorr[i].weightB, s->decorr[i].delta, s->decorr[i].samplesB[0], R);
849                 R = R2;
850
851                 if (t == -3) {
852                     R2                       = s->decorr[i].samplesA[0];
853                     s->decorr[i].samplesA[0] = R;
854                 }
855
856                 if (type != AV_SAMPLE_FMT_S16P)
857                     L2 = L + ((s->decorr[i].weightA * (int64_t)R2 + 512) >> 10);
858                 else
859                     L2 = L + (unsigned)((int)(s->decorr[i].weightA * (unsigned)R2 + 512) >> 10);
860                 UPDATE_WEIGHT_CLIP(s->decorr[i].weightA, s->decorr[i].delta, R2, L);
861                 L                        = L2;
862                 s->decorr[i].samplesB[0] = L;
863             }
864         }
865
866         if (type == AV_SAMPLE_FMT_S16P) {
867             if (FFABS((int64_t)L) + FFABS((int64_t)R) > (1<<19)) {
868                 av_log(s->avctx, AV_LOG_ERROR, "sample %d %d too large\n", L, R);
869                 return AVERROR_INVALIDDATA;
870             }
871         }
872
873         pos = (pos + 1) & 7;
874         if (s->joint)
875             L += (unsigned)(R -= (unsigned)(L >> 1));
876         crc = (crc * 3 + L) * 3 + R;
877
878         if (type == AV_SAMPLE_FMT_FLTP) {
879             *dstfl_l++ = wv_get_value_float(s, &crc_extra_bits, L);
880             *dstfl_r++ = wv_get_value_float(s, &crc_extra_bits, R);
881         } else if (type == AV_SAMPLE_FMT_S32P) {
882             *dst32_l++ = wv_get_value_integer(s, &crc_extra_bits, L);
883             *dst32_r++ = wv_get_value_integer(s, &crc_extra_bits, R);
884         } else {
885             *dst16_l++ = wv_get_value_integer(s, &crc_extra_bits, L);
886             *dst16_r++ = wv_get_value_integer(s, &crc_extra_bits, R);
887         }
888         count++;
889     } while (!last && count < s->samples);
890
891     if (last && count < s->samples) {
892         int size = av_get_bytes_per_sample(type);
893         memset((uint8_t*)dst_l + count*size, 0, (s->samples-count)*size);
894         memset((uint8_t*)dst_r + count*size, 0, (s->samples-count)*size);
895     }
896
897     if ((s->avctx->err_recognition & AV_EF_CRCCHECK) &&
898         wv_check_crc(s, crc, crc_extra_bits))
899         return AVERROR_INVALIDDATA;
900
901     return 0;
902 }
903
904 static inline int wv_unpack_mono(WavpackFrameContext *s, GetBitContext *gb,
905                                  void *dst, const int type)
906 {
907     int i, j, count = 0;
908     int last, t;
909     int A, S, T;
910     int pos                  = 0;
911     uint32_t crc             = 0xFFFFFFFF;
912     uint32_t crc_extra_bits  = 0xFFFFFFFF;
913     int16_t *dst16           = dst;
914     int32_t *dst32           = dst;
915     float *dstfl             = dst;
916
917     s->one = s->zero = s->zeroes = 0;
918     do {
919         T = wv_get_value(s, gb, 0, &last);
920         S = 0;
921         if (last)
922             break;
923         for (i = 0; i < s->terms; i++) {
924             t = s->decorr[i].value;
925             if (t > 8) {
926                 if (t & 1)
927                     A =  2U * s->decorr[i].samplesA[0] - s->decorr[i].samplesA[1];
928                 else
929                     A = (int)(3U * s->decorr[i].samplesA[0] - s->decorr[i].samplesA[1]) >> 1;
930                 s->decorr[i].samplesA[1] = s->decorr[i].samplesA[0];
931                 j                        = 0;
932             } else {
933                 A = s->decorr[i].samplesA[pos];
934                 j = (pos + t) & 7;
935             }
936             if (type != AV_SAMPLE_FMT_S16P)
937                 S = T + ((s->decorr[i].weightA * (int64_t)A + 512) >> 10);
938             else
939                 S = T + (unsigned)((int)(s->decorr[i].weightA * (unsigned)A + 512) >> 10);
940             if (A && T)
941                 s->decorr[i].weightA -= ((((T ^ A) >> 30) & 2) - 1) * s->decorr[i].delta;
942             s->decorr[i].samplesA[j] = T = S;
943         }
944         pos = (pos + 1) & 7;
945         crc = crc * 3 + S;
946
947         if (type == AV_SAMPLE_FMT_FLTP) {
948             *dstfl++ = wv_get_value_float(s, &crc_extra_bits, S);
949         } else if (type == AV_SAMPLE_FMT_S32P) {
950             *dst32++ = wv_get_value_integer(s, &crc_extra_bits, S);
951         } else {
952             *dst16++ = wv_get_value_integer(s, &crc_extra_bits, S);
953         }
954         count++;
955     } while (!last && count < s->samples);
956
957     if (last && count < s->samples) {
958         int size = av_get_bytes_per_sample(type);
959         memset((uint8_t*)dst + count*size, 0, (s->samples-count)*size);
960     }
961
962     if (s->avctx->err_recognition & AV_EF_CRCCHECK) {
963         int ret = wv_check_crc(s, crc, crc_extra_bits);
964         if (ret < 0 && s->avctx->err_recognition & AV_EF_EXPLODE)
965             return ret;
966     }
967
968     return 0;
969 }
970
971 static av_cold int wv_alloc_frame_context(WavpackContext *c)
972 {
973     if (c->fdec_num == WV_MAX_FRAME_DECODERS)
974         return -1;
975
976     c->fdec[c->fdec_num] = av_mallocz(sizeof(**c->fdec));
977     if (!c->fdec[c->fdec_num])
978         return -1;
979     c->fdec_num++;
980     c->fdec[c->fdec_num - 1]->avctx = c->avctx;
981
982     return 0;
983 }
984
985 static int wv_dsd_reset(WavpackContext *s, int channels)
986 {
987     int i;
988
989     s->dsdctx = NULL;
990     s->dsd_channels = 0;
991     av_buffer_unref(&s->dsd_ref);
992
993     if (!channels)
994         return 0;
995
996     if (channels > INT_MAX / sizeof(*s->dsdctx))
997         return AVERROR(EINVAL);
998
999     s->dsd_ref = av_buffer_allocz(channels * sizeof(*s->dsdctx));
1000     if (!s->dsd_ref)
1001         return AVERROR(ENOMEM);
1002     s->dsdctx = (DSDContext*)s->dsd_ref->data;
1003     s->dsd_channels = channels;
1004
1005     for (i = 0; i < channels; i++)
1006         memset(s->dsdctx[i].buf, 0x69, sizeof(s->dsdctx[i].buf));
1007
1008     return 0;
1009 }
1010
1011 #if HAVE_THREADS
1012 static int update_thread_context(AVCodecContext *dst, const AVCodecContext *src)
1013 {
1014     WavpackContext *fsrc = src->priv_data;
1015     WavpackContext *fdst = dst->priv_data;
1016     int ret;
1017
1018     if (dst == src)
1019         return 0;
1020
1021     ff_thread_release_buffer(dst, &fdst->curr_frame);
1022     if (fsrc->curr_frame.f->data[0]) {
1023         if ((ret = ff_thread_ref_frame(&fdst->curr_frame, &fsrc->curr_frame)) < 0)
1024             return ret;
1025     }
1026
1027     fdst->dsdctx = NULL;
1028     fdst->dsd_channels = 0;
1029     ret = av_buffer_replace(&fdst->dsd_ref, fsrc->dsd_ref);
1030     if (ret < 0)
1031         return ret;
1032     if (fsrc->dsd_ref) {
1033         fdst->dsdctx = (DSDContext*)fdst->dsd_ref->data;
1034         fdst->dsd_channels = fsrc->dsd_channels;
1035     }
1036
1037     return 0;
1038 }
1039 #endif
1040
1041 static av_cold int wavpack_decode_init(AVCodecContext *avctx)
1042 {
1043     WavpackContext *s = avctx->priv_data;
1044
1045     s->avctx = avctx;
1046
1047     s->fdec_num = 0;
1048
1049     s->curr_frame.f = av_frame_alloc();
1050     s->prev_frame.f = av_frame_alloc();
1051
1052     if (!s->curr_frame.f || !s->prev_frame.f)
1053         return AVERROR(ENOMEM);
1054
1055     ff_init_dsd_data();
1056
1057     return 0;
1058 }
1059
1060 static av_cold int wavpack_decode_end(AVCodecContext *avctx)
1061 {
1062     WavpackContext *s = avctx->priv_data;
1063
1064     for (int i = 0; i < s->fdec_num; i++)
1065         av_freep(&s->fdec[i]);
1066     s->fdec_num = 0;
1067
1068     ff_thread_release_buffer(avctx, &s->curr_frame);
1069     av_frame_free(&s->curr_frame.f);
1070
1071     ff_thread_release_buffer(avctx, &s->prev_frame);
1072     av_frame_free(&s->prev_frame.f);
1073
1074     av_buffer_unref(&s->dsd_ref);
1075
1076     return 0;
1077 }
1078
1079 static int wavpack_decode_block(AVCodecContext *avctx, int block_no,
1080                                 const uint8_t *buf, int buf_size)
1081 {
1082     WavpackContext *wc = avctx->priv_data;
1083     WavpackFrameContext *s;
1084     GetByteContext gb;
1085     enum AVSampleFormat sample_fmt;
1086     void *samples_l = NULL, *samples_r = NULL;
1087     int ret;
1088     int got_terms   = 0, got_weights = 0, got_samples = 0,
1089         got_entropy = 0, got_pcm     = 0, got_float   = 0, got_hybrid = 0;
1090     int got_dsd = 0;
1091     int i, j, id, size, ssize, weights, t;
1092     int bpp, chan = 0, orig_bpp, sample_rate = 0, rate_x = 1, dsd_mode = 0;
1093     int multiblock;
1094     uint64_t chmask = 0;
1095
1096     if (block_no >= wc->fdec_num && wv_alloc_frame_context(wc) < 0) {
1097         av_log(avctx, AV_LOG_ERROR, "Error creating frame decode context\n");
1098         return AVERROR_INVALIDDATA;
1099     }
1100
1101     s = wc->fdec[block_no];
1102     if (!s) {
1103         av_log(avctx, AV_LOG_ERROR, "Context for block %d is not present\n",
1104                block_no);
1105         return AVERROR_INVALIDDATA;
1106     }
1107
1108     memset(s->decorr, 0, MAX_TERMS * sizeof(Decorr));
1109     memset(s->ch, 0, sizeof(s->ch));
1110     s->extra_bits     = 0;
1111     s->and            = s->or = s->shift = 0;
1112     s->got_extra_bits = 0;
1113
1114     bytestream2_init(&gb, buf, buf_size);
1115
1116     s->samples = bytestream2_get_le32(&gb);
1117     if (s->samples != wc->samples) {
1118         av_log(avctx, AV_LOG_ERROR, "Mismatching number of samples in "
1119                "a sequence: %d and %d\n", wc->samples, s->samples);
1120         return AVERROR_INVALIDDATA;
1121     }
1122     s->frame_flags = bytestream2_get_le32(&gb);
1123
1124     if (s->frame_flags & (WV_FLOAT_DATA | WV_DSD_DATA))
1125         sample_fmt = AV_SAMPLE_FMT_FLTP;
1126     else if ((s->frame_flags & 0x03) <= 1)
1127         sample_fmt = AV_SAMPLE_FMT_S16P;
1128     else
1129         sample_fmt          = AV_SAMPLE_FMT_S32P;
1130
1131     if (wc->ch_offset && avctx->sample_fmt != sample_fmt)
1132         return AVERROR_INVALIDDATA;
1133
1134     bpp            = av_get_bytes_per_sample(sample_fmt);
1135     orig_bpp       = ((s->frame_flags & 0x03) + 1) << 3;
1136     multiblock     = (s->frame_flags & WV_SINGLE_BLOCK) != WV_SINGLE_BLOCK;
1137
1138     s->stereo         = !(s->frame_flags & WV_MONO);
1139     s->stereo_in      =  (s->frame_flags & WV_FALSE_STEREO) ? 0 : s->stereo;
1140     s->joint          =   s->frame_flags & WV_JOINT_STEREO;
1141     s->hybrid         =   s->frame_flags & WV_HYBRID_MODE;
1142     s->hybrid_bitrate =   s->frame_flags & WV_HYBRID_BITRATE;
1143     s->post_shift     = bpp * 8 - orig_bpp + ((s->frame_flags >> 13) & 0x1f);
1144     if (s->post_shift < 0 || s->post_shift > 31) {
1145         return AVERROR_INVALIDDATA;
1146     }
1147     s->hybrid_maxclip =  ((1LL << (orig_bpp - 1)) - 1);
1148     s->hybrid_minclip = ((-1UL << (orig_bpp - 1)));
1149     s->CRC            = bytestream2_get_le32(&gb);
1150
1151     // parse metadata blocks
1152     while (bytestream2_get_bytes_left(&gb)) {
1153         id   = bytestream2_get_byte(&gb);
1154         size = bytestream2_get_byte(&gb);
1155         if (id & WP_IDF_LONG)
1156             size |= (bytestream2_get_le16u(&gb)) << 8;
1157         size <<= 1; // size is specified in words
1158         ssize  = size;
1159         if (id & WP_IDF_ODD)
1160             size--;
1161         if (size < 0) {
1162             av_log(avctx, AV_LOG_ERROR,
1163                    "Got incorrect block %02X with size %i\n", id, size);
1164             break;
1165         }
1166         if (bytestream2_get_bytes_left(&gb) < ssize) {
1167             av_log(avctx, AV_LOG_ERROR,
1168                    "Block size %i is out of bounds\n", size);
1169             break;
1170         }
1171         switch (id & WP_IDF_MASK) {
1172         case WP_ID_DECTERMS:
1173             if (size > MAX_TERMS) {
1174                 av_log(avctx, AV_LOG_ERROR, "Too many decorrelation terms\n");
1175                 s->terms = 0;
1176                 bytestream2_skip(&gb, ssize);
1177                 continue;
1178             }
1179             s->terms = size;
1180             for (i = 0; i < s->terms; i++) {
1181                 uint8_t val = bytestream2_get_byte(&gb);
1182                 s->decorr[s->terms - i - 1].value = (val & 0x1F) - 5;
1183                 s->decorr[s->terms - i - 1].delta =  val >> 5;
1184             }
1185             got_terms = 1;
1186             break;
1187         case WP_ID_DECWEIGHTS:
1188             if (!got_terms) {
1189                 av_log(avctx, AV_LOG_ERROR, "No decorrelation terms met\n");
1190                 continue;
1191             }
1192             weights = size >> s->stereo_in;
1193             if (weights > MAX_TERMS || weights > s->terms) {
1194                 av_log(avctx, AV_LOG_ERROR, "Too many decorrelation weights\n");
1195                 bytestream2_skip(&gb, ssize);
1196                 continue;
1197             }
1198             for (i = 0; i < weights; i++) {
1199                 t = (int8_t)bytestream2_get_byte(&gb);
1200                 s->decorr[s->terms - i - 1].weightA = t * (1 << 3);
1201                 if (s->decorr[s->terms - i - 1].weightA > 0)
1202                     s->decorr[s->terms - i - 1].weightA +=
1203                         (s->decorr[s->terms - i - 1].weightA + 64) >> 7;
1204                 if (s->stereo_in) {
1205                     t = (int8_t)bytestream2_get_byte(&gb);
1206                     s->decorr[s->terms - i - 1].weightB = t * (1 << 3);
1207                     if (s->decorr[s->terms - i - 1].weightB > 0)
1208                         s->decorr[s->terms - i - 1].weightB +=
1209                             (s->decorr[s->terms - i - 1].weightB + 64) >> 7;
1210                 }
1211             }
1212             got_weights = 1;
1213             break;
1214         case WP_ID_DECSAMPLES:
1215             if (!got_terms) {
1216                 av_log(avctx, AV_LOG_ERROR, "No decorrelation terms met\n");
1217                 continue;
1218             }
1219             t = 0;
1220             for (i = s->terms - 1; (i >= 0) && (t < size); i--) {
1221                 if (s->decorr[i].value > 8) {
1222                     s->decorr[i].samplesA[0] =
1223                         wp_exp2(bytestream2_get_le16(&gb));
1224                     s->decorr[i].samplesA[1] =
1225                         wp_exp2(bytestream2_get_le16(&gb));
1226
1227                     if (s->stereo_in) {
1228                         s->decorr[i].samplesB[0] =
1229                             wp_exp2(bytestream2_get_le16(&gb));
1230                         s->decorr[i].samplesB[1] =
1231                             wp_exp2(bytestream2_get_le16(&gb));
1232                         t                       += 4;
1233                     }
1234                     t += 4;
1235                 } else if (s->decorr[i].value < 0) {
1236                     s->decorr[i].samplesA[0] =
1237                         wp_exp2(bytestream2_get_le16(&gb));
1238                     s->decorr[i].samplesB[0] =
1239                         wp_exp2(bytestream2_get_le16(&gb));
1240                     t                       += 4;
1241                 } else {
1242                     for (j = 0; j < s->decorr[i].value; j++) {
1243                         s->decorr[i].samplesA[j] =
1244                             wp_exp2(bytestream2_get_le16(&gb));
1245                         if (s->stereo_in) {
1246                             s->decorr[i].samplesB[j] =
1247                                 wp_exp2(bytestream2_get_le16(&gb));
1248                         }
1249                     }
1250                     t += s->decorr[i].value * 2 * (s->stereo_in + 1);
1251                 }
1252             }
1253             got_samples = 1;
1254             break;
1255         case WP_ID_ENTROPY:
1256             if (size != 6 * (s->stereo_in + 1)) {
1257                 av_log(avctx, AV_LOG_ERROR,
1258                        "Entropy vars size should be %i, got %i.\n",
1259                        6 * (s->stereo_in + 1), size);
1260                 bytestream2_skip(&gb, ssize);
1261                 continue;
1262             }
1263             for (j = 0; j <= s->stereo_in; j++)
1264                 for (i = 0; i < 3; i++) {
1265                     s->ch[j].median[i] = wp_exp2(bytestream2_get_le16(&gb));
1266                 }
1267             got_entropy = 1;
1268             break;
1269         case WP_ID_HYBRID:
1270             if (s->hybrid_bitrate) {
1271                 for (i = 0; i <= s->stereo_in; i++) {
1272                     s->ch[i].slow_level = wp_exp2(bytestream2_get_le16(&gb));
1273                     size               -= 2;
1274                 }
1275             }
1276             for (i = 0; i < (s->stereo_in + 1); i++) {
1277                 s->ch[i].bitrate_acc = bytestream2_get_le16(&gb) << 16;
1278                 size                -= 2;
1279             }
1280             if (size > 0) {
1281                 for (i = 0; i < (s->stereo_in + 1); i++) {
1282                     s->ch[i].bitrate_delta =
1283                         wp_exp2((int16_t)bytestream2_get_le16(&gb));
1284                 }
1285             } else {
1286                 for (i = 0; i < (s->stereo_in + 1); i++)
1287                     s->ch[i].bitrate_delta = 0;
1288             }
1289             got_hybrid = 1;
1290             break;
1291         case WP_ID_INT32INFO: {
1292             uint8_t val[4];
1293             if (size != 4) {
1294                 av_log(avctx, AV_LOG_ERROR,
1295                        "Invalid INT32INFO, size = %i\n",
1296                        size);
1297                 bytestream2_skip(&gb, ssize - 4);
1298                 continue;
1299             }
1300             bytestream2_get_buffer(&gb, val, 4);
1301             if (val[0] > 30) {
1302                 av_log(avctx, AV_LOG_ERROR,
1303                        "Invalid INT32INFO, extra_bits = %d (> 30)\n", val[0]);
1304                 continue;
1305             } else if (val[0]) {
1306                 s->extra_bits = val[0];
1307             } else if (val[1]) {
1308                 s->shift = val[1];
1309             } else if (val[2]) {
1310                 s->and   = s->or = 1;
1311                 s->shift = val[2];
1312             } else if (val[3]) {
1313                 s->and   = 1;
1314                 s->shift = val[3];
1315             }
1316             if (s->shift > 31) {
1317                 av_log(avctx, AV_LOG_ERROR,
1318                        "Invalid INT32INFO, shift = %d (> 31)\n", s->shift);
1319                 s->and = s->or = s->shift = 0;
1320                 continue;
1321             }
1322             /* original WavPack decoder forces 32-bit lossy sound to be treated
1323              * as 24-bit one in order to have proper clipping */
1324             if (s->hybrid && bpp == 4 && s->post_shift < 8 && s->shift > 8) {
1325                 s->post_shift      += 8;
1326                 s->shift           -= 8;
1327                 s->hybrid_maxclip >>= 8;
1328                 s->hybrid_minclip >>= 8;
1329             }
1330             break;
1331         }
1332         case WP_ID_FLOATINFO:
1333             if (size != 4) {
1334                 av_log(avctx, AV_LOG_ERROR,
1335                        "Invalid FLOATINFO, size = %i\n", size);
1336                 bytestream2_skip(&gb, ssize);
1337                 continue;
1338             }
1339             s->float_flag    = bytestream2_get_byte(&gb);
1340             s->float_shift   = bytestream2_get_byte(&gb);
1341             s->float_max_exp = bytestream2_get_byte(&gb);
1342             if (s->float_shift > 31) {
1343                 av_log(avctx, AV_LOG_ERROR,
1344                        "Invalid FLOATINFO, shift = %d (> 31)\n", s->float_shift);
1345                 s->float_shift = 0;
1346                 continue;
1347             }
1348             got_float        = 1;
1349             bytestream2_skip(&gb, 1);
1350             break;
1351         case WP_ID_DATA:
1352             if ((ret = init_get_bits8(&s->gb, gb.buffer, size)) < 0)
1353                 return ret;
1354             bytestream2_skip(&gb, size);
1355             got_pcm      = 1;
1356             break;
1357         case WP_ID_DSD_DATA:
1358             if (size < 2) {
1359                 av_log(avctx, AV_LOG_ERROR, "Invalid DSD_DATA, size = %i\n",
1360                        size);
1361                 bytestream2_skip(&gb, ssize);
1362                 continue;
1363             }
1364             rate_x = bytestream2_get_byte(&gb);
1365             if (rate_x > 30)
1366                 return AVERROR_INVALIDDATA;
1367             rate_x = 1 << rate_x;
1368             dsd_mode = bytestream2_get_byte(&gb);
1369             if (dsd_mode && dsd_mode != 1 && dsd_mode != 3) {
1370                 av_log(avctx, AV_LOG_ERROR, "Invalid DSD encoding mode: %d\n",
1371                     dsd_mode);
1372                 return AVERROR_INVALIDDATA;
1373             }
1374             bytestream2_init(&s->gbyte, gb.buffer, size-2);
1375             bytestream2_skip(&gb, size-2);
1376             got_dsd      = 1;
1377             break;
1378         case WP_ID_EXTRABITS:
1379             if (size <= 4) {
1380                 av_log(avctx, AV_LOG_ERROR, "Invalid EXTRABITS, size = %i\n",
1381                        size);
1382                 bytestream2_skip(&gb, size);
1383                 continue;
1384             }
1385             if ((ret = init_get_bits8(&s->gb_extra_bits, gb.buffer, size)) < 0)
1386                 return ret;
1387             s->crc_extra_bits  = get_bits_long(&s->gb_extra_bits, 32);
1388             bytestream2_skip(&gb, size);
1389             s->got_extra_bits  = 1;
1390             break;
1391         case WP_ID_CHANINFO:
1392             if (size <= 1) {
1393                 av_log(avctx, AV_LOG_ERROR,
1394                        "Insufficient channel information\n");
1395                 return AVERROR_INVALIDDATA;
1396             }
1397             chan = bytestream2_get_byte(&gb);
1398             switch (size - 2) {
1399             case 0:
1400                 chmask = bytestream2_get_byte(&gb);
1401                 break;
1402             case 1:
1403                 chmask = bytestream2_get_le16(&gb);
1404                 break;
1405             case 2:
1406                 chmask = bytestream2_get_le24(&gb);
1407                 break;
1408             case 3:
1409                 chmask = bytestream2_get_le32(&gb);
1410                 break;
1411             case 4:
1412                 size = bytestream2_get_byte(&gb);
1413                 chan  |= (bytestream2_get_byte(&gb) & 0xF) << 8;
1414                 chan  += 1;
1415                 if (avctx->channels != chan)
1416                     av_log(avctx, AV_LOG_WARNING, "%i channels signalled"
1417                            " instead of %i.\n", chan, avctx->channels);
1418                 chmask = bytestream2_get_le24(&gb);
1419                 break;
1420             case 5:
1421                 size = bytestream2_get_byte(&gb);
1422                 chan  |= (bytestream2_get_byte(&gb) & 0xF) << 8;
1423                 chan  += 1;
1424                 if (avctx->channels != chan)
1425                     av_log(avctx, AV_LOG_WARNING, "%i channels signalled"
1426                            " instead of %i.\n", chan, avctx->channels);
1427                 chmask = bytestream2_get_le32(&gb);
1428                 break;
1429             default:
1430                 av_log(avctx, AV_LOG_ERROR, "Invalid channel info size %d\n",
1431                        size);
1432                 chan   = avctx->channels;
1433                 chmask = avctx->channel_layout;
1434             }
1435             break;
1436         case WP_ID_SAMPLE_RATE:
1437             if (size != 3) {
1438                 av_log(avctx, AV_LOG_ERROR, "Invalid custom sample rate.\n");
1439                 return AVERROR_INVALIDDATA;
1440             }
1441             sample_rate = bytestream2_get_le24(&gb);
1442             break;
1443         default:
1444             bytestream2_skip(&gb, size);
1445         }
1446         if (id & WP_IDF_ODD)
1447             bytestream2_skip(&gb, 1);
1448     }
1449
1450     if (got_pcm) {
1451         if (!got_terms) {
1452             av_log(avctx, AV_LOG_ERROR, "No block with decorrelation terms\n");
1453             return AVERROR_INVALIDDATA;
1454         }
1455         if (!got_weights) {
1456             av_log(avctx, AV_LOG_ERROR, "No block with decorrelation weights\n");
1457             return AVERROR_INVALIDDATA;
1458         }
1459         if (!got_samples) {
1460             av_log(avctx, AV_LOG_ERROR, "No block with decorrelation samples\n");
1461             return AVERROR_INVALIDDATA;
1462         }
1463         if (!got_entropy) {
1464             av_log(avctx, AV_LOG_ERROR, "No block with entropy info\n");
1465             return AVERROR_INVALIDDATA;
1466         }
1467         if (s->hybrid && !got_hybrid) {
1468             av_log(avctx, AV_LOG_ERROR, "Hybrid config not found\n");
1469             return AVERROR_INVALIDDATA;
1470         }
1471         if (!got_float && sample_fmt == AV_SAMPLE_FMT_FLTP) {
1472             av_log(avctx, AV_LOG_ERROR, "Float information not found\n");
1473             return AVERROR_INVALIDDATA;
1474         }
1475         if (s->got_extra_bits && sample_fmt != AV_SAMPLE_FMT_FLTP) {
1476             const int size   = get_bits_left(&s->gb_extra_bits);
1477             const int wanted = s->samples * s->extra_bits << s->stereo_in;
1478             if (size < wanted) {
1479                 av_log(avctx, AV_LOG_ERROR, "Too small EXTRABITS\n");
1480                 s->got_extra_bits = 0;
1481             }
1482         }
1483     }
1484
1485     if (!got_pcm && !got_dsd) {
1486         av_log(avctx, AV_LOG_ERROR, "Packed samples not found\n");
1487         return AVERROR_INVALIDDATA;
1488     }
1489
1490     if ((got_pcm && wc->modulation != MODULATION_PCM) ||
1491         (got_dsd && wc->modulation != MODULATION_DSD)) {
1492             av_log(avctx, AV_LOG_ERROR, "Invalid PCM/DSD mix encountered\n");
1493             return AVERROR_INVALIDDATA;
1494     }
1495
1496     if (!wc->ch_offset) {
1497         int      new_channels = avctx->channels;
1498         uint64_t new_chmask   = avctx->channel_layout;
1499         int new_samplerate;
1500         int sr = (s->frame_flags >> 23) & 0xf;
1501         if (sr == 0xf) {
1502             if (!sample_rate) {
1503                 av_log(avctx, AV_LOG_ERROR, "Custom sample rate missing.\n");
1504                 return AVERROR_INVALIDDATA;
1505             }
1506             new_samplerate = sample_rate;
1507         } else
1508             new_samplerate = wv_rates[sr];
1509
1510         if (new_samplerate * (uint64_t)rate_x > INT_MAX)
1511             return AVERROR_INVALIDDATA;
1512         new_samplerate *= rate_x;
1513
1514         if (multiblock) {
1515             if (chan)
1516                 new_channels = chan;
1517             if (chmask)
1518                 new_chmask = chmask;
1519         } else {
1520             new_channels = s->stereo ? 2 : 1;
1521             new_chmask   = s->stereo ? AV_CH_LAYOUT_STEREO :
1522                                        AV_CH_LAYOUT_MONO;
1523         }
1524
1525         if (new_chmask &&
1526             av_get_channel_layout_nb_channels(new_chmask) != new_channels) {
1527             av_log(avctx, AV_LOG_ERROR, "Channel mask does not match the channel count\n");
1528             return AVERROR_INVALIDDATA;
1529         }
1530
1531         /* clear DSD state if stream properties change */
1532         if (new_channels   != wc->dsd_channels      ||
1533             new_chmask     != avctx->channel_layout ||
1534             new_samplerate != avctx->sample_rate    ||
1535             !!got_dsd      != !!wc->dsdctx) {
1536             ret = wv_dsd_reset(wc, got_dsd ? new_channels : 0);
1537             if (ret < 0) {
1538                 av_log(avctx, AV_LOG_ERROR, "Error reinitializing the DSD context\n");
1539                 return ret;
1540             }
1541             ff_thread_release_buffer(avctx, &wc->curr_frame);
1542         }
1543         avctx->channels            = new_channels;
1544         avctx->channel_layout      = new_chmask;
1545         avctx->sample_rate         = new_samplerate;
1546         avctx->sample_fmt          = sample_fmt;
1547         avctx->bits_per_raw_sample = orig_bpp;
1548
1549         ff_thread_release_buffer(avctx, &wc->prev_frame);
1550         FFSWAP(ThreadFrame, wc->curr_frame, wc->prev_frame);
1551
1552         /* get output buffer */
1553         wc->curr_frame.f->nb_samples = s->samples;
1554         if ((ret = ff_thread_get_buffer(avctx, &wc->curr_frame, AV_GET_BUFFER_FLAG_REF)) < 0)
1555             return ret;
1556
1557         wc->frame = wc->curr_frame.f;
1558         ff_thread_finish_setup(avctx);
1559     }
1560
1561     if (wc->ch_offset + s->stereo >= avctx->channels) {
1562         av_log(avctx, AV_LOG_WARNING, "Too many channels coded in a packet.\n");
1563         return ((avctx->err_recognition & AV_EF_EXPLODE) || !wc->ch_offset) ? AVERROR_INVALIDDATA : 0;
1564     }
1565
1566     samples_l = wc->frame->extended_data[wc->ch_offset];
1567     if (s->stereo)
1568         samples_r = wc->frame->extended_data[wc->ch_offset + 1];
1569
1570     wc->ch_offset += 1 + s->stereo;
1571
1572     if (s->stereo_in) {
1573         if (got_dsd) {
1574             if (dsd_mode == 3) {
1575                 ret = wv_unpack_dsd_high(s, samples_l, samples_r);
1576             } else if (dsd_mode == 1) {
1577                 ret = wv_unpack_dsd_fast(s, samples_l, samples_r);
1578             } else {
1579                 ret = wv_unpack_dsd_copy(s, samples_l, samples_r);
1580             }
1581         } else {
1582             ret = wv_unpack_stereo(s, &s->gb, samples_l, samples_r, avctx->sample_fmt);
1583         }
1584         if (ret < 0)
1585             return ret;
1586     } else {
1587         if (got_dsd) {
1588             if (dsd_mode == 3) {
1589                 ret = wv_unpack_dsd_high(s, samples_l, NULL);
1590             } else if (dsd_mode == 1) {
1591                 ret = wv_unpack_dsd_fast(s, samples_l, NULL);
1592             } else {
1593                 ret = wv_unpack_dsd_copy(s, samples_l, NULL);
1594             }
1595         } else {
1596             ret = wv_unpack_mono(s, &s->gb, samples_l, avctx->sample_fmt);
1597         }
1598         if (ret < 0)
1599             return ret;
1600
1601         if (s->stereo)
1602             memcpy(samples_r, samples_l, bpp * s->samples);
1603     }
1604
1605     return 0;
1606 }
1607
1608 static void wavpack_decode_flush(AVCodecContext *avctx)
1609 {
1610     WavpackContext *s = avctx->priv_data;
1611
1612     wv_dsd_reset(s, 0);
1613 }
1614
1615 static int dsd_channel(AVCodecContext *avctx, void *frmptr, int jobnr, int threadnr)
1616 {
1617     WavpackContext *s  = avctx->priv_data;
1618     AVFrame *frame = frmptr;
1619
1620     ff_dsd2pcm_translate (&s->dsdctx [jobnr], s->samples, 0,
1621         (uint8_t *)frame->extended_data[jobnr], 4,
1622         (float *)frame->extended_data[jobnr], 1);
1623
1624     return 0;
1625 }
1626
1627 static int wavpack_decode_frame(AVCodecContext *avctx, void *data,
1628                                 int *got_frame_ptr, AVPacket *avpkt)
1629 {
1630     WavpackContext *s  = avctx->priv_data;
1631     const uint8_t *buf = avpkt->data;
1632     int buf_size       = avpkt->size;
1633     int frame_size, ret, frame_flags;
1634
1635     if (avpkt->size <= WV_HEADER_SIZE)
1636         return AVERROR_INVALIDDATA;
1637
1638     s->frame     = NULL;
1639     s->block     = 0;
1640     s->ch_offset = 0;
1641
1642     /* determine number of samples */
1643     s->samples  = AV_RL32(buf + 20);
1644     frame_flags = AV_RL32(buf + 24);
1645     if (s->samples <= 0 || s->samples > WV_MAX_SAMPLES) {
1646         av_log(avctx, AV_LOG_ERROR, "Invalid number of samples: %d\n",
1647                s->samples);
1648         return AVERROR_INVALIDDATA;
1649     }
1650
1651     s->modulation = (frame_flags & WV_DSD_DATA) ? MODULATION_DSD : MODULATION_PCM;
1652
1653     while (buf_size > WV_HEADER_SIZE) {
1654         frame_size = AV_RL32(buf + 4) - 12;
1655         buf       += 20;
1656         buf_size  -= 20;
1657         if (frame_size <= 0 || frame_size > buf_size) {
1658             av_log(avctx, AV_LOG_ERROR,
1659                    "Block %d has invalid size (size %d vs. %d bytes left)\n",
1660                    s->block, frame_size, buf_size);
1661             ret = AVERROR_INVALIDDATA;
1662             goto error;
1663         }
1664         if ((ret = wavpack_decode_block(avctx, s->block, buf, frame_size)) < 0)
1665             goto error;
1666         s->block++;
1667         buf      += frame_size;
1668         buf_size -= frame_size;
1669     }
1670
1671     if (s->ch_offset != avctx->channels) {
1672         av_log(avctx, AV_LOG_ERROR, "Not enough channels coded in a packet.\n");
1673         ret = AVERROR_INVALIDDATA;
1674         goto error;
1675     }
1676
1677     ff_thread_await_progress(&s->prev_frame, INT_MAX, 0);
1678     ff_thread_release_buffer(avctx, &s->prev_frame);
1679
1680     if (s->modulation == MODULATION_DSD)
1681         avctx->execute2(avctx, dsd_channel, s->frame, NULL, avctx->channels);
1682
1683     ff_thread_report_progress(&s->curr_frame, INT_MAX, 0);
1684
1685     if ((ret = av_frame_ref(data, s->frame)) < 0)
1686         return ret;
1687
1688     *got_frame_ptr = 1;
1689
1690     return avpkt->size;
1691
1692 error:
1693     if (s->frame) {
1694         ff_thread_await_progress(&s->prev_frame, INT_MAX, 0);
1695         ff_thread_release_buffer(avctx, &s->prev_frame);
1696         ff_thread_report_progress(&s->curr_frame, INT_MAX, 0);
1697     }
1698
1699     return ret;
1700 }
1701
1702 const AVCodec ff_wavpack_decoder = {
1703     .name           = "wavpack",
1704     .long_name      = NULL_IF_CONFIG_SMALL("WavPack"),
1705     .type           = AVMEDIA_TYPE_AUDIO,
1706     .id             = AV_CODEC_ID_WAVPACK,
1707     .priv_data_size = sizeof(WavpackContext),
1708     .init           = wavpack_decode_init,
1709     .close          = wavpack_decode_end,
1710     .decode         = wavpack_decode_frame,
1711     .flush          = wavpack_decode_flush,
1712     .update_thread_context = ONLY_IF_THREADS_ENABLED(update_thread_context),
1713     .capabilities   = AV_CODEC_CAP_DR1 | AV_CODEC_CAP_FRAME_THREADS |
1714                       AV_CODEC_CAP_SLICE_THREADS | AV_CODEC_CAP_CHANNEL_CONF,
1715     .caps_internal  = FF_CODEC_CAP_INIT_THREADSAFE | FF_CODEC_CAP_INIT_CLEANUP |
1716                       FF_CODEC_CAP_ALLOCATE_PROGRESS,
1717 };