]> git.sesse.net Git - ffmpeg/blob - libavcodec/wavpack.c
81b310f2976e3041ac0e9a9e10ea3875f32eb32c
[ffmpeg] / libavcodec / wavpack.c
1 /*
2  * WavPack lossless audio decoder
3  * Copyright (c) 2006,2011 Konstantin Shishkov
4  *
5  * This file is part of Libav.
6  *
7  * Libav is free software; you can redistribute it and/or
8  * modify it under the terms of the GNU Lesser General Public
9  * License as published by the Free Software Foundation; either
10  * version 2.1 of the License, or (at your option) any later version.
11  *
12  * Libav is distributed in the hope that it will be useful,
13  * but WITHOUT ANY WARRANTY; without even the implied warranty of
14  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
15  * Lesser General Public License for more details.
16  *
17  * You should have received a copy of the GNU Lesser General Public
18  * License along with Libav; if not, write to the Free Software
19  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
20  */
21
22 #define BITSTREAM_READER_LE
23
24 #include "libavutil/channel_layout.h"
25 #include "avcodec.h"
26 #include "get_bits.h"
27 #include "internal.h"
28 #include "unary.h"
29 #include "bytestream.h"
30
31 /**
32  * @file
33  * WavPack lossless audio decoder
34  */
35
36 #define WV_MONO           0x00000004
37 #define WV_JOINT_STEREO   0x00000010
38 #define WV_FALSE_STEREO   0x40000000
39
40 #define WV_HYBRID_MODE    0x00000008
41 #define WV_HYBRID_SHAPE   0x00000008
42 #define WV_HYBRID_BITRATE 0x00000200
43 #define WV_HYBRID_BALANCE 0x00000400
44
45 #define WV_FLT_SHIFT_ONES 0x01
46 #define WV_FLT_SHIFT_SAME 0x02
47 #define WV_FLT_SHIFT_SENT 0x04
48 #define WV_FLT_ZERO_SENT  0x08
49 #define WV_FLT_ZERO_SIGN  0x10
50
51 enum WP_ID_Flags {
52     WP_IDF_MASK   = 0x1F,
53     WP_IDF_IGNORE = 0x20,
54     WP_IDF_ODD    = 0x40,
55     WP_IDF_LONG   = 0x80
56 };
57
58 enum WP_ID {
59     WP_ID_DUMMY = 0,
60     WP_ID_ENCINFO,
61     WP_ID_DECTERMS,
62     WP_ID_DECWEIGHTS,
63     WP_ID_DECSAMPLES,
64     WP_ID_ENTROPY,
65     WP_ID_HYBRID,
66     WP_ID_SHAPING,
67     WP_ID_FLOATINFO,
68     WP_ID_INT32INFO,
69     WP_ID_DATA,
70     WP_ID_CORR,
71     WP_ID_EXTRABITS,
72     WP_ID_CHANINFO
73 };
74
75 typedef struct SavedContext {
76     int offset;
77     int size;
78     int bits_used;
79     uint32_t crc;
80 } SavedContext;
81
82 #define MAX_TERMS 16
83
84 typedef struct Decorr {
85     int delta;
86     int value;
87     int weightA;
88     int weightB;
89     int samplesA[8];
90     int samplesB[8];
91 } Decorr;
92
93 typedef struct WvChannel {
94     int median[3];
95     int slow_level, error_limit;
96     int bitrate_acc, bitrate_delta;
97 } WvChannel;
98
99 typedef struct WavpackFrameContext {
100     AVCodecContext *avctx;
101     int frame_flags;
102     int stereo, stereo_in;
103     int joint;
104     uint32_t CRC;
105     GetBitContext gb;
106     int got_extra_bits;
107     uint32_t crc_extra_bits;
108     GetBitContext gb_extra_bits;
109     int data_size; // in bits
110     int samples;
111     int terms;
112     Decorr decorr[MAX_TERMS];
113     int zero, one, zeroes;
114     int extra_bits;
115     int and, or, shift;
116     int post_shift;
117     int hybrid, hybrid_bitrate;
118     int hybrid_maxclip, hybrid_minclip;
119     int float_flag;
120     int float_shift;
121     int float_max_exp;
122     WvChannel ch[2];
123     int pos;
124     SavedContext sc, extra_sc;
125 } WavpackFrameContext;
126
127 #define WV_MAX_FRAME_DECODERS 14
128
129 typedef struct WavpackContext {
130     AVCodecContext *avctx;
131
132     WavpackFrameContext *fdec[WV_MAX_FRAME_DECODERS];
133     int fdec_num;
134
135     int multichannel;
136     int mkv_mode;
137     int block;
138     int samples;
139     int ch_offset;
140 } WavpackContext;
141
142 // exponent table copied from WavPack source
143 static const uint8_t wp_exp2_table[256] = {
144     0x00, 0x01, 0x01, 0x02, 0x03, 0x03, 0x04, 0x05, 0x06, 0x06, 0x07, 0x08, 0x08, 0x09, 0x0a, 0x0b,
145     0x0b, 0x0c, 0x0d, 0x0e, 0x0e, 0x0f, 0x10, 0x10, 0x11, 0x12, 0x13, 0x13, 0x14, 0x15, 0x16, 0x16,
146     0x17, 0x18, 0x19, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1d, 0x1e, 0x1f, 0x20, 0x20, 0x21, 0x22, 0x23,
147     0x24, 0x24, 0x25, 0x26, 0x27, 0x28, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2c, 0x2d, 0x2e, 0x2f, 0x30,
148     0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3a, 0x3a, 0x3b, 0x3c, 0x3d,
149     0x3e, 0x3f, 0x40, 0x41, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48, 0x48, 0x49, 0x4a, 0x4b,
150     0x4c, 0x4d, 0x4e, 0x4f, 0x50, 0x51, 0x51, 0x52, 0x53, 0x54, 0x55, 0x56, 0x57, 0x58, 0x59, 0x5a,
151     0x5b, 0x5c, 0x5d, 0x5e, 0x5e, 0x5f, 0x60, 0x61, 0x62, 0x63, 0x64, 0x65, 0x66, 0x67, 0x68, 0x69,
152     0x6a, 0x6b, 0x6c, 0x6d, 0x6e, 0x6f, 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79,
153     0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f, 0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x87, 0x88, 0x89, 0x8a,
154     0x8b, 0x8c, 0x8d, 0x8e, 0x8f, 0x90, 0x91, 0x92, 0x93, 0x95, 0x96, 0x97, 0x98, 0x99, 0x9a, 0x9b,
155     0x9c, 0x9d, 0x9f, 0xa0, 0xa1, 0xa2, 0xa3, 0xa4, 0xa5, 0xa6, 0xa8, 0xa9, 0xaa, 0xab, 0xac, 0xad,
156     0xaf, 0xb0, 0xb1, 0xb2, 0xb3, 0xb4, 0xb6, 0xb7, 0xb8, 0xb9, 0xba, 0xbc, 0xbd, 0xbe, 0xbf, 0xc0,
157     0xc2, 0xc3, 0xc4, 0xc5, 0xc6, 0xc8, 0xc9, 0xca, 0xcb, 0xcd, 0xce, 0xcf, 0xd0, 0xd2, 0xd3, 0xd4,
158     0xd6, 0xd7, 0xd8, 0xd9, 0xdb, 0xdc, 0xdd, 0xde, 0xe0, 0xe1, 0xe2, 0xe4, 0xe5, 0xe6, 0xe8, 0xe9,
159     0xea, 0xec, 0xed, 0xee, 0xf0, 0xf1, 0xf2, 0xf4, 0xf5, 0xf6, 0xf8, 0xf9, 0xfa, 0xfc, 0xfd, 0xff
160 };
161
162 static const uint8_t wp_log2_table [] = {
163     0x00, 0x01, 0x03, 0x04, 0x06, 0x07, 0x09, 0x0a, 0x0b, 0x0d, 0x0e, 0x10, 0x11, 0x12, 0x14, 0x15,
164     0x16, 0x18, 0x19, 0x1a, 0x1c, 0x1d, 0x1e, 0x20, 0x21, 0x22, 0x24, 0x25, 0x26, 0x28, 0x29, 0x2a,
165     0x2c, 0x2d, 0x2e, 0x2f, 0x31, 0x32, 0x33, 0x34, 0x36, 0x37, 0x38, 0x39, 0x3b, 0x3c, 0x3d, 0x3e,
166     0x3f, 0x41, 0x42, 0x43, 0x44, 0x45, 0x47, 0x48, 0x49, 0x4a, 0x4b, 0x4d, 0x4e, 0x4f, 0x50, 0x51,
167     0x52, 0x54, 0x55, 0x56, 0x57, 0x58, 0x59, 0x5a, 0x5c, 0x5d, 0x5e, 0x5f, 0x60, 0x61, 0x62, 0x63,
168     0x64, 0x66, 0x67, 0x68, 0x69, 0x6a, 0x6b, 0x6c, 0x6d, 0x6e, 0x6f, 0x70, 0x71, 0x72, 0x74, 0x75,
169     0x76, 0x77, 0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f, 0x80, 0x81, 0x82, 0x83, 0x84, 0x85,
170     0x86, 0x87, 0x88, 0x89, 0x8a, 0x8b, 0x8c, 0x8d, 0x8e, 0x8f, 0x90, 0x91, 0x92, 0x93, 0x94, 0x95,
171     0x96, 0x97, 0x98, 0x99, 0x9a, 0x9b, 0x9b, 0x9c, 0x9d, 0x9e, 0x9f, 0xa0, 0xa1, 0xa2, 0xa3, 0xa4,
172     0xa5, 0xa6, 0xa7, 0xa8, 0xa9, 0xa9, 0xaa, 0xab, 0xac, 0xad, 0xae, 0xaf, 0xb0, 0xb1, 0xb2, 0xb2,
173     0xb3, 0xb4, 0xb5, 0xb6, 0xb7, 0xb8, 0xb9, 0xb9, 0xba, 0xbb, 0xbc, 0xbd, 0xbe, 0xbf, 0xc0, 0xc0,
174     0xc1, 0xc2, 0xc3, 0xc4, 0xc5, 0xc6, 0xc6, 0xc7, 0xc8, 0xc9, 0xca, 0xcb, 0xcb, 0xcc, 0xcd, 0xce,
175     0xcf, 0xd0, 0xd0, 0xd1, 0xd2, 0xd3, 0xd4, 0xd4, 0xd5, 0xd6, 0xd7, 0xd8, 0xd8, 0xd9, 0xda, 0xdb,
176     0xdc, 0xdc, 0xdd, 0xde, 0xdf, 0xe0, 0xe0, 0xe1, 0xe2, 0xe3, 0xe4, 0xe4, 0xe5, 0xe6, 0xe7, 0xe7,
177     0xe8, 0xe9, 0xea, 0xea, 0xeb, 0xec, 0xed, 0xee, 0xee, 0xef, 0xf0, 0xf1, 0xf1, 0xf2, 0xf3, 0xf4,
178     0xf4, 0xf5, 0xf6, 0xf7, 0xf7, 0xf8, 0xf9, 0xf9, 0xfa, 0xfb, 0xfc, 0xfc, 0xfd, 0xfe, 0xff, 0xff
179 };
180
181 static av_always_inline int wp_exp2(int16_t val)
182 {
183     int res, neg = 0;
184
185     if (val < 0) {
186         val = -val;
187         neg = 1;
188     }
189
190     res   = wp_exp2_table[val & 0xFF] | 0x100;
191     val >>= 8;
192     res   = (val > 9) ? (res << (val - 9)) : (res >> (9 - val));
193     return neg ? -res : res;
194 }
195
196 static av_always_inline int wp_log2(int32_t val)
197 {
198     int bits;
199
200     if (!val)
201         return 0;
202     if (val == 1)
203         return 256;
204     val += val >> 9;
205     bits = av_log2(val) + 1;
206     if (bits < 9)
207         return (bits << 8) + wp_log2_table[(val << (9 - bits)) & 0xFF];
208     else
209         return (bits << 8) + wp_log2_table[(val >> (bits - 9)) & 0xFF];
210 }
211
212 #define LEVEL_DECAY(a)  ((a + 0x80) >> 8)
213
214 // macros for manipulating median values
215 #define GET_MED(n) ((c->median[n] >> 4) + 1)
216 #define DEC_MED(n) c->median[n] -= ((c->median[n] + (128 >> n) - 2) / (128 >> n)) * 2
217 #define INC_MED(n) c->median[n] += ((c->median[n] + (128 >> n)    ) / (128 >> n)) * 5
218
219 // macros for applying weight
220 #define UPDATE_WEIGHT_CLIP(weight, delta, samples, in) \
221     if (samples && in) { \
222         if ((samples ^ in) < 0) { \
223             weight -= delta; \
224             if (weight < -1024) \
225                 weight = -1024; \
226         } else { \
227             weight += delta; \
228             if (weight > 1024) \
229                 weight = 1024; \
230         } \
231     }
232
233 static av_always_inline int get_tail(GetBitContext *gb, int k)
234 {
235     int p, e, res;
236
237     if (k < 1)
238         return 0;
239     p   = av_log2(k);
240     e   = (1 << (p + 1)) - k - 1;
241     res = p ? get_bits(gb, p) : 0;
242     if (res >= e)
243         res = (res << 1) - e + get_bits1(gb);
244     return res;
245 }
246
247 static void update_error_limit(WavpackFrameContext *ctx)
248 {
249     int i, br[2], sl[2];
250
251     for (i = 0; i <= ctx->stereo_in; i++) {
252         ctx->ch[i].bitrate_acc += ctx->ch[i].bitrate_delta;
253         br[i]                   = ctx->ch[i].bitrate_acc >> 16;
254         sl[i]                   = LEVEL_DECAY(ctx->ch[i].slow_level);
255     }
256     if (ctx->stereo_in && ctx->hybrid_bitrate) {
257         int balance = (sl[1] - sl[0] + br[1] + 1) >> 1;
258         if (balance > br[0]) {
259             br[1] = br[0] << 1;
260             br[0] = 0;
261         } else if (-balance > br[0]) {
262             br[0] <<= 1;
263             br[1]   = 0;
264         } else {
265             br[1] = br[0] + balance;
266             br[0] = br[0] - balance;
267         }
268     }
269     for (i = 0; i <= ctx->stereo_in; i++) {
270         if (ctx->hybrid_bitrate) {
271             if (sl[i] - br[i] > -0x100)
272                 ctx->ch[i].error_limit = wp_exp2(sl[i] - br[i] + 0x100);
273             else
274                 ctx->ch[i].error_limit = 0;
275         } else {
276             ctx->ch[i].error_limit = wp_exp2(br[i]);
277         }
278     }
279 }
280
281 static int wv_get_value(WavpackFrameContext *ctx, GetBitContext *gb,
282                         int channel, int *last)
283 {
284     int t, t2;
285     int sign, base, add, ret;
286     WvChannel *c = &ctx->ch[channel];
287
288     *last = 0;
289
290     if ((ctx->ch[0].median[0] < 2U) && (ctx->ch[1].median[0] < 2U) &&
291         !ctx->zero && !ctx->one) {
292         if (ctx->zeroes) {
293             ctx->zeroes--;
294             if (ctx->zeroes) {
295                 c->slow_level -= LEVEL_DECAY(c->slow_level);
296                 return 0;
297             }
298         } else {
299             t = get_unary_0_33(gb);
300             if (t >= 2) {
301                 if (get_bits_left(gb) < t - 1)
302                     goto error;
303                 t = get_bits(gb, t - 1) | (1 << (t - 1));
304             } else {
305                 if (get_bits_left(gb) < 0)
306                     goto error;
307             }
308             ctx->zeroes = t;
309             if (ctx->zeroes) {
310                 memset(ctx->ch[0].median, 0, sizeof(ctx->ch[0].median));
311                 memset(ctx->ch[1].median, 0, sizeof(ctx->ch[1].median));
312                 c->slow_level -= LEVEL_DECAY(c->slow_level);
313                 return 0;
314             }
315         }
316     }
317
318     if (ctx->zero) {
319         t         = 0;
320         ctx->zero = 0;
321     } else {
322         t = get_unary_0_33(gb);
323         if (get_bits_left(gb) < 0)
324             goto error;
325         if (t == 16) {
326             t2 = get_unary_0_33(gb);
327             if (t2 < 2) {
328                 if (get_bits_left(gb) < 0)
329                     goto error;
330                 t += t2;
331             } else {
332                 if (get_bits_left(gb) < t2 - 1)
333                     goto error;
334                 t += get_bits(gb, t2 - 1) | (1 << (t2 - 1));
335             }
336         }
337
338         if (ctx->one) {
339             ctx->one = t & 1;
340             t        = (t >> 1) + 1;
341         } else {
342             ctx->one = t & 1;
343             t      >>= 1;
344         }
345         ctx->zero = !ctx->one;
346     }
347
348     if (ctx->hybrid && !channel)
349         update_error_limit(ctx);
350
351     if (!t) {
352         base = 0;
353         add  = GET_MED(0) - 1;
354         DEC_MED(0);
355     } else if (t == 1) {
356         base = GET_MED(0);
357         add  = GET_MED(1) - 1;
358         INC_MED(0);
359         DEC_MED(1);
360     } else if (t == 2) {
361         base = GET_MED(0) + GET_MED(1);
362         add  = GET_MED(2) - 1;
363         INC_MED(0);
364         INC_MED(1);
365         DEC_MED(2);
366     } else {
367         base = GET_MED(0) + GET_MED(1) + GET_MED(2) * (t - 2);
368         add  = GET_MED(2) - 1;
369         INC_MED(0);
370         INC_MED(1);
371         INC_MED(2);
372     }
373     if (!c->error_limit) {
374         ret = base + get_tail(gb, add);
375         if (get_bits_left(gb) <= 0)
376             goto error;
377     } else {
378         int mid = (base * 2 + add + 1) >> 1;
379         while (add > c->error_limit) {
380             if (get_bits_left(gb) <= 0)
381                 goto error;
382             if (get_bits1(gb)) {
383                 add -= (mid - base);
384                 base = mid;
385             } else
386                 add = mid - base - 1;
387             mid = (base * 2 + add + 1) >> 1;
388         }
389         ret = mid;
390     }
391     sign = get_bits1(gb);
392     if (ctx->hybrid_bitrate)
393         c->slow_level += wp_log2(ret) - LEVEL_DECAY(c->slow_level);
394     return sign ? ~ret : ret;
395
396 error:
397     *last = 1;
398     return 0;
399 }
400
401 static inline int wv_get_value_integer(WavpackFrameContext *s, uint32_t *crc,
402                                        int S)
403 {
404     int bit;
405
406     if (s->extra_bits) {
407         S <<= s->extra_bits;
408
409         if (s->got_extra_bits &&
410             get_bits_left(&s->gb_extra_bits) >= s->extra_bits) {
411             S   |= get_bits(&s->gb_extra_bits, s->extra_bits);
412             *crc = *crc * 9 + (S & 0xffff) * 3 + ((unsigned)S >> 16);
413         }
414     }
415
416     bit = (S & s->and) | s->or;
417     bit = ((S + bit) << s->shift) - bit;
418
419     if (s->hybrid)
420         bit = av_clip(bit, s->hybrid_minclip, s->hybrid_maxclip);
421
422     return bit << s->post_shift;
423 }
424
425 static float wv_get_value_float(WavpackFrameContext *s, uint32_t *crc, int S)
426 {
427     union {
428         float    f;
429         uint32_t u;
430     } value;
431
432     unsigned int sign;
433     int exp = s->float_max_exp;
434
435     if (s->got_extra_bits) {
436         const int max_bits  = 1 + 23 + 8 + 1;
437         const int left_bits = get_bits_left(&s->gb_extra_bits);
438
439         if (left_bits + 8 * FF_INPUT_BUFFER_PADDING_SIZE < max_bits)
440             return 0.0;
441     }
442
443     if (S) {
444         S  <<= s->float_shift;
445         sign = S < 0;
446         if (sign)
447             S = -S;
448         if (S >= 0x1000000) {
449             if (s->got_extra_bits && get_bits1(&s->gb_extra_bits))
450                 S = get_bits(&s->gb_extra_bits, 23);
451             else
452                 S = 0;
453             exp = 255;
454         } else if (exp) {
455             int shift = 23 - av_log2(S);
456             exp = s->float_max_exp;
457             if (exp <= shift)
458                 shift = --exp;
459             exp -= shift;
460
461             if (shift) {
462                 S <<= shift;
463                 if ((s->float_flag & WV_FLT_SHIFT_ONES) ||
464                     (s->got_extra_bits &&
465                      (s->float_flag & WV_FLT_SHIFT_SAME) &&
466                      get_bits1(&s->gb_extra_bits))) {
467                     S |= (1 << shift) - 1;
468                 } else if (s->got_extra_bits &&
469                            (s->float_flag & WV_FLT_SHIFT_SENT)) {
470                     S |= get_bits(&s->gb_extra_bits, shift);
471                 }
472             }
473         } else {
474             exp = s->float_max_exp;
475         }
476         S &= 0x7fffff;
477     } else {
478         sign = 0;
479         exp  = 0;
480         if (s->got_extra_bits && (s->float_flag & WV_FLT_ZERO_SENT)) {
481             if (get_bits1(&s->gb_extra_bits)) {
482                 S = get_bits(&s->gb_extra_bits, 23);
483                 if (s->float_max_exp >= 25)
484                     exp = get_bits(&s->gb_extra_bits, 8);
485                 sign = get_bits1(&s->gb_extra_bits);
486             } else {
487                 if (s->float_flag & WV_FLT_ZERO_SIGN)
488                     sign = get_bits1(&s->gb_extra_bits);
489             }
490         }
491     }
492
493     *crc = *crc * 27 + S * 9 + exp * 3 + sign;
494
495     value.u = (sign << 31) | (exp << 23) | S;
496     return value.f;
497 }
498
499 static void wv_reset_saved_context(WavpackFrameContext *s)
500 {
501     s->pos    = 0;
502     s->sc.crc = s->extra_sc.crc = 0xFFFFFFFF;
503 }
504
505 static inline int wv_check_crc(WavpackFrameContext *s, uint32_t crc,
506                                uint32_t crc_extra_bits)
507 {
508     if (crc != s->CRC) {
509         av_log(s->avctx, AV_LOG_ERROR, "CRC error\n");
510         return AVERROR_INVALIDDATA;
511     }
512     if (s->got_extra_bits && crc_extra_bits != s->crc_extra_bits) {
513         av_log(s->avctx, AV_LOG_ERROR, "Extra bits CRC error\n");
514         return AVERROR_INVALIDDATA;
515     }
516
517     return 0;
518 }
519
520 static inline int wv_unpack_stereo(WavpackFrameContext *s, GetBitContext *gb,
521                                    void *dst_l, void *dst_r, const int type)
522 {
523     int i, j, count = 0;
524     int last, t;
525     int A, B, L, L2, R, R2;
526     int pos                 = s->pos;
527     uint32_t crc            = s->sc.crc;
528     uint32_t crc_extra_bits = s->extra_sc.crc;
529     int16_t *dst16_l        = dst_l;
530     int16_t *dst16_r        = dst_r;
531     int32_t *dst32_l        = dst_l;
532     int32_t *dst32_r        = dst_r;
533     float *dstfl_l          = dst_l;
534     float *dstfl_r          = dst_r;
535
536     s->one = s->zero = s->zeroes = 0;
537     do {
538         L = wv_get_value(s, gb, 0, &last);
539         if (last)
540             break;
541         R = wv_get_value(s, gb, 1, &last);
542         if (last)
543             break;
544         for (i = 0; i < s->terms; i++) {
545             t = s->decorr[i].value;
546             if (t > 0) {
547                 if (t > 8) {
548                     if (t & 1) {
549                         A = 2 * s->decorr[i].samplesA[0] - s->decorr[i].samplesA[1];
550                         B = 2 * s->decorr[i].samplesB[0] - s->decorr[i].samplesB[1];
551                     } else {
552                         A = (3 * s->decorr[i].samplesA[0] - s->decorr[i].samplesA[1]) >> 1;
553                         B = (3 * s->decorr[i].samplesB[0] - s->decorr[i].samplesB[1]) >> 1;
554                     }
555                     s->decorr[i].samplesA[1] = s->decorr[i].samplesA[0];
556                     s->decorr[i].samplesB[1] = s->decorr[i].samplesB[0];
557                     j                        = 0;
558                 } else {
559                     A = s->decorr[i].samplesA[pos];
560                     B = s->decorr[i].samplesB[pos];
561                     j = (pos + t) & 7;
562                 }
563                 if (type != AV_SAMPLE_FMT_S16P) {
564                     L2 = L + ((s->decorr[i].weightA * (int64_t)A + 512) >> 10);
565                     R2 = R + ((s->decorr[i].weightB * (int64_t)B + 512) >> 10);
566                 } else {
567                     L2 = L + ((s->decorr[i].weightA * A + 512) >> 10);
568                     R2 = R + ((s->decorr[i].weightB * B + 512) >> 10);
569                 }
570                 if (A && L)
571                     s->decorr[i].weightA -= ((((L ^ A) >> 30) & 2) - 1) * s->decorr[i].delta;
572                 if (B && R)
573                     s->decorr[i].weightB -= ((((R ^ B) >> 30) & 2) - 1) * s->decorr[i].delta;
574                 s->decorr[i].samplesA[j] = L = L2;
575                 s->decorr[i].samplesB[j] = R = R2;
576             } else if (t == -1) {
577                 if (type != AV_SAMPLE_FMT_S16P)
578                     L2 = L + ((s->decorr[i].weightA * (int64_t)s->decorr[i].samplesA[0] + 512) >> 10);
579                 else
580                     L2 = L + ((s->decorr[i].weightA * s->decorr[i].samplesA[0] + 512) >> 10);
581                 UPDATE_WEIGHT_CLIP(s->decorr[i].weightA, s->decorr[i].delta, s->decorr[i].samplesA[0], L);
582                 L = L2;
583                 if (type != AV_SAMPLE_FMT_S16P)
584                     R2 = R + ((s->decorr[i].weightB * (int64_t)L2 + 512) >> 10);
585                 else
586                     R2 = R + ((s->decorr[i].weightB * L2 + 512) >> 10);
587                 UPDATE_WEIGHT_CLIP(s->decorr[i].weightB, s->decorr[i].delta, L2, R);
588                 R                        = R2;
589                 s->decorr[i].samplesA[0] = R;
590             } else {
591                 if (type != AV_SAMPLE_FMT_S16P)
592                     R2 = R + ((s->decorr[i].weightB * (int64_t)s->decorr[i].samplesB[0] + 512) >> 10);
593                 else
594                     R2 = R + ((s->decorr[i].weightB * s->decorr[i].samplesB[0] + 512) >> 10);
595                 UPDATE_WEIGHT_CLIP(s->decorr[i].weightB, s->decorr[i].delta, s->decorr[i].samplesB[0], R);
596                 R = R2;
597
598                 if (t == -3) {
599                     R2                       = s->decorr[i].samplesA[0];
600                     s->decorr[i].samplesA[0] = R;
601                 }
602
603                 if (type != AV_SAMPLE_FMT_S16P)
604                     L2 = L + ((s->decorr[i].weightA * (int64_t)R2 + 512) >> 10);
605                 else
606                     L2 = L + ((s->decorr[i].weightA * R2 + 512) >> 10);
607                 UPDATE_WEIGHT_CLIP(s->decorr[i].weightA, s->decorr[i].delta, R2, L);
608                 L                        = L2;
609                 s->decorr[i].samplesB[0] = L;
610             }
611         }
612         pos = (pos + 1) & 7;
613         if (s->joint)
614             L += (R -= (L >> 1));
615         crc = (crc * 3 + L) * 3 + R;
616
617         if (type == AV_SAMPLE_FMT_FLTP) {
618             *dstfl_l++ = wv_get_value_float(s, &crc_extra_bits, L);
619             *dstfl_r++ = wv_get_value_float(s, &crc_extra_bits, R);
620         } else if (type == AV_SAMPLE_FMT_S32P) {
621             *dst32_l++ = wv_get_value_integer(s, &crc_extra_bits, L);
622             *dst32_r++ = wv_get_value_integer(s, &crc_extra_bits, R);
623         } else {
624             *dst16_l++ = wv_get_value_integer(s, &crc_extra_bits, L);
625             *dst16_r++ = wv_get_value_integer(s, &crc_extra_bits, R);
626         }
627         count++;
628     } while (!last && count < s->samples);
629
630     wv_reset_saved_context(s);
631     if ((s->avctx->err_recognition & AV_EF_CRCCHECK) &&
632         wv_check_crc(s, crc, crc_extra_bits))
633         return AVERROR_INVALIDDATA;
634
635     return 0;
636 }
637
638 static inline int wv_unpack_mono(WavpackFrameContext *s, GetBitContext *gb,
639                                  void *dst, const int type)
640 {
641     int i, j, count = 0;
642     int last, t;
643     int A, S, T;
644     int pos                  = s->pos;
645     uint32_t crc             = s->sc.crc;
646     uint32_t crc_extra_bits  = s->extra_sc.crc;
647     int16_t *dst16           = dst;
648     int32_t *dst32           = dst;
649     float *dstfl             = dst;
650
651     s->one = s->zero = s->zeroes = 0;
652     do {
653         T = wv_get_value(s, gb, 0, &last);
654         S = 0;
655         if (last)
656             break;
657         for (i = 0; i < s->terms; i++) {
658             t = s->decorr[i].value;
659             if (t > 8) {
660                 if (t & 1)
661                     A =  2 * s->decorr[i].samplesA[0] - s->decorr[i].samplesA[1];
662                 else
663                     A = (3 * s->decorr[i].samplesA[0] - s->decorr[i].samplesA[1]) >> 1;
664                 s->decorr[i].samplesA[1] = s->decorr[i].samplesA[0];
665                 j                        = 0;
666             } else {
667                 A = s->decorr[i].samplesA[pos];
668                 j = (pos + t) & 7;
669             }
670             if (type != AV_SAMPLE_FMT_S16P)
671                 S = T + ((s->decorr[i].weightA * (int64_t)A + 512) >> 10);
672             else
673                 S = T + ((s->decorr[i].weightA * A + 512) >> 10);
674             if (A && T)
675                 s->decorr[i].weightA -= ((((T ^ A) >> 30) & 2) - 1) * s->decorr[i].delta;
676             s->decorr[i].samplesA[j] = T = S;
677         }
678         pos = (pos + 1) & 7;
679         crc = crc * 3 + S;
680
681         if (type == AV_SAMPLE_FMT_FLTP) {
682             *dstfl++ = wv_get_value_float(s, &crc_extra_bits, S);
683         } else if (type == AV_SAMPLE_FMT_S32P) {
684             *dst32++ = wv_get_value_integer(s, &crc_extra_bits, S);
685         } else {
686             *dst16++ = wv_get_value_integer(s, &crc_extra_bits, S);
687         }
688         count++;
689     } while (!last && count < s->samples);
690
691     wv_reset_saved_context(s);
692     if ((s->avctx->err_recognition & AV_EF_CRCCHECK) &&
693         wv_check_crc(s, crc, crc_extra_bits))
694         return AVERROR_INVALIDDATA;
695
696     return 0;
697 }
698
699 static av_cold int wv_alloc_frame_context(WavpackContext *c)
700 {
701     if (c->fdec_num == WV_MAX_FRAME_DECODERS)
702         return -1;
703
704     c->fdec[c->fdec_num] = av_mallocz(sizeof(**c->fdec));
705     if (!c->fdec[c->fdec_num])
706         return -1;
707     c->fdec_num++;
708     c->fdec[c->fdec_num - 1]->avctx = c->avctx;
709     wv_reset_saved_context(c->fdec[c->fdec_num - 1]);
710
711     return 0;
712 }
713
714 static av_cold int wavpack_decode_init(AVCodecContext *avctx)
715 {
716     WavpackContext *s = avctx->priv_data;
717
718     s->avctx = avctx;
719     if (avctx->bits_per_coded_sample <= 16)
720         avctx->sample_fmt = AV_SAMPLE_FMT_S16P;
721     else
722         avctx->sample_fmt = AV_SAMPLE_FMT_S32P;
723     if (avctx->channels <= 2 && !avctx->channel_layout)
724         avctx->channel_layout = (avctx->channels == 2) ? AV_CH_LAYOUT_STEREO
725                                                        : AV_CH_LAYOUT_MONO;
726
727     s->multichannel = avctx->channels > 2;
728     /* lavf demuxer does not provide extradata, Matroska stores 0x403
729      * there, use this to detect decoding mode for multichannel */
730     s->mkv_mode = 0;
731     if (s->multichannel && avctx->extradata && avctx->extradata_size == 2) {
732         int ver = AV_RL16(avctx->extradata);
733         if (ver >= 0x402 && ver <= 0x410)
734             s->mkv_mode = 1;
735     }
736
737     s->fdec_num = 0;
738
739     return 0;
740 }
741
742 static av_cold int wavpack_decode_end(AVCodecContext *avctx)
743 {
744     WavpackContext *s = avctx->priv_data;
745     int i;
746
747     for (i = 0; i < s->fdec_num; i++)
748         av_freep(&s->fdec[i]);
749     s->fdec_num = 0;
750
751     return 0;
752 }
753
754 static int wavpack_decode_block(AVCodecContext *avctx, int block_no,
755                                 uint8_t **data, int *got_frame_ptr,
756                                 const uint8_t *buf, int buf_size)
757 {
758     WavpackContext *wc = avctx->priv_data;
759     WavpackFrameContext *s;
760     GetByteContext gb;
761     void *samples_l, *samples_r;
762     int ret;
763     int got_terms   = 0, got_weights = 0, got_samples = 0,
764         got_entropy = 0, got_bs      = 0, got_float   = 0, got_hybrid = 0;
765     int i, j, id, size, ssize, weights, t;
766     int bpp, chan, chmask, orig_bpp;
767
768     if (buf_size == 0) {
769         *got_frame_ptr = 0;
770         return 0;
771     }
772
773     if (block_no >= wc->fdec_num && wv_alloc_frame_context(wc) < 0) {
774         av_log(avctx, AV_LOG_ERROR, "Error creating frame decode context\n");
775         return AVERROR_INVALIDDATA;
776     }
777
778     s = wc->fdec[block_no];
779     if (!s) {
780         av_log(avctx, AV_LOG_ERROR, "Context for block %d is not present\n",
781                block_no);
782         return AVERROR_INVALIDDATA;
783     }
784
785     memset(s->decorr, 0, MAX_TERMS * sizeof(Decorr));
786     memset(s->ch, 0, sizeof(s->ch));
787     s->extra_bits     = 0;
788     s->and            = s->or = s->shift = 0;
789     s->got_extra_bits = 0;
790
791     bytestream2_init(&gb, buf, buf_size);
792
793     if (!wc->mkv_mode) {
794         s->samples = bytestream2_get_le32(&gb);
795         if (s->samples != wc->samples)
796             return AVERROR_INVALIDDATA;
797
798         if (!s->samples) {
799             *got_frame_ptr = 0;
800             return 0;
801         }
802     } else {
803         s->samples = wc->samples;
804     }
805     s->frame_flags = bytestream2_get_le32(&gb);
806     bpp            = av_get_bytes_per_sample(avctx->sample_fmt);
807     orig_bpp       = ((s->frame_flags & 0x03) + 1) << 3;
808
809     s->stereo         = !(s->frame_flags & WV_MONO);
810     s->stereo_in      =  (s->frame_flags & WV_FALSE_STEREO) ? 0 : s->stereo;
811     s->joint          =   s->frame_flags & WV_JOINT_STEREO;
812     s->hybrid         =   s->frame_flags & WV_HYBRID_MODE;
813     s->hybrid_bitrate =   s->frame_flags & WV_HYBRID_BITRATE;
814     s->post_shift     = bpp * 8 - orig_bpp + ((s->frame_flags >> 13) & 0x1f);
815     s->hybrid_maxclip =  ((1LL << (orig_bpp - 1)) - 1);
816     s->hybrid_minclip = ((-1LL << (orig_bpp - 1)));
817     s->CRC            = bytestream2_get_le32(&gb);
818
819     samples_l = data[wc->ch_offset];
820     if (s->stereo)
821         samples_r = data[wc->ch_offset + 1];
822
823     if (wc->mkv_mode)
824         bytestream2_skip(&gb, 4);  // skip block size;
825
826     wc->ch_offset += 1 + s->stereo;
827
828     // parse metadata blocks
829     while (bytestream2_get_bytes_left(&gb)) {
830         id   = bytestream2_get_byte(&gb);
831         size = bytestream2_get_byte(&gb);
832         if (id & WP_IDF_LONG) {
833             size |= (bytestream2_get_byte(&gb)) << 8;
834             size |= (bytestream2_get_byte(&gb)) << 16;
835         }
836         size <<= 1; // size is specified in words
837         ssize  = size;
838         if (id & WP_IDF_ODD)
839             size--;
840         if (size < 0) {
841             av_log(avctx, AV_LOG_ERROR,
842                    "Got incorrect block %02X with size %i\n", id, size);
843             break;
844         }
845         if (bytestream2_get_bytes_left(&gb) < ssize) {
846             av_log(avctx, AV_LOG_ERROR,
847                    "Block size %i is out of bounds\n", size);
848             break;
849         }
850         if (id & WP_IDF_IGNORE) {
851             bytestream2_skip(&gb, ssize);
852             continue;
853         }
854         switch (id & WP_IDF_MASK) {
855         case WP_ID_DECTERMS:
856             if (size > MAX_TERMS) {
857                 av_log(avctx, AV_LOG_ERROR, "Too many decorrelation terms\n");
858                 s->terms = 0;
859                 bytestream2_skip(&gb, ssize);
860                 continue;
861             }
862             s->terms = size;
863             for (i = 0; i < s->terms; i++) {
864                 uint8_t val = bytestream2_get_byte(&gb);
865                 s->decorr[s->terms - i - 1].value = (val & 0x1F) - 5;
866                 s->decorr[s->terms - i - 1].delta =  val >> 5;
867             }
868             got_terms = 1;
869             break;
870         case WP_ID_DECWEIGHTS:
871             if (!got_terms) {
872                 av_log(avctx, AV_LOG_ERROR, "No decorrelation terms met\n");
873                 continue;
874             }
875             weights = size >> s->stereo_in;
876             if (weights > MAX_TERMS || weights > s->terms) {
877                 av_log(avctx, AV_LOG_ERROR, "Too many decorrelation weights\n");
878                 bytestream2_skip(&gb, ssize);
879                 continue;
880             }
881             for (i = 0; i < weights; i++) {
882                 t = (int8_t)bytestream2_get_byte(&gb);
883                 s->decorr[s->terms - i - 1].weightA = t << 3;
884                 if (s->decorr[s->terms - i - 1].weightA > 0)
885                     s->decorr[s->terms - i - 1].weightA +=
886                         (s->decorr[s->terms - i - 1].weightA + 64) >> 7;
887                 if (s->stereo_in) {
888                     t = (int8_t)bytestream2_get_byte(&gb);
889                     s->decorr[s->terms - i - 1].weightB = t << 3;
890                     if (s->decorr[s->terms - i - 1].weightB > 0)
891                         s->decorr[s->terms - i - 1].weightB +=
892                             (s->decorr[s->terms - i - 1].weightB + 64) >> 7;
893                 }
894             }
895             got_weights = 1;
896             break;
897         case WP_ID_DECSAMPLES:
898             if (!got_terms) {
899                 av_log(avctx, AV_LOG_ERROR, "No decorrelation terms met\n");
900                 continue;
901             }
902             t = 0;
903             for (i = s->terms - 1; (i >= 0) && (t < size); i--) {
904                 if (s->decorr[i].value > 8) {
905                     s->decorr[i].samplesA[0] =
906                         wp_exp2(bytestream2_get_le16(&gb));
907                     s->decorr[i].samplesA[1] =
908                         wp_exp2(bytestream2_get_le16(&gb));
909
910                     if (s->stereo_in) {
911                         s->decorr[i].samplesB[0] =
912                             wp_exp2(bytestream2_get_le16(&gb));
913                         s->decorr[i].samplesB[1] =
914                             wp_exp2(bytestream2_get_le16(&gb));
915                         t                       += 4;
916                     }
917                     t += 4;
918                 } else if (s->decorr[i].value < 0) {
919                     s->decorr[i].samplesA[0] =
920                         wp_exp2(bytestream2_get_le16(&gb));
921                     s->decorr[i].samplesB[0] =
922                         wp_exp2(bytestream2_get_le16(&gb));
923                     t                       += 4;
924                 } else {
925                     for (j = 0; j < s->decorr[i].value; j++) {
926                         s->decorr[i].samplesA[j] =
927                             wp_exp2(bytestream2_get_le16(&gb));
928                         if (s->stereo_in) {
929                             s->decorr[i].samplesB[j] =
930                                 wp_exp2(bytestream2_get_le16(&gb));
931                         }
932                     }
933                     t += s->decorr[i].value * 2 * (s->stereo_in + 1);
934                 }
935             }
936             got_samples = 1;
937             break;
938         case WP_ID_ENTROPY:
939             if (size != 6 * (s->stereo_in + 1)) {
940                 av_log(avctx, AV_LOG_ERROR,
941                        "Entropy vars size should be %i, got %i",
942                        6 * (s->stereo_in + 1), size);
943                 bytestream2_skip(&gb, ssize);
944                 continue;
945             }
946             for (j = 0; j <= s->stereo_in; j++)
947                 for (i = 0; i < 3; i++) {
948                     s->ch[j].median[i] = wp_exp2(bytestream2_get_le16(&gb));
949                 }
950             got_entropy = 1;
951             break;
952         case WP_ID_HYBRID:
953             if (s->hybrid_bitrate) {
954                 for (i = 0; i <= s->stereo_in; i++) {
955                     s->ch[i].slow_level = wp_exp2(bytestream2_get_le16(&gb));
956                     size               -= 2;
957                 }
958             }
959             for (i = 0; i < (s->stereo_in + 1); i++) {
960                 s->ch[i].bitrate_acc = bytestream2_get_le16(&gb) << 16;
961                 size                -= 2;
962             }
963             if (size > 0) {
964                 for (i = 0; i < (s->stereo_in + 1); i++) {
965                     s->ch[i].bitrate_delta =
966                         wp_exp2((int16_t)bytestream2_get_le16(&gb));
967                 }
968             } else {
969                 for (i = 0; i < (s->stereo_in + 1); i++)
970                     s->ch[i].bitrate_delta = 0;
971             }
972             got_hybrid = 1;
973             break;
974         case WP_ID_INT32INFO: {
975             uint8_t val[4];
976             if (size != 4) {
977                 av_log(avctx, AV_LOG_ERROR,
978                        "Invalid INT32INFO, size = %i\n",
979                        size);
980                 bytestream2_skip(&gb, ssize - 4);
981                 continue;
982             }
983             bytestream2_get_buffer(&gb, val, 4);
984             if (val[0]) {
985                 s->extra_bits = val[0];
986             } else if (val[1]) {
987                 s->shift = val[1];
988             } else if (val[2]) {
989                 s->and   = s->or = 1;
990                 s->shift = val[2];
991             } else if (val[3]) {
992                 s->and   = 1;
993                 s->shift = val[3];
994             }
995             /* original WavPack decoder forces 32-bit lossy sound to be treated
996              * as 24-bit one in order to have proper clipping */
997             if (s->hybrid && bpp == 4 && s->post_shift < 8 && s->shift > 8) {
998                 s->post_shift      += 8;
999                 s->shift           -= 8;
1000                 s->hybrid_maxclip >>= 8;
1001                 s->hybrid_minclip >>= 8;
1002             }
1003             break;
1004         }
1005         case WP_ID_FLOATINFO:
1006             if (size != 4) {
1007                 av_log(avctx, AV_LOG_ERROR,
1008                        "Invalid FLOATINFO, size = %i\n", size);
1009                 bytestream2_skip(&gb, ssize);
1010                 continue;
1011             }
1012             s->float_flag    = bytestream2_get_byte(&gb);
1013             s->float_shift   = bytestream2_get_byte(&gb);
1014             s->float_max_exp = bytestream2_get_byte(&gb);
1015             got_float        = 1;
1016             bytestream2_skip(&gb, 1);
1017             break;
1018         case WP_ID_DATA:
1019             s->sc.offset = bytestream2_tell(&gb);
1020             s->sc.size   = size * 8;
1021             init_get_bits(&s->gb, gb.buffer, size * 8);
1022             s->data_size = size * 8;
1023             bytestream2_skip(&gb, size);
1024             got_bs       = 1;
1025             break;
1026         case WP_ID_EXTRABITS:
1027             if (size <= 4) {
1028                 av_log(avctx, AV_LOG_ERROR, "Invalid EXTRABITS, size = %i\n",
1029                        size);
1030                 bytestream2_skip(&gb, size);
1031                 continue;
1032             }
1033             s->extra_sc.offset = bytestream2_tell(&gb);
1034             s->extra_sc.size   = size * 8;
1035             init_get_bits(&s->gb_extra_bits, gb.buffer, size * 8);
1036             s->crc_extra_bits  = get_bits_long(&s->gb_extra_bits, 32);
1037             bytestream2_skip(&gb, size);
1038             s->got_extra_bits  = 1;
1039             break;
1040         case WP_ID_CHANINFO:
1041             if (size <= 1) {
1042                 av_log(avctx, AV_LOG_ERROR,
1043                        "Insufficient channel information\n");
1044                 return AVERROR_INVALIDDATA;
1045             }
1046             chan = bytestream2_get_byte(&gb);
1047             switch (size - 2) {
1048             case 0:
1049                 chmask = bytestream2_get_byte(&gb);
1050                 break;
1051             case 1:
1052                 chmask = bytestream2_get_le16(&gb);
1053                 break;
1054             case 2:
1055                 chmask = bytestream2_get_le24(&gb);
1056                 break;
1057             case 3:
1058                 chmask = bytestream2_get_le32(&gb);;
1059                 break;
1060             case 5:
1061                 bytestream2_skip(&gb, 1);
1062                 chan  |= (bytestream2_get_byte(&gb) & 0xF) << 8;
1063                 chmask = bytestream2_get_le16(&gb);
1064                 break;
1065             default:
1066                 av_log(avctx, AV_LOG_ERROR, "Invalid channel info size %d\n",
1067                        size);
1068                 chan   = avctx->channels;
1069                 chmask = avctx->channel_layout;
1070             }
1071             if (chan != avctx->channels) {
1072                 av_log(avctx, AV_LOG_ERROR,
1073                        "Block reports total %d channels, "
1074                        "decoder believes it's %d channels\n",
1075                        chan, avctx->channels);
1076                 return AVERROR_INVALIDDATA;
1077             }
1078             if (!avctx->channel_layout)
1079                 avctx->channel_layout = chmask;
1080             break;
1081         default:
1082             bytestream2_skip(&gb, size);
1083         }
1084         if (id & WP_IDF_ODD)
1085             bytestream2_skip(&gb, 1);
1086     }
1087
1088     if (!got_terms) {
1089         av_log(avctx, AV_LOG_ERROR, "No block with decorrelation terms\n");
1090         return AVERROR_INVALIDDATA;
1091     }
1092     if (!got_weights) {
1093         av_log(avctx, AV_LOG_ERROR, "No block with decorrelation weights\n");
1094         return AVERROR_INVALIDDATA;
1095     }
1096     if (!got_samples) {
1097         av_log(avctx, AV_LOG_ERROR, "No block with decorrelation samples\n");
1098         return AVERROR_INVALIDDATA;
1099     }
1100     if (!got_entropy) {
1101         av_log(avctx, AV_LOG_ERROR, "No block with entropy info\n");
1102         return AVERROR_INVALIDDATA;
1103     }
1104     if (s->hybrid && !got_hybrid) {
1105         av_log(avctx, AV_LOG_ERROR, "Hybrid config not found\n");
1106         return AVERROR_INVALIDDATA;
1107     }
1108     if (!got_bs) {
1109         av_log(avctx, AV_LOG_ERROR, "Packed samples not found\n");
1110         return AVERROR_INVALIDDATA;
1111     }
1112     if (!got_float && avctx->sample_fmt == AV_SAMPLE_FMT_FLTP) {
1113         av_log(avctx, AV_LOG_ERROR, "Float information not found\n");
1114         return AVERROR_INVALIDDATA;
1115     }
1116     if (s->got_extra_bits && avctx->sample_fmt != AV_SAMPLE_FMT_FLTP) {
1117         const int size   = get_bits_left(&s->gb_extra_bits);
1118         const int wanted = s->samples * s->extra_bits << s->stereo_in;
1119         if (size < wanted) {
1120             av_log(avctx, AV_LOG_ERROR, "Too small EXTRABITS\n");
1121             s->got_extra_bits = 0;
1122         }
1123     }
1124
1125     if (s->stereo_in) {
1126         ret = wv_unpack_stereo(s, &s->gb, samples_l, samples_r, avctx->sample_fmt);
1127         if (ret < 0)
1128             return ret;
1129     } else {
1130         ret = wv_unpack_mono(s, &s->gb, samples_l, avctx->sample_fmt);
1131         if (ret < 0)
1132             return ret;
1133
1134         if (s->stereo)
1135             memcpy(samples_r, samples_l, bpp * s->samples);
1136     }
1137
1138     *got_frame_ptr = 1;
1139
1140     return 0;
1141 }
1142
1143 static void wavpack_decode_flush(AVCodecContext *avctx)
1144 {
1145     WavpackContext *s = avctx->priv_data;
1146     int i;
1147
1148     for (i = 0; i < s->fdec_num; i++)
1149         wv_reset_saved_context(s->fdec[i]);
1150 }
1151
1152 static int wavpack_decode_frame(AVCodecContext *avctx, void *data,
1153                                 int *got_frame_ptr, AVPacket *avpkt)
1154 {
1155     WavpackContext *s  = avctx->priv_data;
1156     const uint8_t *buf = avpkt->data;
1157     int buf_size       = avpkt->size;
1158     AVFrame *frame     = data;
1159     int frame_size, ret, frame_flags;
1160
1161     if (avpkt->size < 12 + s->multichannel * 4)
1162         return AVERROR_INVALIDDATA;
1163
1164     s->block     = 0;
1165     s->ch_offset = 0;
1166
1167     /* determine number of samples */
1168     if (s->mkv_mode) {
1169         s->samples  = AV_RL32(buf);
1170         buf        += 4;
1171         frame_flags = AV_RL32(buf);
1172     } else {
1173         if (s->multichannel) {
1174             s->samples  = AV_RL32(buf + 4);
1175             frame_flags = AV_RL32(buf + 8);
1176         } else {
1177             s->samples  = AV_RL32(buf);
1178             frame_flags = AV_RL32(buf + 4);
1179         }
1180     }
1181     if (s->samples <= 0) {
1182         av_log(avctx, AV_LOG_ERROR, "Invalid number of samples: %d\n",
1183                s->samples);
1184         return AVERROR_INVALIDDATA;
1185     }
1186
1187     if (frame_flags & 0x80) {
1188         avctx->sample_fmt = AV_SAMPLE_FMT_FLTP;
1189     } else if ((frame_flags & 0x03) <= 1) {
1190         avctx->sample_fmt = AV_SAMPLE_FMT_S16P;
1191     } else {
1192         avctx->sample_fmt          = AV_SAMPLE_FMT_S32P;
1193         avctx->bits_per_raw_sample = ((frame_flags & 0x03) + 1) << 3;
1194     }
1195
1196     /* get output buffer */
1197     frame->nb_samples = s->samples;
1198     if ((ret = ff_get_buffer(avctx, frame, 0)) < 0) {
1199         av_log(avctx, AV_LOG_ERROR, "get_buffer() failed\n");
1200         return ret;
1201     }
1202
1203     while (buf_size > 0) {
1204         if (!s->multichannel) {
1205             frame_size = buf_size;
1206         } else {
1207             if (!s->mkv_mode) {
1208                 frame_size = AV_RL32(buf) - 12;
1209                 buf       += 4;
1210                 buf_size  -= 4;
1211             } else {
1212                 if (buf_size < 12) // MKV files can have zero flags after last block
1213                     break;
1214                 frame_size = AV_RL32(buf + 8) + 12;
1215             }
1216         }
1217         if (frame_size < 0 || frame_size > buf_size) {
1218             av_log(avctx, AV_LOG_ERROR,
1219                    "Block %d has invalid size (size %d vs. %d bytes left)\n",
1220                    s->block, frame_size, buf_size);
1221             wavpack_decode_flush(avctx);
1222             return AVERROR_INVALIDDATA;
1223         }
1224         if ((ret = wavpack_decode_block(avctx, s->block,
1225                                         frame->extended_data, got_frame_ptr,
1226                                         buf, frame_size)) < 0) {
1227             wavpack_decode_flush(avctx);
1228             return ret;
1229         }
1230         s->block++;
1231         buf      += frame_size;
1232         buf_size -= frame_size;
1233     }
1234
1235     return avpkt->size;
1236 }
1237
1238 AVCodec ff_wavpack_decoder = {
1239     .name           = "wavpack",
1240     .type           = AVMEDIA_TYPE_AUDIO,
1241     .id             = AV_CODEC_ID_WAVPACK,
1242     .priv_data_size = sizeof(WavpackContext),
1243     .init           = wavpack_decode_init,
1244     .close          = wavpack_decode_end,
1245     .decode         = wavpack_decode_frame,
1246     .flush          = wavpack_decode_flush,
1247     .capabilities   = CODEC_CAP_DR1,
1248     .long_name      = NULL_IF_CONFIG_SMALL("WavPack"),
1249 };