]> git.sesse.net Git - ffmpeg/blob - libavcodec/wavpack.c
avcodec/wavpack: check for allocation failure
[ffmpeg] / libavcodec / wavpack.c
1 /*
2  * WavPack lossless audio decoder
3  * Copyright (c) 2006,2011 Konstantin Shishkov
4  * Copyright (c) 2020 David Bryant
5  *
6  * This file is part of FFmpeg.
7  *
8  * FFmpeg is free software; you can redistribute it and/or
9  * modify it under the terms of the GNU Lesser General Public
10  * License as published by the Free Software Foundation; either
11  * version 2.1 of the License, or (at your option) any later version.
12  *
13  * FFmpeg is distributed in the hope that it will be useful,
14  * but WITHOUT ANY WARRANTY; without even the implied warranty of
15  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
16  * Lesser General Public License for more details.
17  *
18  * You should have received a copy of the GNU Lesser General Public
19  * License along with FFmpeg; if not, write to the Free Software
20  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
21  */
22
23 #include "libavutil/channel_layout.h"
24
25 #define BITSTREAM_READER_LE
26 #include "avcodec.h"
27 #include "bytestream.h"
28 #include "get_bits.h"
29 #include "internal.h"
30 #include "thread.h"
31 #include "unary.h"
32 #include "wavpack.h"
33 #include "dsd.h"
34
35 /**
36  * @file
37  * WavPack lossless audio decoder
38  */
39
40 #define DSD_BYTE_READY(low,high) (!(((low) ^ (high)) & 0xff000000))
41
42 #define PTABLE_BITS 8
43 #define PTABLE_BINS (1<<PTABLE_BITS)
44 #define PTABLE_MASK (PTABLE_BINS-1)
45
46 #define UP   0x010000fe
47 #define DOWN 0x00010000
48 #define DECAY 8
49
50 #define PRECISION 20
51 #define VALUE_ONE (1 << PRECISION)
52 #define PRECISION_USE 12
53
54 #define RATE_S 20
55
56 #define MAX_HISTORY_BITS    5
57 #define MAX_HISTORY_BINS    (1 << MAX_HISTORY_BITS)
58 #define MAX_BIN_BYTES       1280    // for value_lookup, per bin (2k - 512 - 256)
59
60 typedef enum {
61     MODULATION_PCM,     // pulse code modulation
62     MODULATION_DSD      // pulse density modulation (aka DSD)
63 } Modulation;
64
65 typedef struct WavpackFrameContext {
66     AVCodecContext *avctx;
67     int frame_flags;
68     int stereo, stereo_in;
69     int joint;
70     uint32_t CRC;
71     GetBitContext gb;
72     int got_extra_bits;
73     uint32_t crc_extra_bits;
74     GetBitContext gb_extra_bits;
75     int samples;
76     int terms;
77     Decorr decorr[MAX_TERMS];
78     int zero, one, zeroes;
79     int extra_bits;
80     int and, or, shift;
81     int post_shift;
82     int hybrid, hybrid_bitrate;
83     int hybrid_maxclip, hybrid_minclip;
84     int float_flag;
85     int float_shift;
86     int float_max_exp;
87     WvChannel ch[2];
88
89     GetByteContext gbyte;
90     int ptable [PTABLE_BINS];
91     uint8_t value_lookup_buffer[MAX_HISTORY_BINS*MAX_BIN_BYTES];
92     uint16_t summed_probabilities[MAX_HISTORY_BINS][256];
93     uint8_t probabilities[MAX_HISTORY_BINS][256];
94     uint8_t *value_lookup[MAX_HISTORY_BINS];
95 } WavpackFrameContext;
96
97 #define WV_MAX_FRAME_DECODERS 14
98
99 typedef struct WavpackContext {
100     AVCodecContext *avctx;
101
102     WavpackFrameContext *fdec[WV_MAX_FRAME_DECODERS];
103     int fdec_num;
104
105     int block;
106     int samples;
107     int ch_offset;
108
109     AVFrame *frame;
110     ThreadFrame curr_frame, prev_frame;
111     Modulation modulation;
112     DSDContext *dsdctx;
113 } WavpackContext;
114
115 #define LEVEL_DECAY(a)  (((a) + 0x80) >> 8)
116
117 static av_always_inline unsigned get_tail(GetBitContext *gb, int k)
118 {
119     int p, e, res;
120
121     if (k < 1)
122         return 0;
123     p   = av_log2(k);
124     e   = (1 << (p + 1)) - k - 1;
125     res = get_bitsz(gb, p);
126     if (res >= e)
127         res = (res << 1) - e + get_bits1(gb);
128     return res;
129 }
130
131 static int update_error_limit(WavpackFrameContext *ctx)
132 {
133     int i, br[2], sl[2];
134
135     for (i = 0; i <= ctx->stereo_in; i++) {
136         if (ctx->ch[i].bitrate_acc > UINT_MAX - ctx->ch[i].bitrate_delta)
137             return AVERROR_INVALIDDATA;
138         ctx->ch[i].bitrate_acc += ctx->ch[i].bitrate_delta;
139         br[i]                   = ctx->ch[i].bitrate_acc >> 16;
140         sl[i]                   = LEVEL_DECAY(ctx->ch[i].slow_level);
141     }
142     if (ctx->stereo_in && ctx->hybrid_bitrate) {
143         int balance = (sl[1] - sl[0] + br[1] + 1) >> 1;
144         if (balance > br[0]) {
145             br[1] = br[0] * 2;
146             br[0] = 0;
147         } else if (-balance > br[0]) {
148             br[0]  *= 2;
149             br[1]   = 0;
150         } else {
151             br[1] = br[0] + balance;
152             br[0] = br[0] - balance;
153         }
154     }
155     for (i = 0; i <= ctx->stereo_in; i++) {
156         if (ctx->hybrid_bitrate) {
157             if (sl[i] - br[i] > -0x100)
158                 ctx->ch[i].error_limit = wp_exp2(sl[i] - br[i] + 0x100);
159             else
160                 ctx->ch[i].error_limit = 0;
161         } else {
162             ctx->ch[i].error_limit = wp_exp2(br[i]);
163         }
164     }
165
166     return 0;
167 }
168
169 static int wv_get_value(WavpackFrameContext *ctx, GetBitContext *gb,
170                         int channel, int *last)
171 {
172     int t, t2;
173     int sign, base, add, ret;
174     WvChannel *c = &ctx->ch[channel];
175
176     *last = 0;
177
178     if ((ctx->ch[0].median[0] < 2U) && (ctx->ch[1].median[0] < 2U) &&
179         !ctx->zero && !ctx->one) {
180         if (ctx->zeroes) {
181             ctx->zeroes--;
182             if (ctx->zeroes) {
183                 c->slow_level -= LEVEL_DECAY(c->slow_level);
184                 return 0;
185             }
186         } else {
187             t = get_unary_0_33(gb);
188             if (t >= 2) {
189                 if (t >= 32 || get_bits_left(gb) < t - 1)
190                     goto error;
191                 t = get_bits_long(gb, t - 1) | (1 << (t - 1));
192             } else {
193                 if (get_bits_left(gb) < 0)
194                     goto error;
195             }
196             ctx->zeroes = t;
197             if (ctx->zeroes) {
198                 memset(ctx->ch[0].median, 0, sizeof(ctx->ch[0].median));
199                 memset(ctx->ch[1].median, 0, sizeof(ctx->ch[1].median));
200                 c->slow_level -= LEVEL_DECAY(c->slow_level);
201                 return 0;
202             }
203         }
204     }
205
206     if (ctx->zero) {
207         t         = 0;
208         ctx->zero = 0;
209     } else {
210         t = get_unary_0_33(gb);
211         if (get_bits_left(gb) < 0)
212             goto error;
213         if (t == 16) {
214             t2 = get_unary_0_33(gb);
215             if (t2 < 2) {
216                 if (get_bits_left(gb) < 0)
217                     goto error;
218                 t += t2;
219             } else {
220                 if (t2 >= 32 || get_bits_left(gb) < t2 - 1)
221                     goto error;
222                 t += get_bits_long(gb, t2 - 1) | (1 << (t2 - 1));
223             }
224         }
225
226         if (ctx->one) {
227             ctx->one = t & 1;
228             t        = (t >> 1) + 1;
229         } else {
230             ctx->one = t & 1;
231             t      >>= 1;
232         }
233         ctx->zero = !ctx->one;
234     }
235
236     if (ctx->hybrid && !channel) {
237         if (update_error_limit(ctx) < 0)
238             goto error;
239     }
240
241     if (!t) {
242         base = 0;
243         add  = GET_MED(0) - 1;
244         DEC_MED(0);
245     } else if (t == 1) {
246         base = GET_MED(0);
247         add  = GET_MED(1) - 1;
248         INC_MED(0);
249         DEC_MED(1);
250     } else if (t == 2) {
251         base = GET_MED(0) + GET_MED(1);
252         add  = GET_MED(2) - 1;
253         INC_MED(0);
254         INC_MED(1);
255         DEC_MED(2);
256     } else {
257         base = GET_MED(0) + GET_MED(1) + GET_MED(2) * (t - 2U);
258         add  = GET_MED(2) - 1;
259         INC_MED(0);
260         INC_MED(1);
261         INC_MED(2);
262     }
263     if (!c->error_limit) {
264         if (add >= 0x2000000U) {
265             av_log(ctx->avctx, AV_LOG_ERROR, "k %d is too large\n", add);
266             goto error;
267         }
268         ret = base + get_tail(gb, add);
269         if (get_bits_left(gb) <= 0)
270             goto error;
271     } else {
272         int mid = (base * 2U + add + 1) >> 1;
273         while (add > c->error_limit) {
274             if (get_bits_left(gb) <= 0)
275                 goto error;
276             if (get_bits1(gb)) {
277                 add -= (mid - (unsigned)base);
278                 base = mid;
279             } else
280                 add = mid - (unsigned)base - 1;
281             mid = (base * 2U + add + 1) >> 1;
282         }
283         ret = mid;
284     }
285     sign = get_bits1(gb);
286     if (ctx->hybrid_bitrate)
287         c->slow_level += wp_log2(ret) - LEVEL_DECAY(c->slow_level);
288     return sign ? ~ret : ret;
289
290 error:
291     ret = get_bits_left(gb);
292     if (ret <= 0) {
293         av_log(ctx->avctx, AV_LOG_ERROR, "Too few bits (%d) left\n", ret);
294     }
295     *last = 1;
296     return 0;
297 }
298
299 static inline int wv_get_value_integer(WavpackFrameContext *s, uint32_t *crc,
300                                        unsigned S)
301 {
302     unsigned bit;
303
304     if (s->extra_bits) {
305         S *= 1 << s->extra_bits;
306
307         if (s->got_extra_bits &&
308             get_bits_left(&s->gb_extra_bits) >= s->extra_bits) {
309             S   |= get_bits_long(&s->gb_extra_bits, s->extra_bits);
310             *crc = *crc * 9 + (S & 0xffff) * 3 + ((unsigned)S >> 16);
311         }
312     }
313
314     bit = (S & s->and) | s->or;
315     bit = ((S + bit) << s->shift) - bit;
316
317     if (s->hybrid)
318         bit = av_clip(bit, s->hybrid_minclip, s->hybrid_maxclip);
319
320     return bit << s->post_shift;
321 }
322
323 static float wv_get_value_float(WavpackFrameContext *s, uint32_t *crc, int S)
324 {
325     union {
326         float    f;
327         uint32_t u;
328     } value;
329
330     unsigned int sign;
331     int exp = s->float_max_exp;
332
333     if (s->got_extra_bits) {
334         const int max_bits  = 1 + 23 + 8 + 1;
335         const int left_bits = get_bits_left(&s->gb_extra_bits);
336
337         if (left_bits + 8 * AV_INPUT_BUFFER_PADDING_SIZE < max_bits)
338             return 0.0;
339     }
340
341     if (S) {
342         S  *= 1U << s->float_shift;
343         sign = S < 0;
344         if (sign)
345             S = -(unsigned)S;
346         if (S >= 0x1000000U) {
347             if (s->got_extra_bits && get_bits1(&s->gb_extra_bits))
348                 S = get_bits(&s->gb_extra_bits, 23);
349             else
350                 S = 0;
351             exp = 255;
352         } else if (exp) {
353             int shift = 23 - av_log2(S);
354             exp = s->float_max_exp;
355             if (exp <= shift)
356                 shift = --exp;
357             exp -= shift;
358
359             if (shift) {
360                 S <<= shift;
361                 if ((s->float_flag & WV_FLT_SHIFT_ONES) ||
362                     (s->got_extra_bits &&
363                      (s->float_flag & WV_FLT_SHIFT_SAME) &&
364                      get_bits1(&s->gb_extra_bits))) {
365                     S |= (1 << shift) - 1;
366                 } else if (s->got_extra_bits &&
367                            (s->float_flag & WV_FLT_SHIFT_SENT)) {
368                     S |= get_bits(&s->gb_extra_bits, shift);
369                 }
370             }
371         } else {
372             exp = s->float_max_exp;
373         }
374         S &= 0x7fffff;
375     } else {
376         sign = 0;
377         exp  = 0;
378         if (s->got_extra_bits && (s->float_flag & WV_FLT_ZERO_SENT)) {
379             if (get_bits1(&s->gb_extra_bits)) {
380                 S = get_bits(&s->gb_extra_bits, 23);
381                 if (s->float_max_exp >= 25)
382                     exp = get_bits(&s->gb_extra_bits, 8);
383                 sign = get_bits1(&s->gb_extra_bits);
384             } else {
385                 if (s->float_flag & WV_FLT_ZERO_SIGN)
386                     sign = get_bits1(&s->gb_extra_bits);
387             }
388         }
389     }
390
391     *crc = *crc * 27 + S * 9 + exp * 3 + sign;
392
393     value.u = (sign << 31) | (exp << 23) | S;
394     return value.f;
395 }
396
397 static inline int wv_check_crc(WavpackFrameContext *s, uint32_t crc,
398                                uint32_t crc_extra_bits)
399 {
400     if (crc != s->CRC) {
401         av_log(s->avctx, AV_LOG_ERROR, "CRC error\n");
402         return AVERROR_INVALIDDATA;
403     }
404     if (s->got_extra_bits && crc_extra_bits != s->crc_extra_bits) {
405         av_log(s->avctx, AV_LOG_ERROR, "Extra bits CRC error\n");
406         return AVERROR_INVALIDDATA;
407     }
408
409     return 0;
410 }
411
412 static void init_ptable(int *table, int rate_i, int rate_s)
413 {
414     int value = 0x808000, rate = rate_i << 8;
415
416     for (int c = (rate + 128) >> 8; c--;)
417         value += (DOWN - value) >> DECAY;
418
419     for (int i = 0; i < PTABLE_BINS/2; i++) {
420         table[i] = value;
421         table[PTABLE_BINS-1-i] = 0x100ffff - value;
422
423         if (value > 0x010000) {
424             rate += (rate * rate_s + 128) >> 8;
425
426             for (int c = (rate + 64) >> 7; c--;)
427                 value += (DOWN - value) >> DECAY;
428         }
429     }
430 }
431
432 typedef struct {
433     int32_t value, fltr0, fltr1, fltr2, fltr3, fltr4, fltr5, fltr6, factor;
434     unsigned int byte;
435 } DSDfilters;
436
437 static int wv_unpack_dsd_high(WavpackFrameContext *s, uint8_t *dst_left, uint8_t *dst_right)
438 {
439     uint32_t checksum = 0xFFFFFFFF;
440     uint8_t *dst_l = dst_left, *dst_r = dst_right;
441     int total_samples = s->samples, stereo = dst_r ? 1 : 0;
442     DSDfilters filters[2], *sp = filters;
443     int rate_i, rate_s;
444     uint32_t low, high, value;
445
446     if (bytestream2_get_bytes_left(&s->gbyte) < (stereo ? 20 : 13))
447         return AVERROR_INVALIDDATA;
448
449     rate_i = bytestream2_get_byte(&s->gbyte);
450     rate_s = bytestream2_get_byte(&s->gbyte);
451
452     if (rate_s != RATE_S)
453         return AVERROR_INVALIDDATA;
454
455     init_ptable(s->ptable, rate_i, rate_s);
456
457     for (int channel = 0; channel < stereo + 1; channel++) {
458         DSDfilters *sp = filters + channel;
459
460         sp->fltr1 = bytestream2_get_byte(&s->gbyte) << (PRECISION - 8);
461         sp->fltr2 = bytestream2_get_byte(&s->gbyte) << (PRECISION - 8);
462         sp->fltr3 = bytestream2_get_byte(&s->gbyte) << (PRECISION - 8);
463         sp->fltr4 = bytestream2_get_byte(&s->gbyte) << (PRECISION - 8);
464         sp->fltr5 = bytestream2_get_byte(&s->gbyte) << (PRECISION - 8);
465         sp->fltr6 = 0;
466         sp->factor = bytestream2_get_byte(&s->gbyte) & 0xff;
467         sp->factor |= (bytestream2_get_byte(&s->gbyte) << 8) & 0xff00;
468         sp->factor = (int32_t)((uint32_t)sp->factor << 16) >> 16;
469     }
470
471     value = bytestream2_get_be32(&s->gbyte);
472     high = 0xffffffff;
473     low = 0x0;
474
475     while (total_samples--) {
476         int bitcount = 8;
477
478         sp[0].value = sp[0].fltr1 - sp[0].fltr5 + ((sp[0].fltr6 * sp[0].factor) >> 2);
479
480         if (stereo)
481             sp[1].value = sp[1].fltr1 - sp[1].fltr5 + ((sp[1].fltr6 * sp[1].factor) >> 2);
482
483         while (bitcount--) {
484             int32_t *pp = s->ptable + ((sp[0].value >> (PRECISION - PRECISION_USE)) & PTABLE_MASK);
485             uint32_t split = low + ((high - low) >> 8) * (*pp >> 16);
486
487             if (value <= split) {
488                 high = split;
489                 *pp += (UP - *pp) >> DECAY;
490                 sp[0].fltr0 = -1;
491             } else {
492                 low = split + 1;
493                 *pp += (DOWN - *pp) >> DECAY;
494                 sp[0].fltr0 = 0;
495             }
496
497             while (DSD_BYTE_READY(high, low) && bytestream2_get_bytes_left(&s->gbyte)) {
498                 value = (value << 8) | bytestream2_get_byte(&s->gbyte);
499                 high = (high << 8) | 0xff;
500                 low <<= 8;
501             }
502
503             sp[0].value += sp[0].fltr6 * 8;
504             sp[0].byte = (sp[0].byte << 1) | (sp[0].fltr0 & 1);
505             sp[0].factor += (((sp[0].value ^ sp[0].fltr0) >> 31) | 1) &
506                 ((sp[0].value ^ (sp[0].value - (sp[0].fltr6 * 16))) >> 31);
507             sp[0].fltr1 += ((sp[0].fltr0 & VALUE_ONE) - sp[0].fltr1) >> 6;
508             sp[0].fltr2 += ((sp[0].fltr0 & VALUE_ONE) - sp[0].fltr2) >> 4;
509             sp[0].fltr3 += (sp[0].fltr2 - sp[0].fltr3) >> 4;
510             sp[0].fltr4 += (sp[0].fltr3 - sp[0].fltr4) >> 4;
511             sp[0].value = (sp[0].fltr4 - sp[0].fltr5) >> 4;
512             sp[0].fltr5 += sp[0].value;
513             sp[0].fltr6 += (sp[0].value - sp[0].fltr6) >> 3;
514             sp[0].value = sp[0].fltr1 - sp[0].fltr5 + ((sp[0].fltr6 * sp[0].factor) >> 2);
515
516             if (!stereo)
517                 continue;
518
519             pp = s->ptable + ((sp[1].value >> (PRECISION - PRECISION_USE)) & PTABLE_MASK);
520             split = low + ((high - low) >> 8) * (*pp >> 16);
521
522             if (value <= split) {
523                 high = split;
524                 *pp += (UP - *pp) >> DECAY;
525                 sp[1].fltr0 = -1;
526             } else {
527                 low = split + 1;
528                 *pp += (DOWN - *pp) >> DECAY;
529                 sp[1].fltr0 = 0;
530             }
531
532             while (DSD_BYTE_READY(high, low) && bytestream2_get_bytes_left(&s->gbyte)) {
533                 value = (value << 8) | bytestream2_get_byte(&s->gbyte);
534                 high = (high << 8) | 0xff;
535                 low <<= 8;
536             }
537
538             sp[1].value += sp[1].fltr6 * 8;
539             sp[1].byte = (sp[1].byte << 1) | (sp[1].fltr0 & 1);
540             sp[1].factor += (((sp[1].value ^ sp[1].fltr0) >> 31) | 1) &
541                 ((sp[1].value ^ (sp[1].value - (sp[1].fltr6 * 16))) >> 31);
542             sp[1].fltr1 += ((sp[1].fltr0 & VALUE_ONE) - sp[1].fltr1) >> 6;
543             sp[1].fltr2 += ((sp[1].fltr0 & VALUE_ONE) - sp[1].fltr2) >> 4;
544             sp[1].fltr3 += (sp[1].fltr2 - sp[1].fltr3) >> 4;
545             sp[1].fltr4 += (sp[1].fltr3 - sp[1].fltr4) >> 4;
546             sp[1].value = (sp[1].fltr4 - sp[1].fltr5) >> 4;
547             sp[1].fltr5 += sp[1].value;
548             sp[1].fltr6 += (sp[1].value - sp[1].fltr6) >> 3;
549             sp[1].value = sp[1].fltr1 - sp[1].fltr5 + ((sp[1].fltr6 * sp[1].factor) >> 2);
550         }
551
552         checksum += (checksum << 1) + (*dst_l = sp[0].byte & 0xff);
553         sp[0].factor -= (sp[0].factor + 512) >> 10;
554         dst_l += 4;
555
556         if (stereo) {
557             checksum += (checksum << 1) + (*dst_r = filters[1].byte & 0xff);
558             filters[1].factor -= (filters[1].factor + 512) >> 10;
559             dst_r += 4;
560         }
561     }
562
563     if (wv_check_crc(s, checksum, 0)) {
564         if (s->avctx->err_recognition & AV_EF_CRCCHECK)
565             return AVERROR_INVALIDDATA;
566
567         memset(dst_left, 0x69, s->samples * 4);
568
569         if (dst_r)
570             memset(dst_right, 0x69, s->samples * 4);
571     }
572
573     return 0;
574 }
575
576 static int wv_unpack_dsd_fast(WavpackFrameContext *s, uint8_t *dst_left, uint8_t *dst_right)
577 {
578     uint8_t *dst_l = dst_left, *dst_r = dst_right;
579     uint8_t history_bits, max_probability;
580     int total_summed_probabilities  = 0;
581     int total_samples               = s->samples;
582     uint8_t *vlb                    = s->value_lookup_buffer;
583     int history_bins, p0, p1, chan;
584     uint32_t checksum               = 0xFFFFFFFF;
585     uint32_t low, high, value;
586
587     if (!bytestream2_get_bytes_left(&s->gbyte))
588         return AVERROR_INVALIDDATA;
589
590     history_bits = bytestream2_get_byte(&s->gbyte);
591
592     if (!bytestream2_get_bytes_left(&s->gbyte) || history_bits > MAX_HISTORY_BITS)
593         return AVERROR_INVALIDDATA;
594
595     history_bins = 1 << history_bits;
596     max_probability = bytestream2_get_byte(&s->gbyte);
597
598     if (max_probability < 0xff) {
599         uint8_t *outptr = (uint8_t *)s->probabilities;
600         uint8_t *outend = outptr + sizeof(*s->probabilities) * history_bins;
601
602         while (outptr < outend && bytestream2_get_bytes_left(&s->gbyte)) {
603             int code = bytestream2_get_byte(&s->gbyte);
604
605             if (code > max_probability) {
606                 int zcount = code - max_probability;
607
608                 while (outptr < outend && zcount--)
609                     *outptr++ = 0;
610             } else if (code) {
611                 *outptr++ = code;
612             }
613             else {
614                 break;
615             }
616         }
617
618         if (outptr < outend ||
619             (bytestream2_get_bytes_left(&s->gbyte) && bytestream2_get_byte(&s->gbyte)))
620                 return AVERROR_INVALIDDATA;
621     } else if (bytestream2_get_bytes_left(&s->gbyte) > (int)sizeof(*s->probabilities) * history_bins) {
622         bytestream2_get_buffer(&s->gbyte, (uint8_t *)s->probabilities,
623             sizeof(*s->probabilities) * history_bins);
624     } else {
625         return AVERROR_INVALIDDATA;
626     }
627
628     for (p0 = 0; p0 < history_bins; p0++) {
629         int32_t sum_values = 0;
630
631         for (int i = 0; i < 256; i++)
632             s->summed_probabilities[p0][i] = sum_values += s->probabilities[p0][i];
633
634         if (sum_values) {
635             total_summed_probabilities += sum_values;
636
637             if (total_summed_probabilities > history_bins * MAX_BIN_BYTES)
638                 return AVERROR_INVALIDDATA;
639
640             s->value_lookup[p0] = vlb;
641
642             for (int i = 0; i < 256; i++) {
643                 int c = s->probabilities[p0][i];
644
645                 while (c--)
646                     *vlb++ = i;
647             }
648         }
649     }
650
651     if (bytestream2_get_bytes_left(&s->gbyte) < 4)
652         return AVERROR_INVALIDDATA;
653
654     chan = p0 = p1 = 0;
655     low = 0; high = 0xffffffff;
656     value = bytestream2_get_be32(&s->gbyte);
657
658     if (dst_r)
659         total_samples *= 2;
660
661     while (total_samples--) {
662         unsigned int mult, index, code;
663
664         if (!s->summed_probabilities[p0][255])
665             return AVERROR_INVALIDDATA;
666
667         mult = (high - low) / s->summed_probabilities[p0][255];
668
669         if (!mult) {
670             if (bytestream2_get_bytes_left(&s->gbyte) >= 4)
671                 value = bytestream2_get_be32(&s->gbyte);
672
673             low = 0;
674             high = 0xffffffff;
675             mult = high / s->summed_probabilities[p0][255];
676
677             if (!mult)
678                 return AVERROR_INVALIDDATA;
679         }
680
681         index = (value - low) / mult;
682
683         if (index >= s->summed_probabilities[p0][255])
684             return AVERROR_INVALIDDATA;
685
686         if (!dst_r) {
687             if ((*dst_l = code = s->value_lookup[p0][index]))
688                 low += s->summed_probabilities[p0][code-1] * mult;
689
690             dst_l += 4;
691         } else {
692             if ((code = s->value_lookup[p0][index]))
693                 low += s->summed_probabilities[p0][code-1] * mult;
694
695             if (chan) {
696                 *dst_r = code;
697                 dst_r += 4;
698             }
699             else {
700                 *dst_l = code;
701                 dst_l += 4;
702             }
703
704             chan ^= 1;
705         }
706
707         high = low + s->probabilities[p0][code] * mult - 1;
708         checksum += (checksum << 1) + code;
709
710         if (!dst_r) {
711             p0 = code & (history_bins-1);
712         } else {
713             p0 = p1;
714             p1 = code & (history_bins-1);
715         }
716
717         while (DSD_BYTE_READY(high, low) && bytestream2_get_bytes_left(&s->gbyte)) {
718             value = (value << 8) | bytestream2_get_byte(&s->gbyte);
719             high = (high << 8) | 0xff;
720             low <<= 8;
721         }
722     }
723
724     if (wv_check_crc(s, checksum, 0)) {
725         if (s->avctx->err_recognition & AV_EF_CRCCHECK)
726             return AVERROR_INVALIDDATA;
727
728         memset(dst_left, 0x69, s->samples * 4);
729
730         if (dst_r)
731             memset(dst_right, 0x69, s->samples * 4);
732     }
733
734     return 0;
735 }
736
737 static int wv_unpack_dsd_copy(WavpackFrameContext *s, uint8_t *dst_left, uint8_t *dst_right)
738 {
739     uint8_t *dst_l = dst_left, *dst_r = dst_right;
740     int total_samples           = s->samples;
741     uint32_t checksum           = 0xFFFFFFFF;
742
743     if (bytestream2_get_bytes_left(&s->gbyte) != total_samples * (dst_r ? 2 : 1))
744         return AVERROR_INVALIDDATA;
745
746     while (total_samples--) {
747         checksum += (checksum << 1) + (*dst_l = bytestream2_get_byte(&s->gbyte));
748         dst_l += 4;
749
750         if (dst_r) {
751             checksum += (checksum << 1) + (*dst_r = bytestream2_get_byte(&s->gbyte));
752             dst_r += 4;
753         }
754     }
755
756     if (wv_check_crc(s, checksum, 0)) {
757         if (s->avctx->err_recognition & AV_EF_CRCCHECK)
758             return AVERROR_INVALIDDATA;
759
760         memset(dst_left, 0x69, s->samples * 4);
761
762         if (dst_r)
763             memset(dst_right, 0x69, s->samples * 4);
764     }
765
766     return 0;
767 }
768
769 static inline int wv_unpack_stereo(WavpackFrameContext *s, GetBitContext *gb,
770                                    void *dst_l, void *dst_r, const int type)
771 {
772     int i, j, count = 0;
773     int last, t;
774     int A, B, L, L2, R, R2;
775     int pos                 = 0;
776     uint32_t crc            = 0xFFFFFFFF;
777     uint32_t crc_extra_bits = 0xFFFFFFFF;
778     int16_t *dst16_l        = dst_l;
779     int16_t *dst16_r        = dst_r;
780     int32_t *dst32_l        = dst_l;
781     int32_t *dst32_r        = dst_r;
782     float *dstfl_l          = dst_l;
783     float *dstfl_r          = dst_r;
784
785     s->one = s->zero = s->zeroes = 0;
786     do {
787         L = wv_get_value(s, gb, 0, &last);
788         if (last)
789             break;
790         R = wv_get_value(s, gb, 1, &last);
791         if (last)
792             break;
793         for (i = 0; i < s->terms; i++) {
794             t = s->decorr[i].value;
795             if (t > 0) {
796                 if (t > 8) {
797                     if (t & 1) {
798                         A = 2U * s->decorr[i].samplesA[0] - s->decorr[i].samplesA[1];
799                         B = 2U * s->decorr[i].samplesB[0] - s->decorr[i].samplesB[1];
800                     } else {
801                         A = (int)(3U * s->decorr[i].samplesA[0] - s->decorr[i].samplesA[1]) >> 1;
802                         B = (int)(3U * s->decorr[i].samplesB[0] - s->decorr[i].samplesB[1]) >> 1;
803                     }
804                     s->decorr[i].samplesA[1] = s->decorr[i].samplesA[0];
805                     s->decorr[i].samplesB[1] = s->decorr[i].samplesB[0];
806                     j                        = 0;
807                 } else {
808                     A = s->decorr[i].samplesA[pos];
809                     B = s->decorr[i].samplesB[pos];
810                     j = (pos + t) & 7;
811                 }
812                 if (type != AV_SAMPLE_FMT_S16P) {
813                     L2 = L + ((s->decorr[i].weightA * (int64_t)A + 512) >> 10);
814                     R2 = R + ((s->decorr[i].weightB * (int64_t)B + 512) >> 10);
815                 } else {
816                     L2 = L + (unsigned)((int)(s->decorr[i].weightA * (unsigned)A + 512) >> 10);
817                     R2 = R + (unsigned)((int)(s->decorr[i].weightB * (unsigned)B + 512) >> 10);
818                 }
819                 if (A && L)
820                     s->decorr[i].weightA -= ((((L ^ A) >> 30) & 2) - 1) * s->decorr[i].delta;
821                 if (B && R)
822                     s->decorr[i].weightB -= ((((R ^ B) >> 30) & 2) - 1) * s->decorr[i].delta;
823                 s->decorr[i].samplesA[j] = L = L2;
824                 s->decorr[i].samplesB[j] = R = R2;
825             } else if (t == -1) {
826                 if (type != AV_SAMPLE_FMT_S16P)
827                     L2 = L + ((s->decorr[i].weightA * (int64_t)s->decorr[i].samplesA[0] + 512) >> 10);
828                 else
829                     L2 = L + (unsigned)((int)(s->decorr[i].weightA * (unsigned)s->decorr[i].samplesA[0] + 512) >> 10);
830                 UPDATE_WEIGHT_CLIP(s->decorr[i].weightA, s->decorr[i].delta, s->decorr[i].samplesA[0], L);
831                 L = L2;
832                 if (type != AV_SAMPLE_FMT_S16P)
833                     R2 = R + ((s->decorr[i].weightB * (int64_t)L2 + 512) >> 10);
834                 else
835                     R2 = R + (unsigned)((int)(s->decorr[i].weightB * (unsigned)L2 + 512) >> 10);
836                 UPDATE_WEIGHT_CLIP(s->decorr[i].weightB, s->decorr[i].delta, L2, R);
837                 R                        = R2;
838                 s->decorr[i].samplesA[0] = R;
839             } else {
840                 if (type != AV_SAMPLE_FMT_S16P)
841                     R2 = R + ((s->decorr[i].weightB * (int64_t)s->decorr[i].samplesB[0] + 512) >> 10);
842                 else
843                     R2 = R + (unsigned)((int)(s->decorr[i].weightB * (unsigned)s->decorr[i].samplesB[0] + 512) >> 10);
844                 UPDATE_WEIGHT_CLIP(s->decorr[i].weightB, s->decorr[i].delta, s->decorr[i].samplesB[0], R);
845                 R = R2;
846
847                 if (t == -3) {
848                     R2                       = s->decorr[i].samplesA[0];
849                     s->decorr[i].samplesA[0] = R;
850                 }
851
852                 if (type != AV_SAMPLE_FMT_S16P)
853                     L2 = L + ((s->decorr[i].weightA * (int64_t)R2 + 512) >> 10);
854                 else
855                     L2 = L + (unsigned)((int)(s->decorr[i].weightA * (unsigned)R2 + 512) >> 10);
856                 UPDATE_WEIGHT_CLIP(s->decorr[i].weightA, s->decorr[i].delta, R2, L);
857                 L                        = L2;
858                 s->decorr[i].samplesB[0] = L;
859             }
860         }
861
862         if (type == AV_SAMPLE_FMT_S16P) {
863             if (FFABS((int64_t)L) + FFABS((int64_t)R) > (1<<19)) {
864                 av_log(s->avctx, AV_LOG_ERROR, "sample %d %d too large\n", L, R);
865                 return AVERROR_INVALIDDATA;
866             }
867         }
868
869         pos = (pos + 1) & 7;
870         if (s->joint)
871             L += (unsigned)(R -= (unsigned)(L >> 1));
872         crc = (crc * 3 + L) * 3 + R;
873
874         if (type == AV_SAMPLE_FMT_FLTP) {
875             *dstfl_l++ = wv_get_value_float(s, &crc_extra_bits, L);
876             *dstfl_r++ = wv_get_value_float(s, &crc_extra_bits, R);
877         } else if (type == AV_SAMPLE_FMT_S32P) {
878             *dst32_l++ = wv_get_value_integer(s, &crc_extra_bits, L);
879             *dst32_r++ = wv_get_value_integer(s, &crc_extra_bits, R);
880         } else {
881             *dst16_l++ = wv_get_value_integer(s, &crc_extra_bits, L);
882             *dst16_r++ = wv_get_value_integer(s, &crc_extra_bits, R);
883         }
884         count++;
885     } while (!last && count < s->samples);
886
887     if (last && count < s->samples) {
888         int size = av_get_bytes_per_sample(type);
889         memset((uint8_t*)dst_l + count*size, 0, (s->samples-count)*size);
890         memset((uint8_t*)dst_r + count*size, 0, (s->samples-count)*size);
891     }
892
893     if ((s->avctx->err_recognition & AV_EF_CRCCHECK) &&
894         wv_check_crc(s, crc, crc_extra_bits))
895         return AVERROR_INVALIDDATA;
896
897     return 0;
898 }
899
900 static inline int wv_unpack_mono(WavpackFrameContext *s, GetBitContext *gb,
901                                  void *dst, const int type)
902 {
903     int i, j, count = 0;
904     int last, t;
905     int A, S, T;
906     int pos                  = 0;
907     uint32_t crc             = 0xFFFFFFFF;
908     uint32_t crc_extra_bits  = 0xFFFFFFFF;
909     int16_t *dst16           = dst;
910     int32_t *dst32           = dst;
911     float *dstfl             = dst;
912
913     s->one = s->zero = s->zeroes = 0;
914     do {
915         T = wv_get_value(s, gb, 0, &last);
916         S = 0;
917         if (last)
918             break;
919         for (i = 0; i < s->terms; i++) {
920             t = s->decorr[i].value;
921             if (t > 8) {
922                 if (t & 1)
923                     A =  2U * s->decorr[i].samplesA[0] - s->decorr[i].samplesA[1];
924                 else
925                     A = (int)(3U * s->decorr[i].samplesA[0] - s->decorr[i].samplesA[1]) >> 1;
926                 s->decorr[i].samplesA[1] = s->decorr[i].samplesA[0];
927                 j                        = 0;
928             } else {
929                 A = s->decorr[i].samplesA[pos];
930                 j = (pos + t) & 7;
931             }
932             if (type != AV_SAMPLE_FMT_S16P)
933                 S = T + ((s->decorr[i].weightA * (int64_t)A + 512) >> 10);
934             else
935                 S = T + (unsigned)((int)(s->decorr[i].weightA * (unsigned)A + 512) >> 10);
936             if (A && T)
937                 s->decorr[i].weightA -= ((((T ^ A) >> 30) & 2) - 1) * s->decorr[i].delta;
938             s->decorr[i].samplesA[j] = T = S;
939         }
940         pos = (pos + 1) & 7;
941         crc = crc * 3 + S;
942
943         if (type == AV_SAMPLE_FMT_FLTP) {
944             *dstfl++ = wv_get_value_float(s, &crc_extra_bits, S);
945         } else if (type == AV_SAMPLE_FMT_S32P) {
946             *dst32++ = wv_get_value_integer(s, &crc_extra_bits, S);
947         } else {
948             *dst16++ = wv_get_value_integer(s, &crc_extra_bits, S);
949         }
950         count++;
951     } while (!last && count < s->samples);
952
953     if (last && count < s->samples) {
954         int size = av_get_bytes_per_sample(type);
955         memset((uint8_t*)dst + count*size, 0, (s->samples-count)*size);
956     }
957
958     if (s->avctx->err_recognition & AV_EF_CRCCHECK) {
959         int ret = wv_check_crc(s, crc, crc_extra_bits);
960         if (ret < 0 && s->avctx->err_recognition & AV_EF_EXPLODE)
961             return ret;
962     }
963
964     return 0;
965 }
966
967 static av_cold int wv_alloc_frame_context(WavpackContext *c)
968 {
969     if (c->fdec_num == WV_MAX_FRAME_DECODERS)
970         return -1;
971
972     c->fdec[c->fdec_num] = av_mallocz(sizeof(**c->fdec));
973     if (!c->fdec[c->fdec_num])
974         return -1;
975     c->fdec_num++;
976     c->fdec[c->fdec_num - 1]->avctx = c->avctx;
977
978     return 0;
979 }
980
981 #if HAVE_THREADS
982 static int init_thread_copy(AVCodecContext *avctx)
983 {
984     WavpackContext *s = avctx->priv_data;
985     s->avctx = avctx;
986
987     s->curr_frame.f = av_frame_alloc();
988     s->prev_frame.f = av_frame_alloc();
989
990     if (!s->curr_frame.f || !s->prev_frame.f)
991         return AVERROR(ENOMEM);
992
993     return 0;
994 }
995
996 static int update_thread_context(AVCodecContext *dst, const AVCodecContext *src)
997 {
998     WavpackContext *fsrc = src->priv_data;
999     WavpackContext *fdst = dst->priv_data;
1000     int ret;
1001
1002     if (dst == src)
1003         return 0;
1004
1005     ff_thread_release_buffer(dst, &fdst->curr_frame);
1006     if (fsrc->curr_frame.f->data[0]) {
1007         if ((ret = ff_thread_ref_frame(&fdst->curr_frame, &fsrc->curr_frame)) < 0)
1008             return ret;
1009     }
1010
1011     return 0;
1012 }
1013 #endif
1014
1015 static av_cold int wavpack_decode_init(AVCodecContext *avctx)
1016 {
1017     WavpackContext *s = avctx->priv_data;
1018
1019     s->avctx = avctx;
1020
1021     s->fdec_num = 0;
1022
1023     avctx->internal->allocate_progress = 1;
1024
1025     s->curr_frame.f = av_frame_alloc();
1026     s->prev_frame.f = av_frame_alloc();
1027
1028     // the DSD to PCM context is shared (and used serially) between all decoding threads
1029     s->dsdctx = av_calloc(avctx->channels, sizeof(DSDContext));
1030
1031     if (!s->curr_frame.f || !s->prev_frame.f || !s->dsdctx)
1032         return AVERROR(ENOMEM);
1033
1034     for (int i = 0; i < avctx->channels; i++)
1035         memset(s->dsdctx[i].buf, 0x69, sizeof(s->dsdctx[i].buf));
1036
1037     ff_init_dsd_data();
1038
1039     return 0;
1040 }
1041
1042 static av_cold int wavpack_decode_end(AVCodecContext *avctx)
1043 {
1044     WavpackContext *s = avctx->priv_data;
1045
1046     for (int i = 0; i < s->fdec_num; i++)
1047         av_freep(&s->fdec[i]);
1048     s->fdec_num = 0;
1049
1050     ff_thread_release_buffer(avctx, &s->curr_frame);
1051     av_frame_free(&s->curr_frame.f);
1052
1053     ff_thread_release_buffer(avctx, &s->prev_frame);
1054     av_frame_free(&s->prev_frame.f);
1055
1056     if (!avctx->internal->is_copy)
1057         av_freep(&s->dsdctx);
1058
1059     return 0;
1060 }
1061
1062 static int wavpack_decode_block(AVCodecContext *avctx, int block_no,
1063                                 const uint8_t *buf, int buf_size)
1064 {
1065     WavpackContext *wc = avctx->priv_data;
1066     WavpackFrameContext *s;
1067     GetByteContext gb;
1068     void *samples_l = NULL, *samples_r = NULL;
1069     int ret;
1070     int got_terms   = 0, got_weights = 0, got_samples = 0,
1071         got_entropy = 0, got_pcm     = 0, got_float   = 0, got_hybrid = 0;
1072     int got_dsd = 0;
1073     int i, j, id, size, ssize, weights, t;
1074     int bpp, chan = 0, orig_bpp, sample_rate = 0, rate_x = 1, dsd_mode = 0;
1075     int multiblock;
1076     uint64_t chmask = 0;
1077
1078     if (block_no >= wc->fdec_num && wv_alloc_frame_context(wc) < 0) {
1079         av_log(avctx, AV_LOG_ERROR, "Error creating frame decode context\n");
1080         return AVERROR_INVALIDDATA;
1081     }
1082
1083     s = wc->fdec[block_no];
1084     if (!s) {
1085         av_log(avctx, AV_LOG_ERROR, "Context for block %d is not present\n",
1086                block_no);
1087         return AVERROR_INVALIDDATA;
1088     }
1089
1090     memset(s->decorr, 0, MAX_TERMS * sizeof(Decorr));
1091     memset(s->ch, 0, sizeof(s->ch));
1092     s->extra_bits     = 0;
1093     s->and            = s->or = s->shift = 0;
1094     s->got_extra_bits = 0;
1095
1096     bytestream2_init(&gb, buf, buf_size);
1097
1098     s->samples = bytestream2_get_le32(&gb);
1099     if (s->samples != wc->samples) {
1100         av_log(avctx, AV_LOG_ERROR, "Mismatching number of samples in "
1101                "a sequence: %d and %d\n", wc->samples, s->samples);
1102         return AVERROR_INVALIDDATA;
1103     }
1104     s->frame_flags = bytestream2_get_le32(&gb);
1105     bpp            = av_get_bytes_per_sample(avctx->sample_fmt);
1106     orig_bpp       = ((s->frame_flags & 0x03) + 1) << 3;
1107     multiblock     = (s->frame_flags & WV_SINGLE_BLOCK) != WV_SINGLE_BLOCK;
1108
1109     s->stereo         = !(s->frame_flags & WV_MONO);
1110     s->stereo_in      =  (s->frame_flags & WV_FALSE_STEREO) ? 0 : s->stereo;
1111     s->joint          =   s->frame_flags & WV_JOINT_STEREO;
1112     s->hybrid         =   s->frame_flags & WV_HYBRID_MODE;
1113     s->hybrid_bitrate =   s->frame_flags & WV_HYBRID_BITRATE;
1114     s->post_shift     = bpp * 8 - orig_bpp + ((s->frame_flags >> 13) & 0x1f);
1115     if (s->post_shift < 0 || s->post_shift > 31) {
1116         return AVERROR_INVALIDDATA;
1117     }
1118     s->hybrid_maxclip =  ((1LL << (orig_bpp - 1)) - 1);
1119     s->hybrid_minclip = ((-1UL << (orig_bpp - 1)));
1120     s->CRC            = bytestream2_get_le32(&gb);
1121
1122     // parse metadata blocks
1123     while (bytestream2_get_bytes_left(&gb)) {
1124         id   = bytestream2_get_byte(&gb);
1125         size = bytestream2_get_byte(&gb);
1126         if (id & WP_IDF_LONG)
1127             size |= (bytestream2_get_le16u(&gb)) << 8;
1128         size <<= 1; // size is specified in words
1129         ssize  = size;
1130         if (id & WP_IDF_ODD)
1131             size--;
1132         if (size < 0) {
1133             av_log(avctx, AV_LOG_ERROR,
1134                    "Got incorrect block %02X with size %i\n", id, size);
1135             break;
1136         }
1137         if (bytestream2_get_bytes_left(&gb) < ssize) {
1138             av_log(avctx, AV_LOG_ERROR,
1139                    "Block size %i is out of bounds\n", size);
1140             break;
1141         }
1142         switch (id & WP_IDF_MASK) {
1143         case WP_ID_DECTERMS:
1144             if (size > MAX_TERMS) {
1145                 av_log(avctx, AV_LOG_ERROR, "Too many decorrelation terms\n");
1146                 s->terms = 0;
1147                 bytestream2_skip(&gb, ssize);
1148                 continue;
1149             }
1150             s->terms = size;
1151             for (i = 0; i < s->terms; i++) {
1152                 uint8_t val = bytestream2_get_byte(&gb);
1153                 s->decorr[s->terms - i - 1].value = (val & 0x1F) - 5;
1154                 s->decorr[s->terms - i - 1].delta =  val >> 5;
1155             }
1156             got_terms = 1;
1157             break;
1158         case WP_ID_DECWEIGHTS:
1159             if (!got_terms) {
1160                 av_log(avctx, AV_LOG_ERROR, "No decorrelation terms met\n");
1161                 continue;
1162             }
1163             weights = size >> s->stereo_in;
1164             if (weights > MAX_TERMS || weights > s->terms) {
1165                 av_log(avctx, AV_LOG_ERROR, "Too many decorrelation weights\n");
1166                 bytestream2_skip(&gb, ssize);
1167                 continue;
1168             }
1169             for (i = 0; i < weights; i++) {
1170                 t = (int8_t)bytestream2_get_byte(&gb);
1171                 s->decorr[s->terms - i - 1].weightA = t * (1 << 3);
1172                 if (s->decorr[s->terms - i - 1].weightA > 0)
1173                     s->decorr[s->terms - i - 1].weightA +=
1174                         (s->decorr[s->terms - i - 1].weightA + 64) >> 7;
1175                 if (s->stereo_in) {
1176                     t = (int8_t)bytestream2_get_byte(&gb);
1177                     s->decorr[s->terms - i - 1].weightB = t * (1 << 3);
1178                     if (s->decorr[s->terms - i - 1].weightB > 0)
1179                         s->decorr[s->terms - i - 1].weightB +=
1180                             (s->decorr[s->terms - i - 1].weightB + 64) >> 7;
1181                 }
1182             }
1183             got_weights = 1;
1184             break;
1185         case WP_ID_DECSAMPLES:
1186             if (!got_terms) {
1187                 av_log(avctx, AV_LOG_ERROR, "No decorrelation terms met\n");
1188                 continue;
1189             }
1190             t = 0;
1191             for (i = s->terms - 1; (i >= 0) && (t < size); i--) {
1192                 if (s->decorr[i].value > 8) {
1193                     s->decorr[i].samplesA[0] =
1194                         wp_exp2(bytestream2_get_le16(&gb));
1195                     s->decorr[i].samplesA[1] =
1196                         wp_exp2(bytestream2_get_le16(&gb));
1197
1198                     if (s->stereo_in) {
1199                         s->decorr[i].samplesB[0] =
1200                             wp_exp2(bytestream2_get_le16(&gb));
1201                         s->decorr[i].samplesB[1] =
1202                             wp_exp2(bytestream2_get_le16(&gb));
1203                         t                       += 4;
1204                     }
1205                     t += 4;
1206                 } else if (s->decorr[i].value < 0) {
1207                     s->decorr[i].samplesA[0] =
1208                         wp_exp2(bytestream2_get_le16(&gb));
1209                     s->decorr[i].samplesB[0] =
1210                         wp_exp2(bytestream2_get_le16(&gb));
1211                     t                       += 4;
1212                 } else {
1213                     for (j = 0; j < s->decorr[i].value; j++) {
1214                         s->decorr[i].samplesA[j] =
1215                             wp_exp2(bytestream2_get_le16(&gb));
1216                         if (s->stereo_in) {
1217                             s->decorr[i].samplesB[j] =
1218                                 wp_exp2(bytestream2_get_le16(&gb));
1219                         }
1220                     }
1221                     t += s->decorr[i].value * 2 * (s->stereo_in + 1);
1222                 }
1223             }
1224             got_samples = 1;
1225             break;
1226         case WP_ID_ENTROPY:
1227             if (size != 6 * (s->stereo_in + 1)) {
1228                 av_log(avctx, AV_LOG_ERROR,
1229                        "Entropy vars size should be %i, got %i.\n",
1230                        6 * (s->stereo_in + 1), size);
1231                 bytestream2_skip(&gb, ssize);
1232                 continue;
1233             }
1234             for (j = 0; j <= s->stereo_in; j++)
1235                 for (i = 0; i < 3; i++) {
1236                     s->ch[j].median[i] = wp_exp2(bytestream2_get_le16(&gb));
1237                 }
1238             got_entropy = 1;
1239             break;
1240         case WP_ID_HYBRID:
1241             if (s->hybrid_bitrate) {
1242                 for (i = 0; i <= s->stereo_in; i++) {
1243                     s->ch[i].slow_level = wp_exp2(bytestream2_get_le16(&gb));
1244                     size               -= 2;
1245                 }
1246             }
1247             for (i = 0; i < (s->stereo_in + 1); i++) {
1248                 s->ch[i].bitrate_acc = bytestream2_get_le16(&gb) << 16;
1249                 size                -= 2;
1250             }
1251             if (size > 0) {
1252                 for (i = 0; i < (s->stereo_in + 1); i++) {
1253                     s->ch[i].bitrate_delta =
1254                         wp_exp2((int16_t)bytestream2_get_le16(&gb));
1255                 }
1256             } else {
1257                 for (i = 0; i < (s->stereo_in + 1); i++)
1258                     s->ch[i].bitrate_delta = 0;
1259             }
1260             got_hybrid = 1;
1261             break;
1262         case WP_ID_INT32INFO: {
1263             uint8_t val[4];
1264             if (size != 4) {
1265                 av_log(avctx, AV_LOG_ERROR,
1266                        "Invalid INT32INFO, size = %i\n",
1267                        size);
1268                 bytestream2_skip(&gb, ssize - 4);
1269                 continue;
1270             }
1271             bytestream2_get_buffer(&gb, val, 4);
1272             if (val[0] > 30) {
1273                 av_log(avctx, AV_LOG_ERROR,
1274                        "Invalid INT32INFO, extra_bits = %d (> 30)\n", val[0]);
1275                 continue;
1276             } else if (val[0]) {
1277                 s->extra_bits = val[0];
1278             } else if (val[1]) {
1279                 s->shift = val[1];
1280             } else if (val[2]) {
1281                 s->and   = s->or = 1;
1282                 s->shift = val[2];
1283             } else if (val[3]) {
1284                 s->and   = 1;
1285                 s->shift = val[3];
1286             }
1287             if (s->shift > 31) {
1288                 av_log(avctx, AV_LOG_ERROR,
1289                        "Invalid INT32INFO, shift = %d (> 31)\n", s->shift);
1290                 s->and = s->or = s->shift = 0;
1291                 continue;
1292             }
1293             /* original WavPack decoder forces 32-bit lossy sound to be treated
1294              * as 24-bit one in order to have proper clipping */
1295             if (s->hybrid && bpp == 4 && s->post_shift < 8 && s->shift > 8) {
1296                 s->post_shift      += 8;
1297                 s->shift           -= 8;
1298                 s->hybrid_maxclip >>= 8;
1299                 s->hybrid_minclip >>= 8;
1300             }
1301             break;
1302         }
1303         case WP_ID_FLOATINFO:
1304             if (size != 4) {
1305                 av_log(avctx, AV_LOG_ERROR,
1306                        "Invalid FLOATINFO, size = %i\n", size);
1307                 bytestream2_skip(&gb, ssize);
1308                 continue;
1309             }
1310             s->float_flag    = bytestream2_get_byte(&gb);
1311             s->float_shift   = bytestream2_get_byte(&gb);
1312             s->float_max_exp = bytestream2_get_byte(&gb);
1313             if (s->float_shift > 31) {
1314                 av_log(avctx, AV_LOG_ERROR,
1315                        "Invalid FLOATINFO, shift = %d (> 31)\n", s->float_shift);
1316                 s->float_shift = 0;
1317                 continue;
1318             }
1319             got_float        = 1;
1320             bytestream2_skip(&gb, 1);
1321             break;
1322         case WP_ID_DATA:
1323             if ((ret = init_get_bits8(&s->gb, gb.buffer, size)) < 0)
1324                 return ret;
1325             bytestream2_skip(&gb, size);
1326             got_pcm      = 1;
1327             break;
1328         case WP_ID_DSD_DATA:
1329             if (size < 2) {
1330                 av_log(avctx, AV_LOG_ERROR, "Invalid DSD_DATA, size = %i\n",
1331                        size);
1332                 bytestream2_skip(&gb, ssize);
1333                 continue;
1334             }
1335             rate_x = 1 << bytestream2_get_byte(&gb);
1336             dsd_mode = bytestream2_get_byte(&gb);
1337             if (dsd_mode && dsd_mode != 1 && dsd_mode != 3) {
1338                 av_log(avctx, AV_LOG_ERROR, "Invalid DSD encoding mode: %d\n",
1339                     dsd_mode);
1340                 return AVERROR_INVALIDDATA;
1341             }
1342             bytestream2_init(&s->gbyte, gb.buffer, size-2);
1343             bytestream2_skip(&gb, size-2);
1344             got_dsd      = 1;
1345             break;
1346         case WP_ID_EXTRABITS:
1347             if (size <= 4) {
1348                 av_log(avctx, AV_LOG_ERROR, "Invalid EXTRABITS, size = %i\n",
1349                        size);
1350                 bytestream2_skip(&gb, size);
1351                 continue;
1352             }
1353             if ((ret = init_get_bits8(&s->gb_extra_bits, gb.buffer, size)) < 0)
1354                 return ret;
1355             s->crc_extra_bits  = get_bits_long(&s->gb_extra_bits, 32);
1356             bytestream2_skip(&gb, size);
1357             s->got_extra_bits  = 1;
1358             break;
1359         case WP_ID_CHANINFO:
1360             if (size <= 1) {
1361                 av_log(avctx, AV_LOG_ERROR,
1362                        "Insufficient channel information\n");
1363                 return AVERROR_INVALIDDATA;
1364             }
1365             chan = bytestream2_get_byte(&gb);
1366             switch (size - 2) {
1367             case 0:
1368                 chmask = bytestream2_get_byte(&gb);
1369                 break;
1370             case 1:
1371                 chmask = bytestream2_get_le16(&gb);
1372                 break;
1373             case 2:
1374                 chmask = bytestream2_get_le24(&gb);
1375                 break;
1376             case 3:
1377                 chmask = bytestream2_get_le32(&gb);
1378                 break;
1379             case 4:
1380                 size = bytestream2_get_byte(&gb);
1381                 chan  |= (bytestream2_get_byte(&gb) & 0xF) << 8;
1382                 chan  += 1;
1383                 if (avctx->channels != chan)
1384                     av_log(avctx, AV_LOG_WARNING, "%i channels signalled"
1385                            " instead of %i.\n", chan, avctx->channels);
1386                 chmask = bytestream2_get_le24(&gb);
1387                 break;
1388             case 5:
1389                 size = bytestream2_get_byte(&gb);
1390                 chan  |= (bytestream2_get_byte(&gb) & 0xF) << 8;
1391                 chan  += 1;
1392                 if (avctx->channels != chan)
1393                     av_log(avctx, AV_LOG_WARNING, "%i channels signalled"
1394                            " instead of %i.\n", chan, avctx->channels);
1395                 chmask = bytestream2_get_le32(&gb);
1396                 break;
1397             default:
1398                 av_log(avctx, AV_LOG_ERROR, "Invalid channel info size %d\n",
1399                        size);
1400                 chan   = avctx->channels;
1401                 chmask = avctx->channel_layout;
1402             }
1403             break;
1404         case WP_ID_SAMPLE_RATE:
1405             if (size != 3) {
1406                 av_log(avctx, AV_LOG_ERROR, "Invalid custom sample rate.\n");
1407                 return AVERROR_INVALIDDATA;
1408             }
1409             sample_rate = bytestream2_get_le24(&gb);
1410             break;
1411         default:
1412             bytestream2_skip(&gb, size);
1413         }
1414         if (id & WP_IDF_ODD)
1415             bytestream2_skip(&gb, 1);
1416     }
1417
1418     if (got_pcm) {
1419         if (!got_terms) {
1420             av_log(avctx, AV_LOG_ERROR, "No block with decorrelation terms\n");
1421             return AVERROR_INVALIDDATA;
1422         }
1423         if (!got_weights) {
1424             av_log(avctx, AV_LOG_ERROR, "No block with decorrelation weights\n");
1425             return AVERROR_INVALIDDATA;
1426         }
1427         if (!got_samples) {
1428             av_log(avctx, AV_LOG_ERROR, "No block with decorrelation samples\n");
1429             return AVERROR_INVALIDDATA;
1430         }
1431         if (!got_entropy) {
1432             av_log(avctx, AV_LOG_ERROR, "No block with entropy info\n");
1433             return AVERROR_INVALIDDATA;
1434         }
1435         if (s->hybrid && !got_hybrid) {
1436             av_log(avctx, AV_LOG_ERROR, "Hybrid config not found\n");
1437             return AVERROR_INVALIDDATA;
1438         }
1439         if (!got_float && avctx->sample_fmt == AV_SAMPLE_FMT_FLTP) {
1440             av_log(avctx, AV_LOG_ERROR, "Float information not found\n");
1441             return AVERROR_INVALIDDATA;
1442         }
1443         if (s->got_extra_bits && avctx->sample_fmt != AV_SAMPLE_FMT_FLTP) {
1444             const int size   = get_bits_left(&s->gb_extra_bits);
1445             const int wanted = s->samples * s->extra_bits << s->stereo_in;
1446             if (size < wanted) {
1447                 av_log(avctx, AV_LOG_ERROR, "Too small EXTRABITS\n");
1448                 s->got_extra_bits = 0;
1449             }
1450         }
1451     }
1452
1453     if (!got_pcm && !got_dsd) {
1454         av_log(avctx, AV_LOG_ERROR, "Packed samples not found\n");
1455         return AVERROR_INVALIDDATA;
1456     }
1457
1458     if ((got_pcm && wc->modulation != MODULATION_PCM) ||
1459         (got_dsd && wc->modulation != MODULATION_DSD)) {
1460             av_log(avctx, AV_LOG_ERROR, "Invalid PCM/DSD mix encountered\n");
1461             return AVERROR_INVALIDDATA;
1462     }
1463
1464     if (!wc->ch_offset) {
1465         int sr = (s->frame_flags >> 23) & 0xf;
1466         if (sr == 0xf) {
1467             if (!sample_rate) {
1468                 av_log(avctx, AV_LOG_ERROR, "Custom sample rate missing.\n");
1469                 return AVERROR_INVALIDDATA;
1470             }
1471             avctx->sample_rate = sample_rate * rate_x;
1472         } else
1473             avctx->sample_rate = wv_rates[sr] * rate_x;
1474
1475         if (multiblock) {
1476             if (chan)
1477                 avctx->channels = chan;
1478             if (chmask)
1479                 avctx->channel_layout = chmask;
1480         } else {
1481             avctx->channels       = s->stereo ? 2 : 1;
1482             avctx->channel_layout = s->stereo ? AV_CH_LAYOUT_STEREO :
1483                                                 AV_CH_LAYOUT_MONO;
1484         }
1485
1486         ff_thread_release_buffer(avctx, &wc->prev_frame);
1487         FFSWAP(ThreadFrame, wc->curr_frame, wc->prev_frame);
1488
1489         /* get output buffer */
1490         wc->curr_frame.f->nb_samples = s->samples;
1491         if ((ret = ff_thread_get_buffer(avctx, &wc->curr_frame, AV_GET_BUFFER_FLAG_REF)) < 0)
1492             return ret;
1493
1494         wc->frame = wc->curr_frame.f;
1495         ff_thread_finish_setup(avctx);
1496     }
1497
1498     if (wc->ch_offset + s->stereo >= avctx->channels) {
1499         av_log(avctx, AV_LOG_WARNING, "Too many channels coded in a packet.\n");
1500         return ((avctx->err_recognition & AV_EF_EXPLODE) || !wc->ch_offset) ? AVERROR_INVALIDDATA : 0;
1501     }
1502
1503     samples_l = wc->frame->extended_data[wc->ch_offset];
1504     if (s->stereo)
1505         samples_r = wc->frame->extended_data[wc->ch_offset + 1];
1506
1507     wc->ch_offset += 1 + s->stereo;
1508
1509     if (s->stereo_in) {
1510         if (got_dsd) {
1511             if (dsd_mode == 3) {
1512                 ret = wv_unpack_dsd_high(s, samples_l, samples_r);
1513             } else if (dsd_mode == 1) {
1514                 ret = wv_unpack_dsd_fast(s, samples_l, samples_r);
1515             } else {
1516                 ret = wv_unpack_dsd_copy(s, samples_l, samples_r);
1517             }
1518         } else {
1519             ret = wv_unpack_stereo(s, &s->gb, samples_l, samples_r, avctx->sample_fmt);
1520         }
1521         if (ret < 0)
1522             return ret;
1523     } else {
1524         if (got_dsd) {
1525             if (dsd_mode == 3) {
1526                 ret = wv_unpack_dsd_high(s, samples_l, NULL);
1527             } else if (dsd_mode == 1) {
1528                 ret = wv_unpack_dsd_fast(s, samples_l, NULL);
1529             } else {
1530                 ret = wv_unpack_dsd_copy(s, samples_l, NULL);
1531             }
1532         } else {
1533             ret = wv_unpack_mono(s, &s->gb, samples_l, avctx->sample_fmt);
1534         }
1535         if (ret < 0)
1536             return ret;
1537
1538         if (s->stereo)
1539             memcpy(samples_r, samples_l, bpp * s->samples);
1540     }
1541
1542     return 0;
1543 }
1544
1545 static void wavpack_decode_flush(AVCodecContext *avctx)
1546 {
1547     WavpackContext *s = avctx->priv_data;
1548
1549     if (!avctx->internal->is_copy) {
1550         for (int i = 0; i < avctx->channels; i++)
1551             memset(s->dsdctx[i].buf, 0x69, sizeof(s->dsdctx[i].buf));
1552     }
1553 }
1554
1555 static int dsd_channel(AVCodecContext *avctx, void *frmptr, int jobnr, int threadnr)
1556 {
1557     WavpackContext *s  = avctx->priv_data;
1558     AVFrame *frame = frmptr;
1559
1560     ff_dsd2pcm_translate (&s->dsdctx [jobnr], s->samples, 0,
1561         (uint8_t *)frame->extended_data[jobnr], 4,
1562         (float *)frame->extended_data[jobnr], 1);
1563
1564     return 0;
1565 }
1566
1567 static int wavpack_decode_frame(AVCodecContext *avctx, void *data,
1568                                 int *got_frame_ptr, AVPacket *avpkt)
1569 {
1570     WavpackContext *s  = avctx->priv_data;
1571     const uint8_t *buf = avpkt->data;
1572     int buf_size       = avpkt->size;
1573     int frame_size, ret, frame_flags;
1574
1575     if (avpkt->size <= WV_HEADER_SIZE)
1576         return AVERROR_INVALIDDATA;
1577
1578     s->frame     = NULL;
1579     s->block     = 0;
1580     s->ch_offset = 0;
1581
1582     /* determine number of samples */
1583     s->samples  = AV_RL32(buf + 20);
1584     frame_flags = AV_RL32(buf + 24);
1585     if (s->samples <= 0 || s->samples > WV_MAX_SAMPLES) {
1586         av_log(avctx, AV_LOG_ERROR, "Invalid number of samples: %d\n",
1587                s->samples);
1588         return AVERROR_INVALIDDATA;
1589     }
1590
1591     s->modulation = (frame_flags & WV_DSD_DATA) ? MODULATION_DSD : MODULATION_PCM;
1592
1593     if (frame_flags & (WV_FLOAT_DATA | WV_DSD_DATA)) {
1594         avctx->sample_fmt = AV_SAMPLE_FMT_FLTP;
1595     } else if ((frame_flags & 0x03) <= 1) {
1596         avctx->sample_fmt = AV_SAMPLE_FMT_S16P;
1597     } else {
1598         avctx->sample_fmt          = AV_SAMPLE_FMT_S32P;
1599         avctx->bits_per_raw_sample = ((frame_flags & 0x03) + 1) << 3;
1600     }
1601
1602     while (buf_size > WV_HEADER_SIZE) {
1603         frame_size = AV_RL32(buf + 4) - 12;
1604         buf       += 20;
1605         buf_size  -= 20;
1606         if (frame_size <= 0 || frame_size > buf_size) {
1607             av_log(avctx, AV_LOG_ERROR,
1608                    "Block %d has invalid size (size %d vs. %d bytes left)\n",
1609                    s->block, frame_size, buf_size);
1610             ret = AVERROR_INVALIDDATA;
1611             goto error;
1612         }
1613         if ((ret = wavpack_decode_block(avctx, s->block, buf, frame_size)) < 0)
1614             goto error;
1615         s->block++;
1616         buf      += frame_size;
1617         buf_size -= frame_size;
1618     }
1619
1620     if (s->ch_offset != avctx->channels) {
1621         av_log(avctx, AV_LOG_ERROR, "Not enough channels coded in a packet.\n");
1622         ret = AVERROR_INVALIDDATA;
1623         goto error;
1624     }
1625
1626     ff_thread_await_progress(&s->prev_frame, INT_MAX, 0);
1627     ff_thread_release_buffer(avctx, &s->prev_frame);
1628
1629     if (s->modulation == MODULATION_DSD)
1630         avctx->execute2(avctx, dsd_channel, s->frame, NULL, avctx->channels);
1631
1632     ff_thread_report_progress(&s->curr_frame, INT_MAX, 0);
1633
1634     if ((ret = av_frame_ref(data, s->frame)) < 0)
1635         return ret;
1636
1637     *got_frame_ptr = 1;
1638
1639     return avpkt->size;
1640
1641 error:
1642     if (s->frame) {
1643         ff_thread_await_progress(&s->prev_frame, INT_MAX, 0);
1644         ff_thread_release_buffer(avctx, &s->prev_frame);
1645         ff_thread_report_progress(&s->curr_frame, INT_MAX, 0);
1646     }
1647
1648     return ret;
1649 }
1650
1651 AVCodec ff_wavpack_decoder = {
1652     .name           = "wavpack",
1653     .long_name      = NULL_IF_CONFIG_SMALL("WavPack"),
1654     .type           = AVMEDIA_TYPE_AUDIO,
1655     .id             = AV_CODEC_ID_WAVPACK,
1656     .priv_data_size = sizeof(WavpackContext),
1657     .init           = wavpack_decode_init,
1658     .close          = wavpack_decode_end,
1659     .decode         = wavpack_decode_frame,
1660     .flush          = wavpack_decode_flush,
1661     .init_thread_copy = ONLY_IF_THREADS_ENABLED(init_thread_copy),
1662     .update_thread_context = ONLY_IF_THREADS_ENABLED(update_thread_context),
1663     .capabilities   = AV_CODEC_CAP_DR1 | AV_CODEC_CAP_FRAME_THREADS |
1664                       AV_CODEC_CAP_SLICE_THREADS
1665 };