]> git.sesse.net Git - ffmpeg/blob - libavcodec/wavpack.c
e6e98724911c6279e3ca0f509e860e85c568b8a2
[ffmpeg] / libavcodec / wavpack.c
1 /*
2  * WavPack lossless audio decoder
3  * Copyright (c) 2006,2011 Konstantin Shishkov
4  *
5  * This file is part of Libav.
6  *
7  * Libav is free software; you can redistribute it and/or
8  * modify it under the terms of the GNU Lesser General Public
9  * License as published by the Free Software Foundation; either
10  * version 2.1 of the License, or (at your option) any later version.
11  *
12  * Libav is distributed in the hope that it will be useful,
13  * but WITHOUT ANY WARRANTY; without even the implied warranty of
14  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
15  * Lesser General Public License for more details.
16  *
17  * You should have received a copy of the GNU Lesser General Public
18  * License along with Libav; if not, write to the Free Software
19  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
20  */
21
22 #define BITSTREAM_READER_LE
23
24 #include "libavutil/channel_layout.h"
25 #include "avcodec.h"
26 #include "get_bits.h"
27 #include "internal.h"
28 #include "unary.h"
29 #include "bytestream.h"
30
31 /**
32  * @file
33  * WavPack lossless audio decoder
34  */
35
36 #define WV_MONO           0x00000004
37 #define WV_JOINT_STEREO   0x00000010
38 #define WV_FALSE_STEREO   0x40000000
39
40 #define WV_HYBRID_MODE    0x00000008
41 #define WV_HYBRID_SHAPE   0x00000008
42 #define WV_HYBRID_BITRATE 0x00000200
43 #define WV_HYBRID_BALANCE 0x00000400
44
45 #define WV_FLT_SHIFT_ONES 0x01
46 #define WV_FLT_SHIFT_SAME 0x02
47 #define WV_FLT_SHIFT_SENT 0x04
48 #define WV_FLT_ZERO_SENT  0x08
49 #define WV_FLT_ZERO_SIGN  0x10
50
51 enum WP_ID_Flags {
52     WP_IDF_MASK   = 0x1F,
53     WP_IDF_IGNORE = 0x20,
54     WP_IDF_ODD    = 0x40,
55     WP_IDF_LONG   = 0x80
56 };
57
58 enum WP_ID {
59     WP_ID_DUMMY = 0,
60     WP_ID_ENCINFO,
61     WP_ID_DECTERMS,
62     WP_ID_DECWEIGHTS,
63     WP_ID_DECSAMPLES,
64     WP_ID_ENTROPY,
65     WP_ID_HYBRID,
66     WP_ID_SHAPING,
67     WP_ID_FLOATINFO,
68     WP_ID_INT32INFO,
69     WP_ID_DATA,
70     WP_ID_CORR,
71     WP_ID_EXTRABITS,
72     WP_ID_CHANINFO
73 };
74
75 typedef struct SavedContext {
76     int offset;
77     int size;
78     int bits_used;
79     uint32_t crc;
80 } SavedContext;
81
82 #define MAX_TERMS 16
83
84 typedef struct Decorr {
85     int delta;
86     int value;
87     int weightA;
88     int weightB;
89     int samplesA[8];
90     int samplesB[8];
91 } Decorr;
92
93 typedef struct WvChannel {
94     int median[3];
95     int slow_level, error_limit;
96     int bitrate_acc, bitrate_delta;
97 } WvChannel;
98
99 typedef struct WavpackFrameContext {
100     AVCodecContext *avctx;
101     int frame_flags;
102     int stereo, stereo_in;
103     int joint;
104     uint32_t CRC;
105     GetBitContext gb;
106     int got_extra_bits;
107     uint32_t crc_extra_bits;
108     GetBitContext gb_extra_bits;
109     int data_size; // in bits
110     int samples;
111     int terms;
112     Decorr decorr[MAX_TERMS];
113     int zero, one, zeroes;
114     int extra_bits;
115     int and, or, shift;
116     int post_shift;
117     int hybrid, hybrid_bitrate;
118     int hybrid_maxclip, hybrid_minclip;
119     int float_flag;
120     int float_shift;
121     int float_max_exp;
122     WvChannel ch[2];
123     int pos;
124     SavedContext sc, extra_sc;
125 } WavpackFrameContext;
126
127 #define WV_MAX_FRAME_DECODERS 14
128
129 typedef struct WavpackContext {
130     AVCodecContext *avctx;
131
132     WavpackFrameContext *fdec[WV_MAX_FRAME_DECODERS];
133     int fdec_num;
134
135     int multichannel;
136     int mkv_mode;
137     int block;
138     int samples;
139     int ch_offset;
140 } WavpackContext;
141
142 // exponent table copied from WavPack source
143 static const uint8_t wp_exp2_table[256] = {
144     0x00, 0x01, 0x01, 0x02, 0x03, 0x03, 0x04, 0x05, 0x06, 0x06, 0x07, 0x08, 0x08, 0x09, 0x0a, 0x0b,
145     0x0b, 0x0c, 0x0d, 0x0e, 0x0e, 0x0f, 0x10, 0x10, 0x11, 0x12, 0x13, 0x13, 0x14, 0x15, 0x16, 0x16,
146     0x17, 0x18, 0x19, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1d, 0x1e, 0x1f, 0x20, 0x20, 0x21, 0x22, 0x23,
147     0x24, 0x24, 0x25, 0x26, 0x27, 0x28, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2c, 0x2d, 0x2e, 0x2f, 0x30,
148     0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3a, 0x3a, 0x3b, 0x3c, 0x3d,
149     0x3e, 0x3f, 0x40, 0x41, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48, 0x48, 0x49, 0x4a, 0x4b,
150     0x4c, 0x4d, 0x4e, 0x4f, 0x50, 0x51, 0x51, 0x52, 0x53, 0x54, 0x55, 0x56, 0x57, 0x58, 0x59, 0x5a,
151     0x5b, 0x5c, 0x5d, 0x5e, 0x5e, 0x5f, 0x60, 0x61, 0x62, 0x63, 0x64, 0x65, 0x66, 0x67, 0x68, 0x69,
152     0x6a, 0x6b, 0x6c, 0x6d, 0x6e, 0x6f, 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79,
153     0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f, 0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x87, 0x88, 0x89, 0x8a,
154     0x8b, 0x8c, 0x8d, 0x8e, 0x8f, 0x90, 0x91, 0x92, 0x93, 0x95, 0x96, 0x97, 0x98, 0x99, 0x9a, 0x9b,
155     0x9c, 0x9d, 0x9f, 0xa0, 0xa1, 0xa2, 0xa3, 0xa4, 0xa5, 0xa6, 0xa8, 0xa9, 0xaa, 0xab, 0xac, 0xad,
156     0xaf, 0xb0, 0xb1, 0xb2, 0xb3, 0xb4, 0xb6, 0xb7, 0xb8, 0xb9, 0xba, 0xbc, 0xbd, 0xbe, 0xbf, 0xc0,
157     0xc2, 0xc3, 0xc4, 0xc5, 0xc6, 0xc8, 0xc9, 0xca, 0xcb, 0xcd, 0xce, 0xcf, 0xd0, 0xd2, 0xd3, 0xd4,
158     0xd6, 0xd7, 0xd8, 0xd9, 0xdb, 0xdc, 0xdd, 0xde, 0xe0, 0xe1, 0xe2, 0xe4, 0xe5, 0xe6, 0xe8, 0xe9,
159     0xea, 0xec, 0xed, 0xee, 0xf0, 0xf1, 0xf2, 0xf4, 0xf5, 0xf6, 0xf8, 0xf9, 0xfa, 0xfc, 0xfd, 0xff
160 };
161
162 static const uint8_t wp_log2_table [] = {
163     0x00, 0x01, 0x03, 0x04, 0x06, 0x07, 0x09, 0x0a, 0x0b, 0x0d, 0x0e, 0x10, 0x11, 0x12, 0x14, 0x15,
164     0x16, 0x18, 0x19, 0x1a, 0x1c, 0x1d, 0x1e, 0x20, 0x21, 0x22, 0x24, 0x25, 0x26, 0x28, 0x29, 0x2a,
165     0x2c, 0x2d, 0x2e, 0x2f, 0x31, 0x32, 0x33, 0x34, 0x36, 0x37, 0x38, 0x39, 0x3b, 0x3c, 0x3d, 0x3e,
166     0x3f, 0x41, 0x42, 0x43, 0x44, 0x45, 0x47, 0x48, 0x49, 0x4a, 0x4b, 0x4d, 0x4e, 0x4f, 0x50, 0x51,
167     0x52, 0x54, 0x55, 0x56, 0x57, 0x58, 0x59, 0x5a, 0x5c, 0x5d, 0x5e, 0x5f, 0x60, 0x61, 0x62, 0x63,
168     0x64, 0x66, 0x67, 0x68, 0x69, 0x6a, 0x6b, 0x6c, 0x6d, 0x6e, 0x6f, 0x70, 0x71, 0x72, 0x74, 0x75,
169     0x76, 0x77, 0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f, 0x80, 0x81, 0x82, 0x83, 0x84, 0x85,
170     0x86, 0x87, 0x88, 0x89, 0x8a, 0x8b, 0x8c, 0x8d, 0x8e, 0x8f, 0x90, 0x91, 0x92, 0x93, 0x94, 0x95,
171     0x96, 0x97, 0x98, 0x99, 0x9a, 0x9b, 0x9b, 0x9c, 0x9d, 0x9e, 0x9f, 0xa0, 0xa1, 0xa2, 0xa3, 0xa4,
172     0xa5, 0xa6, 0xa7, 0xa8, 0xa9, 0xa9, 0xaa, 0xab, 0xac, 0xad, 0xae, 0xaf, 0xb0, 0xb1, 0xb2, 0xb2,
173     0xb3, 0xb4, 0xb5, 0xb6, 0xb7, 0xb8, 0xb9, 0xb9, 0xba, 0xbb, 0xbc, 0xbd, 0xbe, 0xbf, 0xc0, 0xc0,
174     0xc1, 0xc2, 0xc3, 0xc4, 0xc5, 0xc6, 0xc6, 0xc7, 0xc8, 0xc9, 0xca, 0xcb, 0xcb, 0xcc, 0xcd, 0xce,
175     0xcf, 0xd0, 0xd0, 0xd1, 0xd2, 0xd3, 0xd4, 0xd4, 0xd5, 0xd6, 0xd7, 0xd8, 0xd8, 0xd9, 0xda, 0xdb,
176     0xdc, 0xdc, 0xdd, 0xde, 0xdf, 0xe0, 0xe0, 0xe1, 0xe2, 0xe3, 0xe4, 0xe4, 0xe5, 0xe6, 0xe7, 0xe7,
177     0xe8, 0xe9, 0xea, 0xea, 0xeb, 0xec, 0xed, 0xee, 0xee, 0xef, 0xf0, 0xf1, 0xf1, 0xf2, 0xf3, 0xf4,
178     0xf4, 0xf5, 0xf6, 0xf7, 0xf7, 0xf8, 0xf9, 0xf9, 0xfa, 0xfb, 0xfc, 0xfc, 0xfd, 0xfe, 0xff, 0xff
179 };
180
181 static av_always_inline int wp_exp2(int16_t val)
182 {
183     int res, neg = 0;
184
185     if (val < 0) {
186         val = -val;
187         neg = 1;
188     }
189
190     res   = wp_exp2_table[val & 0xFF] | 0x100;
191     val >>= 8;
192     res   = (val > 9) ? (res << (val - 9)) : (res >> (9 - val));
193     return neg ? -res : res;
194 }
195
196 static av_always_inline int wp_log2(int32_t val)
197 {
198     int bits;
199
200     if (!val)
201         return 0;
202     if (val == 1)
203         return 256;
204     val += val >> 9;
205     bits = av_log2(val) + 1;
206     if (bits < 9)
207         return (bits << 8) + wp_log2_table[(val << (9 - bits)) & 0xFF];
208     else
209         return (bits << 8) + wp_log2_table[(val >> (bits - 9)) & 0xFF];
210 }
211
212 #define LEVEL_DECAY(a)  ((a + 0x80) >> 8)
213
214 // macros for manipulating median values
215 #define GET_MED(n) ((c->median[n] >> 4) + 1)
216 #define DEC_MED(n) c->median[n] -= ((c->median[n] + (128 >> n) - 2) / (128 >> n)) * 2
217 #define INC_MED(n) c->median[n] += ((c->median[n] + (128 >> n)    ) / (128 >> n)) * 5
218
219 // macros for applying weight
220 #define UPDATE_WEIGHT_CLIP(weight, delta, samples, in) \
221     if (samples && in) { \
222         if ((samples ^ in) < 0) { \
223             weight -= delta; \
224             if (weight < -1024) \
225                 weight = -1024; \
226         } else { \
227             weight += delta; \
228             if (weight > 1024) \
229                 weight = 1024; \
230         } \
231     }
232
233 static av_always_inline int get_tail(GetBitContext *gb, int k)
234 {
235     int p, e, res;
236
237     if (k < 1)
238         return 0;
239     p   = av_log2(k);
240     e   = (1 << (p + 1)) - k - 1;
241     res = p ? get_bits(gb, p) : 0;
242     if (res >= e)
243         res = (res << 1) - e + get_bits1(gb);
244     return res;
245 }
246
247 static void update_error_limit(WavpackFrameContext *ctx)
248 {
249     int i, br[2], sl[2];
250
251     for (i = 0; i <= ctx->stereo_in; i++) {
252         ctx->ch[i].bitrate_acc += ctx->ch[i].bitrate_delta;
253         br[i]                   = ctx->ch[i].bitrate_acc >> 16;
254         sl[i]                   = LEVEL_DECAY(ctx->ch[i].slow_level);
255     }
256     if (ctx->stereo_in && ctx->hybrid_bitrate) {
257         int balance = (sl[1] - sl[0] + br[1] + 1) >> 1;
258         if (balance > br[0]) {
259             br[1] = br[0] << 1;
260             br[0] = 0;
261         } else if (-balance > br[0]) {
262             br[0] <<= 1;
263             br[1]   = 0;
264         } else {
265             br[1] = br[0] + balance;
266             br[0] = br[0] - balance;
267         }
268     }
269     for (i = 0; i <= ctx->stereo_in; i++) {
270         if (ctx->hybrid_bitrate) {
271             if (sl[i] - br[i] > -0x100)
272                 ctx->ch[i].error_limit = wp_exp2(sl[i] - br[i] + 0x100);
273             else
274                 ctx->ch[i].error_limit = 0;
275         } else {
276             ctx->ch[i].error_limit = wp_exp2(br[i]);
277         }
278     }
279 }
280
281 static int wv_get_value(WavpackFrameContext *ctx, GetBitContext *gb,
282                         int channel, int *last)
283 {
284     int t, t2;
285     int sign, base, add, ret;
286     WvChannel *c = &ctx->ch[channel];
287
288     *last = 0;
289
290     if ((ctx->ch[0].median[0] < 2U) && (ctx->ch[1].median[0] < 2U) &&
291         !ctx->zero && !ctx->one) {
292         if (ctx->zeroes) {
293             ctx->zeroes--;
294             if (ctx->zeroes) {
295                 c->slow_level -= LEVEL_DECAY(c->slow_level);
296                 return 0;
297             }
298         } else {
299             t = get_unary_0_33(gb);
300             if (t >= 2) {
301                 if (get_bits_left(gb) < t - 1)
302                     goto error;
303                 t = get_bits(gb, t - 1) | (1 << (t - 1));
304             } else {
305                 if (get_bits_left(gb) < 0)
306                     goto error;
307             }
308             ctx->zeroes = t;
309             if (ctx->zeroes) {
310                 memset(ctx->ch[0].median, 0, sizeof(ctx->ch[0].median));
311                 memset(ctx->ch[1].median, 0, sizeof(ctx->ch[1].median));
312                 c->slow_level -= LEVEL_DECAY(c->slow_level);
313                 return 0;
314             }
315         }
316     }
317
318     if (ctx->zero) {
319         t         = 0;
320         ctx->zero = 0;
321     } else {
322         t = get_unary_0_33(gb);
323         if (get_bits_left(gb) < 0)
324             goto error;
325         if (t == 16) {
326             t2 = get_unary_0_33(gb);
327             if (t2 < 2) {
328                 if (get_bits_left(gb) < 0)
329                     goto error;
330                 t += t2;
331             } else {
332                 if (get_bits_left(gb) < t2 - 1)
333                     goto error;
334                 t += get_bits(gb, t2 - 1) | (1 << (t2 - 1));
335             }
336         }
337
338         if (ctx->one) {
339             ctx->one = t & 1;
340             t        = (t >> 1) + 1;
341         } else {
342             ctx->one = t & 1;
343             t      >>= 1;
344         }
345         ctx->zero = !ctx->one;
346     }
347
348     if (ctx->hybrid && !channel)
349         update_error_limit(ctx);
350
351     if (!t) {
352         base = 0;
353         add  = GET_MED(0) - 1;
354         DEC_MED(0);
355     } else if (t == 1) {
356         base = GET_MED(0);
357         add  = GET_MED(1) - 1;
358         INC_MED(0);
359         DEC_MED(1);
360     } else if (t == 2) {
361         base = GET_MED(0) + GET_MED(1);
362         add  = GET_MED(2) - 1;
363         INC_MED(0);
364         INC_MED(1);
365         DEC_MED(2);
366     } else {
367         base = GET_MED(0) + GET_MED(1) + GET_MED(2) * (t - 2);
368         add  = GET_MED(2) - 1;
369         INC_MED(0);
370         INC_MED(1);
371         INC_MED(2);
372     }
373     if (!c->error_limit) {
374         ret = base + get_tail(gb, add);
375         if (get_bits_left(gb) <= 0)
376             goto error;
377     } else {
378         int mid = (base * 2 + add + 1) >> 1;
379         while (add > c->error_limit) {
380             if (get_bits_left(gb) <= 0)
381                 goto error;
382             if (get_bits1(gb)) {
383                 add -= (mid - base);
384                 base = mid;
385             } else
386                 add = mid - base - 1;
387             mid = (base * 2 + add + 1) >> 1;
388         }
389         ret = mid;
390     }
391     sign = get_bits1(gb);
392     if (ctx->hybrid_bitrate)
393         c->slow_level += wp_log2(ret) - LEVEL_DECAY(c->slow_level);
394     return sign ? ~ret : ret;
395
396 error:
397     *last = 1;
398     return 0;
399 }
400
401 static inline int wv_get_value_integer(WavpackFrameContext *s, uint32_t *crc,
402                                        int S)
403 {
404     int bit;
405
406     if (s->extra_bits) {
407         S <<= s->extra_bits;
408
409         if (s->got_extra_bits &&
410             get_bits_left(&s->gb_extra_bits) >= s->extra_bits) {
411             S   |= get_bits(&s->gb_extra_bits, s->extra_bits);
412             *crc = *crc * 9 + (S & 0xffff) * 3 + ((unsigned)S >> 16);
413         }
414     }
415
416     bit = (S & s->and) | s->or;
417     bit = ((S + bit) << s->shift) - bit;
418
419     if (s->hybrid)
420         bit = av_clip(bit, s->hybrid_minclip, s->hybrid_maxclip);
421
422     return bit << s->post_shift;
423 }
424
425 static float wv_get_value_float(WavpackFrameContext *s, uint32_t *crc, int S)
426 {
427     union {
428         float    f;
429         uint32_t u;
430     } value;
431
432     unsigned int sign;
433     int exp = s->float_max_exp;
434
435     if (s->got_extra_bits) {
436         const int max_bits  = 1 + 23 + 8 + 1;
437         const int left_bits = get_bits_left(&s->gb_extra_bits);
438
439         if (left_bits + 8 * FF_INPUT_BUFFER_PADDING_SIZE < max_bits)
440             return 0.0;
441     }
442
443     if (S) {
444         S  <<= s->float_shift;
445         sign = S < 0;
446         if (sign)
447             S = -S;
448         if (S >= 0x1000000) {
449             if (s->got_extra_bits && get_bits1(&s->gb_extra_bits))
450                 S = get_bits(&s->gb_extra_bits, 23);
451             else
452                 S = 0;
453             exp = 255;
454         } else if (exp) {
455             int shift = 23 - av_log2(S);
456             exp = s->float_max_exp;
457             if (exp <= shift)
458                 shift = --exp;
459             exp -= shift;
460
461             if (shift) {
462                 S <<= shift;
463                 if ((s->float_flag & WV_FLT_SHIFT_ONES) ||
464                     (s->got_extra_bits &&
465                      (s->float_flag & WV_FLT_SHIFT_SAME) &&
466                      get_bits1(&s->gb_extra_bits))) {
467                     S |= (1 << shift) - 1;
468                 } else if (s->got_extra_bits &&
469                            (s->float_flag & WV_FLT_SHIFT_SENT)) {
470                     S |= get_bits(&s->gb_extra_bits, shift);
471                 }
472             }
473         } else {
474             exp = s->float_max_exp;
475         }
476         S &= 0x7fffff;
477     } else {
478         sign = 0;
479         exp  = 0;
480         if (s->got_extra_bits && (s->float_flag & WV_FLT_ZERO_SENT)) {
481             if (get_bits1(&s->gb_extra_bits)) {
482                 S = get_bits(&s->gb_extra_bits, 23);
483                 if (s->float_max_exp >= 25)
484                     exp = get_bits(&s->gb_extra_bits, 8);
485                 sign = get_bits1(&s->gb_extra_bits);
486             } else {
487                 if (s->float_flag & WV_FLT_ZERO_SIGN)
488                     sign = get_bits1(&s->gb_extra_bits);
489             }
490         }
491     }
492
493     *crc = *crc * 27 + S * 9 + exp * 3 + sign;
494
495     value.u = (sign << 31) | (exp << 23) | S;
496     return value.f;
497 }
498
499 static void wv_reset_saved_context(WavpackFrameContext *s)
500 {
501     s->pos    = 0;
502     s->sc.crc = s->extra_sc.crc = 0xFFFFFFFF;
503 }
504
505 static inline int wv_check_crc(WavpackFrameContext *s, uint32_t crc,
506                                uint32_t crc_extra_bits)
507 {
508     if (crc != s->CRC) {
509         av_log(s->avctx, AV_LOG_ERROR, "CRC error\n");
510         return AVERROR_INVALIDDATA;
511     }
512     if (s->got_extra_bits && crc_extra_bits != s->crc_extra_bits) {
513         av_log(s->avctx, AV_LOG_ERROR, "Extra bits CRC error\n");
514         return AVERROR_INVALIDDATA;
515     }
516
517     return 0;
518 }
519
520 static inline int wv_unpack_stereo(WavpackFrameContext *s, GetBitContext *gb,
521                                    void *dst, const int type)
522 {
523     int i, j, count = 0;
524     int last, t;
525     int A, B, L, L2, R, R2;
526     int pos                 = s->pos;
527     uint32_t crc            = s->sc.crc;
528     uint32_t crc_extra_bits = s->extra_sc.crc;
529     int16_t *dst16          = dst;
530     int32_t *dst32          = dst;
531     float *dstfl            = dst;
532     const int channel_pad   = s->avctx->channels - 2;
533
534     s->one = s->zero = s->zeroes = 0;
535     do {
536         L = wv_get_value(s, gb, 0, &last);
537         if (last)
538             break;
539         R = wv_get_value(s, gb, 1, &last);
540         if (last)
541             break;
542         for (i = 0; i < s->terms; i++) {
543             t = s->decorr[i].value;
544             if (t > 0) {
545                 if (t > 8) {
546                     if (t & 1) {
547                         A = 2 * s->decorr[i].samplesA[0] - s->decorr[i].samplesA[1];
548                         B = 2 * s->decorr[i].samplesB[0] - s->decorr[i].samplesB[1];
549                     } else {
550                         A = (3 * s->decorr[i].samplesA[0] - s->decorr[i].samplesA[1]) >> 1;
551                         B = (3 * s->decorr[i].samplesB[0] - s->decorr[i].samplesB[1]) >> 1;
552                     }
553                     s->decorr[i].samplesA[1] = s->decorr[i].samplesA[0];
554                     s->decorr[i].samplesB[1] = s->decorr[i].samplesB[0];
555                     j                        = 0;
556                 } else {
557                     A = s->decorr[i].samplesA[pos];
558                     B = s->decorr[i].samplesB[pos];
559                     j = (pos + t) & 7;
560                 }
561                 if (type != AV_SAMPLE_FMT_S16) {
562                     L2 = L + ((s->decorr[i].weightA * (int64_t)A + 512) >> 10);
563                     R2 = R + ((s->decorr[i].weightB * (int64_t)B + 512) >> 10);
564                 } else {
565                     L2 = L + ((s->decorr[i].weightA * A + 512) >> 10);
566                     R2 = R + ((s->decorr[i].weightB * B + 512) >> 10);
567                 }
568                 if (A && L)
569                     s->decorr[i].weightA -= ((((L ^ A) >> 30) & 2) - 1) * s->decorr[i].delta;
570                 if (B && R)
571                     s->decorr[i].weightB -= ((((R ^ B) >> 30) & 2) - 1) * s->decorr[i].delta;
572                 s->decorr[i].samplesA[j] = L = L2;
573                 s->decorr[i].samplesB[j] = R = R2;
574             } else if (t == -1) {
575                 if (type != AV_SAMPLE_FMT_S16)
576                     L2 = L + ((s->decorr[i].weightA * (int64_t)s->decorr[i].samplesA[0] + 512) >> 10);
577                 else
578                     L2 = L + ((s->decorr[i].weightA * s->decorr[i].samplesA[0] + 512) >> 10);
579                 UPDATE_WEIGHT_CLIP(s->decorr[i].weightA, s->decorr[i].delta, s->decorr[i].samplesA[0], L);
580                 L = L2;
581                 if (type != AV_SAMPLE_FMT_S16)
582                     R2 = R + ((s->decorr[i].weightB * (int64_t)L2 + 512) >> 10);
583                 else
584                     R2 = R + ((s->decorr[i].weightB * L2 + 512) >> 10);
585                 UPDATE_WEIGHT_CLIP(s->decorr[i].weightB, s->decorr[i].delta, L2, R);
586                 R                        = R2;
587                 s->decorr[i].samplesA[0] = R;
588             } else {
589                 if (type != AV_SAMPLE_FMT_S16)
590                     R2 = R + ((s->decorr[i].weightB * (int64_t)s->decorr[i].samplesB[0] + 512) >> 10);
591                 else
592                     R2 = R + ((s->decorr[i].weightB * s->decorr[i].samplesB[0] + 512) >> 10);
593                 UPDATE_WEIGHT_CLIP(s->decorr[i].weightB, s->decorr[i].delta, s->decorr[i].samplesB[0], R);
594                 R = R2;
595
596                 if (t == -3) {
597                     R2                       = s->decorr[i].samplesA[0];
598                     s->decorr[i].samplesA[0] = R;
599                 }
600
601                 if (type != AV_SAMPLE_FMT_S16)
602                     L2 = L + ((s->decorr[i].weightA * (int64_t)R2 + 512) >> 10);
603                 else
604                     L2 = L + ((s->decorr[i].weightA * R2 + 512) >> 10);
605                 UPDATE_WEIGHT_CLIP(s->decorr[i].weightA, s->decorr[i].delta, R2, L);
606                 L                        = L2;
607                 s->decorr[i].samplesB[0] = L;
608             }
609         }
610         pos = (pos + 1) & 7;
611         if (s->joint)
612             L += (R -= (L >> 1));
613         crc = (crc * 3 + L) * 3 + R;
614
615         if (type == AV_SAMPLE_FMT_FLT) {
616             *dstfl++ = wv_get_value_float(s, &crc_extra_bits, L);
617             *dstfl++ = wv_get_value_float(s, &crc_extra_bits, R);
618             dstfl   += channel_pad;
619         } else if (type == AV_SAMPLE_FMT_S32) {
620             *dst32++ = wv_get_value_integer(s, &crc_extra_bits, L);
621             *dst32++ = wv_get_value_integer(s, &crc_extra_bits, R);
622             dst32   += channel_pad;
623         } else {
624             *dst16++ = wv_get_value_integer(s, &crc_extra_bits, L);
625             *dst16++ = wv_get_value_integer(s, &crc_extra_bits, R);
626             dst16   += channel_pad;
627         }
628         count++;
629     } while (!last && count < s->samples);
630
631     wv_reset_saved_context(s);
632     if ((s->avctx->err_recognition & AV_EF_CRCCHECK) &&
633         wv_check_crc(s, crc, crc_extra_bits))
634         return AVERROR_INVALIDDATA;
635
636     return count * 2;
637 }
638
639 static inline int wv_unpack_mono(WavpackFrameContext *s, GetBitContext *gb,
640                                  void *dst, const int type)
641 {
642     int i, j, count = 0;
643     int last, t;
644     int A, S, T;
645     int pos                  = s->pos;
646     uint32_t crc             = s->sc.crc;
647     uint32_t crc_extra_bits  = s->extra_sc.crc;
648     int16_t *dst16           = dst;
649     int32_t *dst32           = dst;
650     float *dstfl             = dst;
651     const int channel_stride = s->avctx->channels;
652
653     s->one = s->zero = s->zeroes = 0;
654     do {
655         T = wv_get_value(s, gb, 0, &last);
656         S = 0;
657         if (last)
658             break;
659         for (i = 0; i < s->terms; i++) {
660             t = s->decorr[i].value;
661             if (t > 8) {
662                 if (t & 1)
663                     A =  2 * s->decorr[i].samplesA[0] - s->decorr[i].samplesA[1];
664                 else
665                     A = (3 * s->decorr[i].samplesA[0] - s->decorr[i].samplesA[1]) >> 1;
666                 s->decorr[i].samplesA[1] = s->decorr[i].samplesA[0];
667                 j                        = 0;
668             } else {
669                 A = s->decorr[i].samplesA[pos];
670                 j = (pos + t) & 7;
671             }
672             if (type != AV_SAMPLE_FMT_S16)
673                 S = T + ((s->decorr[i].weightA * (int64_t)A + 512) >> 10);
674             else
675                 S = T + ((s->decorr[i].weightA * A + 512) >> 10);
676             if (A && T)
677                 s->decorr[i].weightA -= ((((T ^ A) >> 30) & 2) - 1) * s->decorr[i].delta;
678             s->decorr[i].samplesA[j] = T = S;
679         }
680         pos = (pos + 1) & 7;
681         crc = crc * 3 + S;
682
683         if (type == AV_SAMPLE_FMT_FLT) {
684             *dstfl = wv_get_value_float(s, &crc_extra_bits, S);
685             dstfl += channel_stride;
686         } else if (type == AV_SAMPLE_FMT_S32) {
687             *dst32 = wv_get_value_integer(s, &crc_extra_bits, S);
688             dst32 += channel_stride;
689         } else {
690             *dst16 = wv_get_value_integer(s, &crc_extra_bits, S);
691             dst16 += channel_stride;
692         }
693         count++;
694     } while (!last && count < s->samples);
695
696     wv_reset_saved_context(s);
697     if ((s->avctx->err_recognition & AV_EF_CRCCHECK) &&
698         wv_check_crc(s, crc, crc_extra_bits))
699         return AVERROR_INVALIDDATA;
700
701     return count;
702 }
703
704 static av_cold int wv_alloc_frame_context(WavpackContext *c)
705 {
706     if (c->fdec_num == WV_MAX_FRAME_DECODERS)
707         return -1;
708
709     c->fdec[c->fdec_num] = av_mallocz(sizeof(**c->fdec));
710     if (!c->fdec[c->fdec_num])
711         return -1;
712     c->fdec_num++;
713     c->fdec[c->fdec_num - 1]->avctx = c->avctx;
714     wv_reset_saved_context(c->fdec[c->fdec_num - 1]);
715
716     return 0;
717 }
718
719 static av_cold int wavpack_decode_init(AVCodecContext *avctx)
720 {
721     WavpackContext *s = avctx->priv_data;
722
723     s->avctx = avctx;
724     if (avctx->bits_per_coded_sample <= 16)
725         avctx->sample_fmt = AV_SAMPLE_FMT_S16;
726     else
727         avctx->sample_fmt = AV_SAMPLE_FMT_S32;
728     if (avctx->channels <= 2 && !avctx->channel_layout)
729         avctx->channel_layout = (avctx->channels == 2) ? AV_CH_LAYOUT_STEREO
730                                                        : AV_CH_LAYOUT_MONO;
731
732     s->multichannel = avctx->channels > 2;
733     /* lavf demuxer does not provide extradata, Matroska stores 0x403
734      * there, use this to detect decoding mode for multichannel */
735     s->mkv_mode = 0;
736     if (s->multichannel && avctx->extradata && avctx->extradata_size == 2) {
737         int ver = AV_RL16(avctx->extradata);
738         if (ver >= 0x402 && ver <= 0x410)
739             s->mkv_mode = 1;
740     }
741
742     s->fdec_num = 0;
743
744     return 0;
745 }
746
747 static av_cold int wavpack_decode_end(AVCodecContext *avctx)
748 {
749     WavpackContext *s = avctx->priv_data;
750     int i;
751
752     for (i = 0; i < s->fdec_num; i++)
753         av_freep(&s->fdec[i]);
754     s->fdec_num = 0;
755
756     return 0;
757 }
758
759 static int wavpack_decode_block(AVCodecContext *avctx, int block_no,
760                                 void *data, int *got_frame_ptr,
761                                 const uint8_t *buf, int buf_size)
762 {
763     WavpackContext *wc = avctx->priv_data;
764     WavpackFrameContext *s;
765     GetByteContext gb;
766     void *samples = data;
767     int samplecount;
768     int got_terms   = 0, got_weights = 0, got_samples = 0,
769         got_entropy = 0, got_bs      = 0, got_float   = 0, got_hybrid = 0;
770     int i, j, id, size, ssize, weights, t;
771     int bpp, chan, chmask, orig_bpp;
772
773     if (buf_size == 0) {
774         *got_frame_ptr = 0;
775         return 0;
776     }
777
778     if (block_no >= wc->fdec_num && wv_alloc_frame_context(wc) < 0) {
779         av_log(avctx, AV_LOG_ERROR, "Error creating frame decode context\n");
780         return AVERROR_INVALIDDATA;
781     }
782
783     s = wc->fdec[block_no];
784     if (!s) {
785         av_log(avctx, AV_LOG_ERROR, "Context for block %d is not present\n",
786                block_no);
787         return AVERROR_INVALIDDATA;
788     }
789
790     memset(s->decorr, 0, MAX_TERMS * sizeof(Decorr));
791     memset(s->ch, 0, sizeof(s->ch));
792     s->extra_bits     = 0;
793     s->and            = s->or = s->shift = 0;
794     s->got_extra_bits = 0;
795
796     bytestream2_init(&gb, buf, buf_size);
797
798     if (!wc->mkv_mode) {
799         s->samples = bytestream2_get_le32(&gb);
800         if (s->samples != wc->samples)
801             return AVERROR_INVALIDDATA;
802
803         if (!s->samples) {
804             *got_frame_ptr = 0;
805             return 0;
806         }
807     } else {
808         s->samples = wc->samples;
809     }
810     s->frame_flags = bytestream2_get_le32(&gb);
811     bpp            = av_get_bytes_per_sample(avctx->sample_fmt);
812     samples        = (uint8_t *)samples + bpp * wc->ch_offset;
813     orig_bpp       = ((s->frame_flags & 0x03) + 1) << 3;
814
815     s->stereo         = !(s->frame_flags & WV_MONO);
816     s->stereo_in      =  (s->frame_flags & WV_FALSE_STEREO) ? 0 : s->stereo;
817     s->joint          =   s->frame_flags & WV_JOINT_STEREO;
818     s->hybrid         =   s->frame_flags & WV_HYBRID_MODE;
819     s->hybrid_bitrate =   s->frame_flags & WV_HYBRID_BITRATE;
820     s->post_shift     = bpp * 8 - orig_bpp + ((s->frame_flags >> 13) & 0x1f);
821     s->hybrid_maxclip =  ((1LL << (orig_bpp - 1)) - 1);
822     s->hybrid_minclip = ((-1LL << (orig_bpp - 1)));
823     s->CRC            = bytestream2_get_le32(&gb);
824
825     if (wc->mkv_mode)
826         bytestream2_skip(&gb, 4);  // skip block size;
827
828     wc->ch_offset += 1 + s->stereo;
829
830     // parse metadata blocks
831     while (bytestream2_get_bytes_left(&gb)) {
832         id   = bytestream2_get_byte(&gb);
833         size = bytestream2_get_byte(&gb);
834         if (id & WP_IDF_LONG) {
835             size |= (bytestream2_get_byte(&gb)) << 8;
836             size |= (bytestream2_get_byte(&gb)) << 16;
837         }
838         size <<= 1; // size is specified in words
839         ssize  = size;
840         if (id & WP_IDF_ODD)
841             size--;
842         if (size < 0) {
843             av_log(avctx, AV_LOG_ERROR,
844                    "Got incorrect block %02X with size %i\n", id, size);
845             break;
846         }
847         if (bytestream2_get_bytes_left(&gb) < ssize) {
848             av_log(avctx, AV_LOG_ERROR,
849                    "Block size %i is out of bounds\n", size);
850             break;
851         }
852         if (id & WP_IDF_IGNORE) {
853             bytestream2_skip(&gb, ssize);
854             continue;
855         }
856         switch (id & WP_IDF_MASK) {
857         case WP_ID_DECTERMS:
858             if (size > MAX_TERMS) {
859                 av_log(avctx, AV_LOG_ERROR, "Too many decorrelation terms\n");
860                 s->terms = 0;
861                 bytestream2_skip(&gb, ssize);
862                 continue;
863             }
864             s->terms = size;
865             for (i = 0; i < s->terms; i++) {
866                 uint8_t val = bytestream2_get_byte(&gb);
867                 s->decorr[s->terms - i - 1].value = (val & 0x1F) - 5;
868                 s->decorr[s->terms - i - 1].delta =  val >> 5;
869             }
870             got_terms = 1;
871             break;
872         case WP_ID_DECWEIGHTS:
873             if (!got_terms) {
874                 av_log(avctx, AV_LOG_ERROR, "No decorrelation terms met\n");
875                 continue;
876             }
877             weights = size >> s->stereo_in;
878             if (weights > MAX_TERMS || weights > s->terms) {
879                 av_log(avctx, AV_LOG_ERROR, "Too many decorrelation weights\n");
880                 bytestream2_skip(&gb, ssize);
881                 continue;
882             }
883             for (i = 0; i < weights; i++) {
884                 t = (int8_t)bytestream2_get_byte(&gb);
885                 s->decorr[s->terms - i - 1].weightA = t << 3;
886                 if (s->decorr[s->terms - i - 1].weightA > 0)
887                     s->decorr[s->terms - i - 1].weightA +=
888                         (s->decorr[s->terms - i - 1].weightA + 64) >> 7;
889                 if (s->stereo_in) {
890                     t = (int8_t)bytestream2_get_byte(&gb);
891                     s->decorr[s->terms - i - 1].weightB = t << 3;
892                     if (s->decorr[s->terms - i - 1].weightB > 0)
893                         s->decorr[s->terms - i - 1].weightB +=
894                             (s->decorr[s->terms - i - 1].weightB + 64) >> 7;
895                 }
896             }
897             got_weights = 1;
898             break;
899         case WP_ID_DECSAMPLES:
900             if (!got_terms) {
901                 av_log(avctx, AV_LOG_ERROR, "No decorrelation terms met\n");
902                 continue;
903             }
904             t = 0;
905             for (i = s->terms - 1; (i >= 0) && (t < size); i--) {
906                 if (s->decorr[i].value > 8) {
907                     s->decorr[i].samplesA[0] =
908                         wp_exp2(bytestream2_get_le16(&gb));
909                     s->decorr[i].samplesA[1] =
910                         wp_exp2(bytestream2_get_le16(&gb));
911
912                     if (s->stereo_in) {
913                         s->decorr[i].samplesB[0] =
914                             wp_exp2(bytestream2_get_le16(&gb));
915                         s->decorr[i].samplesB[1] =
916                             wp_exp2(bytestream2_get_le16(&gb));
917                         t                       += 4;
918                     }
919                     t += 4;
920                 } else if (s->decorr[i].value < 0) {
921                     s->decorr[i].samplesA[0] =
922                         wp_exp2(bytestream2_get_le16(&gb));
923                     s->decorr[i].samplesB[0] =
924                         wp_exp2(bytestream2_get_le16(&gb));
925                     t                       += 4;
926                 } else {
927                     for (j = 0; j < s->decorr[i].value; j++) {
928                         s->decorr[i].samplesA[j] =
929                             wp_exp2(bytestream2_get_le16(&gb));
930                         if (s->stereo_in) {
931                             s->decorr[i].samplesB[j] =
932                                 wp_exp2(bytestream2_get_le16(&gb));
933                         }
934                     }
935                     t += s->decorr[i].value * 2 * (s->stereo_in + 1);
936                 }
937             }
938             got_samples = 1;
939             break;
940         case WP_ID_ENTROPY:
941             if (size != 6 * (s->stereo_in + 1)) {
942                 av_log(avctx, AV_LOG_ERROR,
943                        "Entropy vars size should be %i, got %i",
944                        6 * (s->stereo_in + 1), size);
945                 bytestream2_skip(&gb, ssize);
946                 continue;
947             }
948             for (j = 0; j <= s->stereo_in; j++)
949                 for (i = 0; i < 3; i++) {
950                     s->ch[j].median[i] = wp_exp2(bytestream2_get_le16(&gb));
951                 }
952             got_entropy = 1;
953             break;
954         case WP_ID_HYBRID:
955             if (s->hybrid_bitrate) {
956                 for (i = 0; i <= s->stereo_in; i++) {
957                     s->ch[i].slow_level = wp_exp2(bytestream2_get_le16(&gb));
958                     size               -= 2;
959                 }
960             }
961             for (i = 0; i < (s->stereo_in + 1); i++) {
962                 s->ch[i].bitrate_acc = bytestream2_get_le16(&gb) << 16;
963                 size                -= 2;
964             }
965             if (size > 0) {
966                 for (i = 0; i < (s->stereo_in + 1); i++) {
967                     s->ch[i].bitrate_delta =
968                         wp_exp2((int16_t)bytestream2_get_le16(&gb));
969                 }
970             } else {
971                 for (i = 0; i < (s->stereo_in + 1); i++)
972                     s->ch[i].bitrate_delta = 0;
973             }
974             got_hybrid = 1;
975             break;
976         case WP_ID_INT32INFO: {
977             uint8_t val[4];
978             if (size != 4) {
979                 av_log(avctx, AV_LOG_ERROR,
980                        "Invalid INT32INFO, size = %i\n",
981                        size);
982                 bytestream2_skip(&gb, ssize - 4);
983                 continue;
984             }
985             bytestream2_get_buffer(&gb, val, 4);
986             if (val[0]) {
987                 s->extra_bits = val[0];
988             } else if (val[1]) {
989                 s->shift = val[1];
990             } else if (val[2]) {
991                 s->and   = s->or = 1;
992                 s->shift = val[2];
993             } else if (val[3]) {
994                 s->and   = 1;
995                 s->shift = val[3];
996             }
997             /* original WavPack decoder forces 32-bit lossy sound to be treated
998              * as 24-bit one in order to have proper clipping */
999             if (s->hybrid && bpp == 4 && s->post_shift < 8 && s->shift > 8) {
1000                 s->post_shift      += 8;
1001                 s->shift           -= 8;
1002                 s->hybrid_maxclip >>= 8;
1003                 s->hybrid_minclip >>= 8;
1004             }
1005             break;
1006         }
1007         case WP_ID_FLOATINFO:
1008             if (size != 4) {
1009                 av_log(avctx, AV_LOG_ERROR,
1010                        "Invalid FLOATINFO, size = %i\n", size);
1011                 bytestream2_skip(&gb, ssize);
1012                 continue;
1013             }
1014             s->float_flag    = bytestream2_get_byte(&gb);
1015             s->float_shift   = bytestream2_get_byte(&gb);
1016             s->float_max_exp = bytestream2_get_byte(&gb);
1017             got_float        = 1;
1018             bytestream2_skip(&gb, 1);
1019             break;
1020         case WP_ID_DATA:
1021             s->sc.offset = bytestream2_tell(&gb);
1022             s->sc.size   = size * 8;
1023             init_get_bits(&s->gb, gb.buffer, size * 8);
1024             s->data_size = size * 8;
1025             bytestream2_skip(&gb, size);
1026             got_bs       = 1;
1027             break;
1028         case WP_ID_EXTRABITS:
1029             if (size <= 4) {
1030                 av_log(avctx, AV_LOG_ERROR, "Invalid EXTRABITS, size = %i\n",
1031                        size);
1032                 bytestream2_skip(&gb, size);
1033                 continue;
1034             }
1035             s->extra_sc.offset = bytestream2_tell(&gb);
1036             s->extra_sc.size   = size * 8;
1037             init_get_bits(&s->gb_extra_bits, gb.buffer, size * 8);
1038             s->crc_extra_bits  = get_bits_long(&s->gb_extra_bits, 32);
1039             bytestream2_skip(&gb, size);
1040             s->got_extra_bits  = 1;
1041             break;
1042         case WP_ID_CHANINFO:
1043             if (size <= 1) {
1044                 av_log(avctx, AV_LOG_ERROR,
1045                        "Insufficient channel information\n");
1046                 return AVERROR_INVALIDDATA;
1047             }
1048             chan = bytestream2_get_byte(&gb);
1049             switch (size - 2) {
1050             case 0:
1051                 chmask = bytestream2_get_byte(&gb);
1052                 break;
1053             case 1:
1054                 chmask = bytestream2_get_le16(&gb);
1055                 break;
1056             case 2:
1057                 chmask = bytestream2_get_le24(&gb);
1058                 break;
1059             case 3:
1060                 chmask = bytestream2_get_le32(&gb);;
1061                 break;
1062             case 5:
1063                 bytestream2_skip(&gb, 1);
1064                 chan  |= (bytestream2_get_byte(&gb) & 0xF) << 8;
1065                 chmask = bytestream2_get_le16(&gb);
1066                 break;
1067             default:
1068                 av_log(avctx, AV_LOG_ERROR, "Invalid channel info size %d\n",
1069                        size);
1070                 chan   = avctx->channels;
1071                 chmask = avctx->channel_layout;
1072             }
1073             if (chan != avctx->channels) {
1074                 av_log(avctx, AV_LOG_ERROR,
1075                        "Block reports total %d channels, "
1076                        "decoder believes it's %d channels\n",
1077                        chan, avctx->channels);
1078                 return AVERROR_INVALIDDATA;
1079             }
1080             if (!avctx->channel_layout)
1081                 avctx->channel_layout = chmask;
1082             break;
1083         default:
1084             bytestream2_skip(&gb, size);
1085         }
1086         if (id & WP_IDF_ODD)
1087             bytestream2_skip(&gb, 1);
1088     }
1089
1090     if (!got_terms) {
1091         av_log(avctx, AV_LOG_ERROR, "No block with decorrelation terms\n");
1092         return AVERROR_INVALIDDATA;
1093     }
1094     if (!got_weights) {
1095         av_log(avctx, AV_LOG_ERROR, "No block with decorrelation weights\n");
1096         return AVERROR_INVALIDDATA;
1097     }
1098     if (!got_samples) {
1099         av_log(avctx, AV_LOG_ERROR, "No block with decorrelation samples\n");
1100         return AVERROR_INVALIDDATA;
1101     }
1102     if (!got_entropy) {
1103         av_log(avctx, AV_LOG_ERROR, "No block with entropy info\n");
1104         return AVERROR_INVALIDDATA;
1105     }
1106     if (s->hybrid && !got_hybrid) {
1107         av_log(avctx, AV_LOG_ERROR, "Hybrid config not found\n");
1108         return AVERROR_INVALIDDATA;
1109     }
1110     if (!got_bs) {
1111         av_log(avctx, AV_LOG_ERROR, "Packed samples not found\n");
1112         return AVERROR_INVALIDDATA;
1113     }
1114     if (!got_float && avctx->sample_fmt == AV_SAMPLE_FMT_FLT) {
1115         av_log(avctx, AV_LOG_ERROR, "Float information not found\n");
1116         return AVERROR_INVALIDDATA;
1117     }
1118     if (s->got_extra_bits && avctx->sample_fmt != AV_SAMPLE_FMT_FLT) {
1119         const int size   = get_bits_left(&s->gb_extra_bits);
1120         const int wanted = s->samples * s->extra_bits << s->stereo_in;
1121         if (size < wanted) {
1122             av_log(avctx, AV_LOG_ERROR, "Too small EXTRABITS\n");
1123             s->got_extra_bits = 0;
1124         }
1125     }
1126
1127     if (s->stereo_in) {
1128         if (avctx->sample_fmt == AV_SAMPLE_FMT_S16)
1129             samplecount = wv_unpack_stereo(s, &s->gb, samples, AV_SAMPLE_FMT_S16);
1130         else if (avctx->sample_fmt == AV_SAMPLE_FMT_S32)
1131             samplecount = wv_unpack_stereo(s, &s->gb, samples, AV_SAMPLE_FMT_S32);
1132         else
1133             samplecount = wv_unpack_stereo(s, &s->gb, samples, AV_SAMPLE_FMT_FLT);
1134
1135         if (samplecount < 0)
1136             return samplecount;
1137
1138         samplecount >>= 1;
1139     } else {
1140         const int channel_stride = avctx->channels;
1141
1142         if (avctx->sample_fmt == AV_SAMPLE_FMT_S16)
1143             samplecount = wv_unpack_mono(s, &s->gb, samples, AV_SAMPLE_FMT_S16);
1144         else if (avctx->sample_fmt == AV_SAMPLE_FMT_S32)
1145             samplecount = wv_unpack_mono(s, &s->gb, samples, AV_SAMPLE_FMT_S32);
1146         else
1147             samplecount = wv_unpack_mono(s, &s->gb, samples, AV_SAMPLE_FMT_FLT);
1148
1149         if (samplecount < 0)
1150             return samplecount;
1151
1152         if (s->stereo && avctx->sample_fmt == AV_SAMPLE_FMT_S16) {
1153             int16_t *dst = (int16_t *)samples + 1;
1154             int16_t *src = (int16_t *)samples;
1155             int cnt      = samplecount;
1156             while (cnt--) {
1157                 *dst = *src;
1158                 src += channel_stride;
1159                 dst += channel_stride;
1160             }
1161         } else if (s->stereo && avctx->sample_fmt == AV_SAMPLE_FMT_S32) {
1162             int32_t *dst = (int32_t *)samples + 1;
1163             int32_t *src = (int32_t *)samples;
1164             int cnt      = samplecount;
1165             while (cnt--) {
1166                 *dst = *src;
1167                 src += channel_stride;
1168                 dst += channel_stride;
1169             }
1170         } else if (s->stereo) {
1171             float *dst = (float *)samples + 1;
1172             float *src = (float *)samples;
1173             int cnt    = samplecount;
1174             while (cnt--) {
1175                 *dst = *src;
1176                 src += channel_stride;
1177                 dst += channel_stride;
1178             }
1179         }
1180     }
1181
1182     *got_frame_ptr = 1;
1183
1184     return samplecount * bpp;
1185 }
1186
1187 static void wavpack_decode_flush(AVCodecContext *avctx)
1188 {
1189     WavpackContext *s = avctx->priv_data;
1190     int i;
1191
1192     for (i = 0; i < s->fdec_num; i++)
1193         wv_reset_saved_context(s->fdec[i]);
1194 }
1195
1196 static int wavpack_decode_frame(AVCodecContext *avctx, void *data,
1197                                 int *got_frame_ptr, AVPacket *avpkt)
1198 {
1199     WavpackContext *s  = avctx->priv_data;
1200     const uint8_t *buf = avpkt->data;
1201     int buf_size       = avpkt->size;
1202     AVFrame *frame     = data;
1203     int frame_size, ret, frame_flags;
1204     int samplecount = 0;
1205
1206     if (avpkt->size < 12 + s->multichannel * 4)
1207         return AVERROR_INVALIDDATA;
1208
1209     s->block     = 0;
1210     s->ch_offset = 0;
1211
1212     /* determine number of samples */
1213     if (s->mkv_mode) {
1214         s->samples  = AV_RL32(buf);
1215         buf        += 4;
1216         frame_flags = AV_RL32(buf);
1217     } else {
1218         if (s->multichannel) {
1219             s->samples  = AV_RL32(buf + 4);
1220             frame_flags = AV_RL32(buf + 8);
1221         } else {
1222             s->samples  = AV_RL32(buf);
1223             frame_flags = AV_RL32(buf + 4);
1224         }
1225     }
1226     if (s->samples <= 0) {
1227         av_log(avctx, AV_LOG_ERROR, "Invalid number of samples: %d\n",
1228                s->samples);
1229         return AVERROR_INVALIDDATA;
1230     }
1231
1232     if (frame_flags & 0x80) {
1233         avctx->sample_fmt = AV_SAMPLE_FMT_FLT;
1234     } else if ((frame_flags & 0x03) <= 1) {
1235         avctx->sample_fmt = AV_SAMPLE_FMT_S16;
1236     } else {
1237         avctx->sample_fmt          = AV_SAMPLE_FMT_S32;
1238         avctx->bits_per_raw_sample = ((frame_flags & 0x03) + 1) << 3;
1239     }
1240
1241     /* get output buffer */
1242     frame->nb_samples = s->samples;
1243     if ((ret = ff_get_buffer(avctx, frame, 0)) < 0) {
1244         av_log(avctx, AV_LOG_ERROR, "get_buffer() failed\n");
1245         return ret;
1246     }
1247
1248     while (buf_size > 0) {
1249         if (!s->multichannel) {
1250             frame_size = buf_size;
1251         } else {
1252             if (!s->mkv_mode) {
1253                 frame_size = AV_RL32(buf) - 12;
1254                 buf       += 4;
1255                 buf_size  -= 4;
1256             } else {
1257                 if (buf_size < 12) // MKV files can have zero flags after last block
1258                     break;
1259                 frame_size = AV_RL32(buf + 8) + 12;
1260             }
1261         }
1262         if (frame_size < 0 || frame_size > buf_size) {
1263             av_log(avctx, AV_LOG_ERROR,
1264                    "Block %d has invalid size (size %d vs. %d bytes left)\n",
1265                    s->block, frame_size, buf_size);
1266             wavpack_decode_flush(avctx);
1267             return AVERROR_INVALIDDATA;
1268         }
1269         if ((samplecount = wavpack_decode_block(avctx, s->block,
1270                                                 frame->data[0], got_frame_ptr,
1271                                                 buf, frame_size)) < 0) {
1272             wavpack_decode_flush(avctx);
1273             return samplecount;
1274         }
1275         s->block++;
1276         buf      += frame_size;
1277         buf_size -= frame_size;
1278     }
1279
1280     return avpkt->size;
1281 }
1282
1283 AVCodec ff_wavpack_decoder = {
1284     .name           = "wavpack",
1285     .type           = AVMEDIA_TYPE_AUDIO,
1286     .id             = AV_CODEC_ID_WAVPACK,
1287     .priv_data_size = sizeof(WavpackContext),
1288     .init           = wavpack_decode_init,
1289     .close          = wavpack_decode_end,
1290     .decode         = wavpack_decode_frame,
1291     .flush          = wavpack_decode_flush,
1292     .capabilities   = CODEC_CAP_SUBFRAMES | CODEC_CAP_DR1,
1293     .long_name      = NULL_IF_CONFIG_SMALL("WavPack"),
1294 };