]> git.sesse.net Git - ffmpeg/blob - libavcodec/wavpack.c
c618537309816c310f268404ab04cf6676b019b2
[ffmpeg] / libavcodec / wavpack.c
1 /*
2  * WavPack lossless audio decoder
3  * Copyright (c) 2006,2011 Konstantin Shishkov
4  *
5  * This file is part of FFmpeg.
6  *
7  * FFmpeg is free software; you can redistribute it and/or
8  * modify it under the terms of the GNU Lesser General Public
9  * License as published by the Free Software Foundation; either
10  * version 2.1 of the License, or (at your option) any later version.
11  *
12  * FFmpeg is distributed in the hope that it will be useful,
13  * but WITHOUT ANY WARRANTY; without even the implied warranty of
14  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
15  * Lesser General Public License for more details.
16  *
17  * You should have received a copy of the GNU Lesser General Public
18  * License along with FFmpeg; if not, write to the Free Software
19  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
20  */
21
22 #define BITSTREAM_READER_LE
23
24 #include "libavutil/channel_layout.h"
25 #include "avcodec.h"
26 #include "get_bits.h"
27 #include "internal.h"
28 #include "unary.h"
29
30 /**
31  * @file
32  * WavPack lossless audio decoder
33  */
34
35 #define WV_MONO           0x00000004
36 #define WV_JOINT_STEREO   0x00000010
37 #define WV_FALSE_STEREO   0x40000000
38
39 #define WV_HYBRID_MODE    0x00000008
40 #define WV_HYBRID_SHAPE   0x00000008
41 #define WV_HYBRID_BITRATE 0x00000200
42 #define WV_HYBRID_BALANCE 0x00000400
43
44 #define WV_FLT_SHIFT_ONES 0x01
45 #define WV_FLT_SHIFT_SAME 0x02
46 #define WV_FLT_SHIFT_SENT 0x04
47 #define WV_FLT_ZERO_SENT  0x08
48 #define WV_FLT_ZERO_SIGN  0x10
49
50 #define WV_MAX_SAMPLES    131072
51
52 enum WP_ID_Flags {
53     WP_IDF_MASK   = 0x1F,
54     WP_IDF_IGNORE = 0x20,
55     WP_IDF_ODD    = 0x40,
56     WP_IDF_LONG   = 0x80
57 };
58
59 enum WP_ID {
60     WP_ID_DUMMY = 0,
61     WP_ID_ENCINFO,
62     WP_ID_DECTERMS,
63     WP_ID_DECWEIGHTS,
64     WP_ID_DECSAMPLES,
65     WP_ID_ENTROPY,
66     WP_ID_HYBRID,
67     WP_ID_SHAPING,
68     WP_ID_FLOATINFO,
69     WP_ID_INT32INFO,
70     WP_ID_DATA,
71     WP_ID_CORR,
72     WP_ID_EXTRABITS,
73     WP_ID_CHANINFO
74 };
75
76 typedef struct SavedContext {
77     int offset;
78     int size;
79     int bits_used;
80     uint32_t crc;
81 } SavedContext;
82
83 #define MAX_TERMS 16
84
85 typedef struct Decorr {
86     int delta;
87     int value;
88     int weightA;
89     int weightB;
90     int samplesA[8];
91     int samplesB[8];
92 } Decorr;
93
94 typedef struct WvChannel {
95     int median[3];
96     int slow_level, error_limit;
97     int bitrate_acc, bitrate_delta;
98 } WvChannel;
99
100 typedef struct WavpackFrameContext {
101     AVCodecContext *avctx;
102     int frame_flags;
103     int stereo, stereo_in;
104     int joint;
105     uint32_t CRC;
106     GetBitContext gb;
107     int got_extra_bits;
108     uint32_t crc_extra_bits;
109     GetBitContext gb_extra_bits;
110     int data_size; // in bits
111     int samples;
112     int terms;
113     Decorr decorr[MAX_TERMS];
114     int zero, one, zeroes;
115     int extra_bits;
116     int and, or, shift;
117     int post_shift;
118     int hybrid, hybrid_bitrate;
119     int hybrid_maxclip, hybrid_minclip;
120     int float_flag;
121     int float_shift;
122     int float_max_exp;
123     WvChannel ch[2];
124     int pos;
125     SavedContext sc, extra_sc;
126 } WavpackFrameContext;
127
128 #define WV_MAX_FRAME_DECODERS 14
129
130 typedef struct WavpackContext {
131     AVCodecContext *avctx;
132
133     WavpackFrameContext *fdec[WV_MAX_FRAME_DECODERS];
134     int fdec_num;
135
136     int multichannel;
137     int mkv_mode;
138     int block;
139     int samples;
140     int ch_offset;
141 } WavpackContext;
142
143 // exponent table copied from WavPack source
144 static const uint8_t wp_exp2_table[256] = {
145     0x00, 0x01, 0x01, 0x02, 0x03, 0x03, 0x04, 0x05, 0x06, 0x06, 0x07, 0x08, 0x08, 0x09, 0x0a, 0x0b,
146     0x0b, 0x0c, 0x0d, 0x0e, 0x0e, 0x0f, 0x10, 0x10, 0x11, 0x12, 0x13, 0x13, 0x14, 0x15, 0x16, 0x16,
147     0x17, 0x18, 0x19, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1d, 0x1e, 0x1f, 0x20, 0x20, 0x21, 0x22, 0x23,
148     0x24, 0x24, 0x25, 0x26, 0x27, 0x28, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2c, 0x2d, 0x2e, 0x2f, 0x30,
149     0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3a, 0x3a, 0x3b, 0x3c, 0x3d,
150     0x3e, 0x3f, 0x40, 0x41, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48, 0x48, 0x49, 0x4a, 0x4b,
151     0x4c, 0x4d, 0x4e, 0x4f, 0x50, 0x51, 0x51, 0x52, 0x53, 0x54, 0x55, 0x56, 0x57, 0x58, 0x59, 0x5a,
152     0x5b, 0x5c, 0x5d, 0x5e, 0x5e, 0x5f, 0x60, 0x61, 0x62, 0x63, 0x64, 0x65, 0x66, 0x67, 0x68, 0x69,
153     0x6a, 0x6b, 0x6c, 0x6d, 0x6e, 0x6f, 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79,
154     0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f, 0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x87, 0x88, 0x89, 0x8a,
155     0x8b, 0x8c, 0x8d, 0x8e, 0x8f, 0x90, 0x91, 0x92, 0x93, 0x95, 0x96, 0x97, 0x98, 0x99, 0x9a, 0x9b,
156     0x9c, 0x9d, 0x9f, 0xa0, 0xa1, 0xa2, 0xa3, 0xa4, 0xa5, 0xa6, 0xa8, 0xa9, 0xaa, 0xab, 0xac, 0xad,
157     0xaf, 0xb0, 0xb1, 0xb2, 0xb3, 0xb4, 0xb6, 0xb7, 0xb8, 0xb9, 0xba, 0xbc, 0xbd, 0xbe, 0xbf, 0xc0,
158     0xc2, 0xc3, 0xc4, 0xc5, 0xc6, 0xc8, 0xc9, 0xca, 0xcb, 0xcd, 0xce, 0xcf, 0xd0, 0xd2, 0xd3, 0xd4,
159     0xd6, 0xd7, 0xd8, 0xd9, 0xdb, 0xdc, 0xdd, 0xde, 0xe0, 0xe1, 0xe2, 0xe4, 0xe5, 0xe6, 0xe8, 0xe9,
160     0xea, 0xec, 0xed, 0xee, 0xf0, 0xf1, 0xf2, 0xf4, 0xf5, 0xf6, 0xf8, 0xf9, 0xfa, 0xfc, 0xfd, 0xff
161 };
162
163 static const uint8_t wp_log2_table [] = {
164     0x00, 0x01, 0x03, 0x04, 0x06, 0x07, 0x09, 0x0a, 0x0b, 0x0d, 0x0e, 0x10, 0x11, 0x12, 0x14, 0x15,
165     0x16, 0x18, 0x19, 0x1a, 0x1c, 0x1d, 0x1e, 0x20, 0x21, 0x22, 0x24, 0x25, 0x26, 0x28, 0x29, 0x2a,
166     0x2c, 0x2d, 0x2e, 0x2f, 0x31, 0x32, 0x33, 0x34, 0x36, 0x37, 0x38, 0x39, 0x3b, 0x3c, 0x3d, 0x3e,
167     0x3f, 0x41, 0x42, 0x43, 0x44, 0x45, 0x47, 0x48, 0x49, 0x4a, 0x4b, 0x4d, 0x4e, 0x4f, 0x50, 0x51,
168     0x52, 0x54, 0x55, 0x56, 0x57, 0x58, 0x59, 0x5a, 0x5c, 0x5d, 0x5e, 0x5f, 0x60, 0x61, 0x62, 0x63,
169     0x64, 0x66, 0x67, 0x68, 0x69, 0x6a, 0x6b, 0x6c, 0x6d, 0x6e, 0x6f, 0x70, 0x71, 0x72, 0x74, 0x75,
170     0x76, 0x77, 0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f, 0x80, 0x81, 0x82, 0x83, 0x84, 0x85,
171     0x86, 0x87, 0x88, 0x89, 0x8a, 0x8b, 0x8c, 0x8d, 0x8e, 0x8f, 0x90, 0x91, 0x92, 0x93, 0x94, 0x95,
172     0x96, 0x97, 0x98, 0x99, 0x9a, 0x9b, 0x9b, 0x9c, 0x9d, 0x9e, 0x9f, 0xa0, 0xa1, 0xa2, 0xa3, 0xa4,
173     0xa5, 0xa6, 0xa7, 0xa8, 0xa9, 0xa9, 0xaa, 0xab, 0xac, 0xad, 0xae, 0xaf, 0xb0, 0xb1, 0xb2, 0xb2,
174     0xb3, 0xb4, 0xb5, 0xb6, 0xb7, 0xb8, 0xb9, 0xb9, 0xba, 0xbb, 0xbc, 0xbd, 0xbe, 0xbf, 0xc0, 0xc0,
175     0xc1, 0xc2, 0xc3, 0xc4, 0xc5, 0xc6, 0xc6, 0xc7, 0xc8, 0xc9, 0xca, 0xcb, 0xcb, 0xcc, 0xcd, 0xce,
176     0xcf, 0xd0, 0xd0, 0xd1, 0xd2, 0xd3, 0xd4, 0xd4, 0xd5, 0xd6, 0xd7, 0xd8, 0xd8, 0xd9, 0xda, 0xdb,
177     0xdc, 0xdc, 0xdd, 0xde, 0xdf, 0xe0, 0xe0, 0xe1, 0xe2, 0xe3, 0xe4, 0xe4, 0xe5, 0xe6, 0xe7, 0xe7,
178     0xe8, 0xe9, 0xea, 0xea, 0xeb, 0xec, 0xed, 0xee, 0xee, 0xef, 0xf0, 0xf1, 0xf1, 0xf2, 0xf3, 0xf4,
179     0xf4, 0xf5, 0xf6, 0xf7, 0xf7, 0xf8, 0xf9, 0xf9, 0xfa, 0xfb, 0xfc, 0xfc, 0xfd, 0xfe, 0xff, 0xff
180 };
181
182 static av_always_inline int wp_exp2(int16_t val)
183 {
184     int res, neg = 0;
185
186     if (val < 0) {
187         val = -val;
188         neg = 1;
189     }
190
191     res   = wp_exp2_table[val & 0xFF] | 0x100;
192     val >>= 8;
193     res   = (val > 9) ? (res << (val - 9)) : (res >> (9 - val));
194     return neg ? -res : res;
195 }
196
197 static av_always_inline int wp_log2(int32_t val)
198 {
199     int bits;
200
201     if (!val)
202         return 0;
203     if (val == 1)
204         return 256;
205     val += val >> 9;
206     bits = av_log2(val) + 1;
207     if (bits < 9)
208         return (bits << 8) + wp_log2_table[(val << (9 - bits)) & 0xFF];
209     else
210         return (bits << 8) + wp_log2_table[(val >> (bits - 9)) & 0xFF];
211 }
212
213 #define LEVEL_DECAY(a)  ((a + 0x80) >> 8)
214
215 // macros for manipulating median values
216 #define GET_MED(n) ((c->median[n] >> 4) + 1)
217 #define DEC_MED(n) c->median[n] -= ((c->median[n] + (128 >> n) - 2) / (128 >> n)) * 2
218 #define INC_MED(n) c->median[n] += ((c->median[n] + (128 >> n)    ) / (128 >> n)) * 5
219
220 // macros for applying weight
221 #define UPDATE_WEIGHT_CLIP(weight, delta, samples, in) \
222     if (samples && in) { \
223         if ((samples ^ in) < 0) { \
224             weight -= delta; \
225             if (weight < -1024) \
226                 weight = -1024; \
227         } else { \
228             weight += delta; \
229             if (weight > 1024) \
230                 weight = 1024; \
231         } \
232     }
233
234 static av_always_inline int get_tail(GetBitContext *gb, int k)
235 {
236     int p, e, res;
237
238     if (k < 1)
239         return 0;
240     p   = av_log2(k);
241     e   = (1 << (p + 1)) - k - 1;
242     res = p ? get_bits(gb, p) : 0;
243     if (res >= e)
244         res = (res << 1) - e + get_bits1(gb);
245     return res;
246 }
247
248 static void update_error_limit(WavpackFrameContext *ctx)
249 {
250     int i, br[2], sl[2];
251
252     for (i = 0; i <= ctx->stereo_in; i++) {
253         ctx->ch[i].bitrate_acc += ctx->ch[i].bitrate_delta;
254         br[i]                   = ctx->ch[i].bitrate_acc >> 16;
255         sl[i]                   = LEVEL_DECAY(ctx->ch[i].slow_level);
256     }
257     if (ctx->stereo_in && ctx->hybrid_bitrate) {
258         int balance = (sl[1] - sl[0] + br[1] + 1) >> 1;
259         if (balance > br[0]) {
260             br[1] = br[0] << 1;
261             br[0] = 0;
262         } else if (-balance > br[0]) {
263             br[0] <<= 1;
264             br[1]   = 0;
265         } else {
266             br[1] = br[0] + balance;
267             br[0] = br[0] - balance;
268         }
269     }
270     for (i = 0; i <= ctx->stereo_in; i++) {
271         if (ctx->hybrid_bitrate) {
272             if (sl[i] - br[i] > -0x100)
273                 ctx->ch[i].error_limit = wp_exp2(sl[i] - br[i] + 0x100);
274             else
275                 ctx->ch[i].error_limit = 0;
276         } else {
277             ctx->ch[i].error_limit = wp_exp2(br[i]);
278         }
279     }
280 }
281
282 static int wv_get_value(WavpackFrameContext *ctx, GetBitContext *gb,
283                         int channel, int *last)
284 {
285     int t, t2;
286     int sign, base, add, ret;
287     WvChannel *c = &ctx->ch[channel];
288
289     *last = 0;
290
291     if ((ctx->ch[0].median[0] < 2U) && (ctx->ch[1].median[0] < 2U) &&
292         !ctx->zero && !ctx->one) {
293         if (ctx->zeroes) {
294             ctx->zeroes--;
295             if (ctx->zeroes) {
296                 c->slow_level -= LEVEL_DECAY(c->slow_level);
297                 return 0;
298             }
299         } else {
300             t = get_unary_0_33(gb);
301             if (t >= 2) {
302                 if (get_bits_left(gb) < t - 1)
303                     goto error;
304                 t = get_bits(gb, t - 1) | (1 << (t - 1));
305             } else {
306                 if (get_bits_left(gb) < 0)
307                     goto error;
308             }
309             ctx->zeroes = t;
310             if (ctx->zeroes) {
311                 memset(ctx->ch[0].median, 0, sizeof(ctx->ch[0].median));
312                 memset(ctx->ch[1].median, 0, sizeof(ctx->ch[1].median));
313                 c->slow_level -= LEVEL_DECAY(c->slow_level);
314                 return 0;
315             }
316         }
317     }
318
319     if (ctx->zero) {
320         t         = 0;
321         ctx->zero = 0;
322     } else {
323         t = get_unary_0_33(gb);
324         if (get_bits_left(gb) < 0)
325             goto error;
326         if (t == 16) {
327             t2 = get_unary_0_33(gb);
328             if (t2 < 2) {
329                 if (get_bits_left(gb) < 0)
330                     goto error;
331                 t += t2;
332             } else {
333                 if (get_bits_left(gb) < t2 - 1)
334                     goto error;
335                 t += get_bits(gb, t2 - 1) | (1 << (t2 - 1));
336             }
337         }
338
339         if (ctx->one) {
340             ctx->one = t & 1;
341             t        = (t >> 1) + 1;
342         } else {
343             ctx->one = t & 1;
344             t      >>= 1;
345         }
346         ctx->zero = !ctx->one;
347     }
348
349     if (ctx->hybrid && !channel)
350         update_error_limit(ctx);
351
352     if (!t) {
353         base = 0;
354         add  = GET_MED(0) - 1;
355         DEC_MED(0);
356     } else if (t == 1) {
357         base = GET_MED(0);
358         add  = GET_MED(1) - 1;
359         INC_MED(0);
360         DEC_MED(1);
361     } else if (t == 2) {
362         base = GET_MED(0) + GET_MED(1);
363         add  = GET_MED(2) - 1;
364         INC_MED(0);
365         INC_MED(1);
366         DEC_MED(2);
367     } else {
368         base = GET_MED(0) + GET_MED(1) + GET_MED(2) * (t - 2);
369         add  = GET_MED(2) - 1;
370         INC_MED(0);
371         INC_MED(1);
372         INC_MED(2);
373     }
374     if (!c->error_limit) {
375         if (add >= 0x2000000U) {
376             av_log(ctx->avctx, AV_LOG_ERROR, "k %d is too large\n", add);
377             goto error;
378         }
379         ret = base + get_tail(gb, add);
380         if (get_bits_left(gb) <= 0)
381             goto error;
382     } else {
383         int mid = (base * 2 + add + 1) >> 1;
384         while (add > c->error_limit) {
385             if (get_bits_left(gb) <= 0)
386                 goto error;
387             if (get_bits1(gb)) {
388                 add -= (mid - base);
389                 base = mid;
390             } else
391                 add = mid - base - 1;
392             mid = (base * 2 + add + 1) >> 1;
393         }
394         ret = mid;
395     }
396     sign = get_bits1(gb);
397     if (ctx->hybrid_bitrate)
398         c->slow_level += wp_log2(ret) - LEVEL_DECAY(c->slow_level);
399     return sign ? ~ret : ret;
400
401 error:
402     *last = 1;
403     return 0;
404 }
405
406 static inline int wv_get_value_integer(WavpackFrameContext *s, uint32_t *crc,
407                                        int S)
408 {
409     int bit;
410
411     if (s->extra_bits) {
412         S <<= s->extra_bits;
413
414         if (s->got_extra_bits &&
415             get_bits_left(&s->gb_extra_bits) >= s->extra_bits) {
416             S   |= get_bits(&s->gb_extra_bits, s->extra_bits);
417             *crc = *crc * 9 + (S & 0xffff) * 3 + ((unsigned)S >> 16);
418         }
419     }
420
421     bit = (S & s->and) | s->or;
422     bit = ((S + bit) << s->shift) - bit;
423
424     if (s->hybrid)
425         bit = av_clip(bit, s->hybrid_minclip, s->hybrid_maxclip);
426
427     return bit << s->post_shift;
428 }
429
430 static float wv_get_value_float(WavpackFrameContext *s, uint32_t *crc, int S)
431 {
432     union {
433         float    f;
434         uint32_t u;
435     } value;
436
437     unsigned int sign;
438     int exp = s->float_max_exp;
439
440     if (s->got_extra_bits) {
441         const int max_bits  = 1 + 23 + 8 + 1;
442         const int left_bits = get_bits_left(&s->gb_extra_bits);
443
444         if (left_bits + 8 * FF_INPUT_BUFFER_PADDING_SIZE < max_bits)
445             return 0.0;
446     }
447
448     if (S) {
449         S  <<= s->float_shift;
450         sign = S < 0;
451         if (sign)
452             S = -S;
453         if (S >= 0x1000000) {
454             if (s->got_extra_bits && get_bits1(&s->gb_extra_bits))
455                 S = get_bits(&s->gb_extra_bits, 23);
456             else
457                 S = 0;
458             exp = 255;
459         } else if (exp) {
460             int shift = 23 - av_log2(S);
461             exp = s->float_max_exp;
462             if (exp <= shift)
463                 shift = --exp;
464             exp -= shift;
465
466             if (shift) {
467                 S <<= shift;
468                 if ((s->float_flag & WV_FLT_SHIFT_ONES) ||
469                     (s->got_extra_bits &&
470                      (s->float_flag & WV_FLT_SHIFT_SAME) &&
471                      get_bits1(&s->gb_extra_bits))) {
472                     S |= (1 << shift) - 1;
473                 } else if (s->got_extra_bits &&
474                            (s->float_flag & WV_FLT_SHIFT_SENT)) {
475                     S |= get_bits(&s->gb_extra_bits, shift);
476                 }
477             }
478         } else {
479             exp = s->float_max_exp;
480         }
481         S &= 0x7fffff;
482     } else {
483         sign = 0;
484         exp  = 0;
485         if (s->got_extra_bits && (s->float_flag & WV_FLT_ZERO_SENT)) {
486             if (get_bits1(&s->gb_extra_bits)) {
487                 S = get_bits(&s->gb_extra_bits, 23);
488                 if (s->float_max_exp >= 25)
489                     exp = get_bits(&s->gb_extra_bits, 8);
490                 sign = get_bits1(&s->gb_extra_bits);
491             } else {
492                 if (s->float_flag & WV_FLT_ZERO_SIGN)
493                     sign = get_bits1(&s->gb_extra_bits);
494             }
495         }
496     }
497
498     *crc = *crc * 27 + S * 9 + exp * 3 + sign;
499
500     value.u = (sign << 31) | (exp << 23) | S;
501     return value.f;
502 }
503
504 static void wv_reset_saved_context(WavpackFrameContext *s)
505 {
506     s->pos    = 0;
507     s->sc.crc = s->extra_sc.crc = 0xFFFFFFFF;
508 }
509
510 static inline int wv_check_crc(WavpackFrameContext *s, uint32_t crc,
511                                uint32_t crc_extra_bits)
512 {
513     if (crc != s->CRC) {
514         av_log(s->avctx, AV_LOG_ERROR, "CRC error\n");
515         return AVERROR_INVALIDDATA;
516     }
517     if (s->got_extra_bits && crc_extra_bits != s->crc_extra_bits) {
518         av_log(s->avctx, AV_LOG_ERROR, "Extra bits CRC error\n");
519         return AVERROR_INVALIDDATA;
520     }
521
522     return 0;
523 }
524
525 static inline int wv_unpack_stereo(WavpackFrameContext *s, GetBitContext *gb,
526                                    void *dst, const int type)
527 {
528     int i, j, count = 0;
529     int last, t;
530     int A, B, L, L2, R, R2;
531     int pos                 = s->pos;
532     uint32_t crc            = s->sc.crc;
533     uint32_t crc_extra_bits = s->extra_sc.crc;
534     int16_t *dst16          = dst;
535     int32_t *dst32          = dst;
536     float *dstfl            = dst;
537     const int channel_pad   = s->avctx->channels - 2;
538
539     s->one = s->zero = s->zeroes = 0;
540     do {
541         L = wv_get_value(s, gb, 0, &last);
542         if (last)
543             break;
544         R = wv_get_value(s, gb, 1, &last);
545         if (last)
546             break;
547         for (i = 0; i < s->terms; i++) {
548             t = s->decorr[i].value;
549             if (t > 0) {
550                 if (t > 8) {
551                     if (t & 1) {
552                         A = 2 * s->decorr[i].samplesA[0] - s->decorr[i].samplesA[1];
553                         B = 2 * s->decorr[i].samplesB[0] - s->decorr[i].samplesB[1];
554                     } else {
555                         A = (3 * s->decorr[i].samplesA[0] - s->decorr[i].samplesA[1]) >> 1;
556                         B = (3 * s->decorr[i].samplesB[0] - s->decorr[i].samplesB[1]) >> 1;
557                     }
558                     s->decorr[i].samplesA[1] = s->decorr[i].samplesA[0];
559                     s->decorr[i].samplesB[1] = s->decorr[i].samplesB[0];
560                     j                        = 0;
561                 } else {
562                     A = s->decorr[i].samplesA[pos];
563                     B = s->decorr[i].samplesB[pos];
564                     j = (pos + t) & 7;
565                 }
566                 if (type != AV_SAMPLE_FMT_S16) {
567                     L2 = L + ((s->decorr[i].weightA * (int64_t)A + 512) >> 10);
568                     R2 = R + ((s->decorr[i].weightB * (int64_t)B + 512) >> 10);
569                 } else {
570                     L2 = L + ((s->decorr[i].weightA * A + 512) >> 10);
571                     R2 = R + ((s->decorr[i].weightB * B + 512) >> 10);
572                 }
573                 if (A && L)
574                     s->decorr[i].weightA -= ((((L ^ A) >> 30) & 2) - 1) * s->decorr[i].delta;
575                 if (B && R)
576                     s->decorr[i].weightB -= ((((R ^ B) >> 30) & 2) - 1) * s->decorr[i].delta;
577                 s->decorr[i].samplesA[j] = L = L2;
578                 s->decorr[i].samplesB[j] = R = R2;
579             } else if (t == -1) {
580                 if (type != AV_SAMPLE_FMT_S16)
581                     L2 = L + ((s->decorr[i].weightA * (int64_t)s->decorr[i].samplesA[0] + 512) >> 10);
582                 else
583                     L2 = L + ((s->decorr[i].weightA * s->decorr[i].samplesA[0] + 512) >> 10);
584                 UPDATE_WEIGHT_CLIP(s->decorr[i].weightA, s->decorr[i].delta, s->decorr[i].samplesA[0], L);
585                 L = L2;
586                 if (type != AV_SAMPLE_FMT_S16)
587                     R2 = R + ((s->decorr[i].weightB * (int64_t)L2 + 512) >> 10);
588                 else
589                     R2 = R + ((s->decorr[i].weightB * L2 + 512) >> 10);
590                 UPDATE_WEIGHT_CLIP(s->decorr[i].weightB, s->decorr[i].delta, L2, R);
591                 R                        = R2;
592                 s->decorr[i].samplesA[0] = R;
593             } else {
594                 if (type != AV_SAMPLE_FMT_S16)
595                     R2 = R + ((s->decorr[i].weightB * (int64_t)s->decorr[i].samplesB[0] + 512) >> 10);
596                 else
597                     R2 = R + ((s->decorr[i].weightB * s->decorr[i].samplesB[0] + 512) >> 10);
598                 UPDATE_WEIGHT_CLIP(s->decorr[i].weightB, s->decorr[i].delta, s->decorr[i].samplesB[0], R);
599                 R = R2;
600
601                 if (t == -3) {
602                     R2                       = s->decorr[i].samplesA[0];
603                     s->decorr[i].samplesA[0] = R;
604                 }
605
606                 if (type != AV_SAMPLE_FMT_S16)
607                     L2 = L + ((s->decorr[i].weightA * (int64_t)R2 + 512) >> 10);
608                 else
609                     L2 = L + ((s->decorr[i].weightA * R2 + 512) >> 10);
610                 UPDATE_WEIGHT_CLIP(s->decorr[i].weightA, s->decorr[i].delta, R2, L);
611                 L                        = L2;
612                 s->decorr[i].samplesB[0] = L;
613             }
614         }
615         pos = (pos + 1) & 7;
616         if (s->joint)
617             L += (R -= (L >> 1));
618         crc = (crc * 3 + L) * 3 + R;
619
620         if (type == AV_SAMPLE_FMT_FLT) {
621             *dstfl++ = wv_get_value_float(s, &crc_extra_bits, L);
622             *dstfl++ = wv_get_value_float(s, &crc_extra_bits, R);
623             dstfl   += channel_pad;
624         } else if (type == AV_SAMPLE_FMT_S32) {
625             *dst32++ = wv_get_value_integer(s, &crc_extra_bits, L);
626             *dst32++ = wv_get_value_integer(s, &crc_extra_bits, R);
627             dst32   += channel_pad;
628         } else {
629             *dst16++ = wv_get_value_integer(s, &crc_extra_bits, L);
630             *dst16++ = wv_get_value_integer(s, &crc_extra_bits, R);
631             dst16   += channel_pad;
632         }
633         count++;
634     } while (!last && count < s->samples);
635
636     wv_reset_saved_context(s);
637     if ((s->avctx->err_recognition & AV_EF_CRCCHECK) &&
638         wv_check_crc(s, crc, crc_extra_bits))
639         return AVERROR_INVALIDDATA;
640
641     return count * 2;
642 }
643
644 static inline int wv_unpack_mono(WavpackFrameContext *s, GetBitContext *gb,
645                                  void *dst, const int type)
646 {
647     int i, j, count = 0;
648     int last, t;
649     int A, S, T;
650     int pos                  = s->pos;
651     uint32_t crc             = s->sc.crc;
652     uint32_t crc_extra_bits  = s->extra_sc.crc;
653     int16_t *dst16           = dst;
654     int32_t *dst32           = dst;
655     float *dstfl             = dst;
656     const int channel_stride = s->avctx->channels;
657
658     s->one = s->zero = s->zeroes = 0;
659     do {
660         T = wv_get_value(s, gb, 0, &last);
661         S = 0;
662         if (last)
663             break;
664         for (i = 0; i < s->terms; i++) {
665             t = s->decorr[i].value;
666             if (t > 8) {
667                 if (t & 1)
668                     A =  2 * s->decorr[i].samplesA[0] - s->decorr[i].samplesA[1];
669                 else
670                     A = (3 * s->decorr[i].samplesA[0] - s->decorr[i].samplesA[1]) >> 1;
671                 s->decorr[i].samplesA[1] = s->decorr[i].samplesA[0];
672                 j                        = 0;
673             } else {
674                 A = s->decorr[i].samplesA[pos];
675                 j = (pos + t) & 7;
676             }
677             if (type != AV_SAMPLE_FMT_S16)
678                 S = T + ((s->decorr[i].weightA * (int64_t)A + 512) >> 10);
679             else
680                 S = T + ((s->decorr[i].weightA * A + 512) >> 10);
681             if (A && T)
682                 s->decorr[i].weightA -= ((((T ^ A) >> 30) & 2) - 1) * s->decorr[i].delta;
683             s->decorr[i].samplesA[j] = T = S;
684         }
685         pos = (pos + 1) & 7;
686         crc = crc * 3 + S;
687
688         if (type == AV_SAMPLE_FMT_FLT) {
689             *dstfl = wv_get_value_float(s, &crc_extra_bits, S);
690             dstfl += channel_stride;
691         } else if (type == AV_SAMPLE_FMT_S32) {
692             *dst32 = wv_get_value_integer(s, &crc_extra_bits, S);
693             dst32 += channel_stride;
694         } else {
695             *dst16 = wv_get_value_integer(s, &crc_extra_bits, S);
696             dst16 += channel_stride;
697         }
698         count++;
699     } while (!last && count < s->samples);
700
701     wv_reset_saved_context(s);
702     if ((s->avctx->err_recognition & AV_EF_CRCCHECK) &&
703         wv_check_crc(s, crc, crc_extra_bits))
704         return AVERROR_INVALIDDATA;
705
706     return count;
707 }
708
709 static av_cold int wv_alloc_frame_context(WavpackContext *c)
710 {
711     if (c->fdec_num == WV_MAX_FRAME_DECODERS)
712         return -1;
713
714     c->fdec[c->fdec_num] = av_mallocz(sizeof(**c->fdec));
715     if (!c->fdec[c->fdec_num])
716         return -1;
717     c->fdec_num++;
718     c->fdec[c->fdec_num - 1]->avctx = c->avctx;
719     wv_reset_saved_context(c->fdec[c->fdec_num - 1]);
720
721     return 0;
722 }
723
724 static av_cold int wavpack_decode_init(AVCodecContext *avctx)
725 {
726     WavpackContext *s = avctx->priv_data;
727
728     s->avctx = avctx;
729     if (avctx->bits_per_coded_sample <= 16)
730         avctx->sample_fmt = AV_SAMPLE_FMT_S16;
731     else
732         avctx->sample_fmt = AV_SAMPLE_FMT_S32;
733     if (avctx->channels <= 2 && !avctx->channel_layout)
734         avctx->channel_layout = (avctx->channels == 2) ? AV_CH_LAYOUT_STEREO
735                                                        : AV_CH_LAYOUT_MONO;
736
737     s->multichannel = avctx->channels > 2;
738     /* lavf demuxer does not provide extradata, Matroska stores 0x403
739      * there, use this to detect decoding mode for multichannel */
740     s->mkv_mode = 0;
741     if (s->multichannel && avctx->extradata && avctx->extradata_size == 2) {
742         int ver = AV_RL16(avctx->extradata);
743         if (ver >= 0x402 && ver <= 0x410)
744             s->mkv_mode = 1;
745     }
746
747     s->fdec_num = 0;
748
749     return 0;
750 }
751
752 static av_cold int wavpack_decode_end(AVCodecContext *avctx)
753 {
754     WavpackContext *s = avctx->priv_data;
755     int i;
756
757     for (i = 0; i < s->fdec_num; i++)
758         av_freep(&s->fdec[i]);
759     s->fdec_num = 0;
760
761     return 0;
762 }
763
764 static int wavpack_decode_block(AVCodecContext *avctx, int block_no,
765                                 void *data, int *got_frame_ptr,
766                                 const uint8_t *buf, int buf_size)
767 {
768     WavpackContext *wc = avctx->priv_data;
769     WavpackFrameContext *s;
770     void *samples = data;
771     int samplecount;
772     int got_terms   = 0, got_weights = 0, got_samples = 0,
773         got_entropy = 0, got_bs      = 0, got_float   = 0, got_hybrid = 0;
774     const uint8_t *orig_buf = buf;
775     const uint8_t *buf_end  = buf + buf_size;
776     int i, j, id, size, ssize, weights, t;
777     int bpp, chan, chmask, orig_bpp;
778
779     if (buf_size == 0) {
780         *got_frame_ptr = 0;
781         return 0;
782     }
783
784     if (block_no >= wc->fdec_num && wv_alloc_frame_context(wc) < 0) {
785         av_log(avctx, AV_LOG_ERROR, "Error creating frame decode context\n");
786         return AVERROR_INVALIDDATA;
787     }
788
789     s = wc->fdec[block_no];
790     if (!s) {
791         av_log(avctx, AV_LOG_ERROR, "Context for block %d is not present\n",
792                block_no);
793         return AVERROR_INVALIDDATA;
794     }
795
796     if (wc->ch_offset >= avctx->channels) {
797         av_log(avctx, AV_LOG_ERROR, "too many channels\n");
798         return -1;
799     }
800
801     memset(s->decorr, 0, MAX_TERMS * sizeof(Decorr));
802     memset(s->ch, 0, sizeof(s->ch));
803     s->extra_bits     = 0;
804     s->and            = s->or = s->shift = 0;
805     s->got_extra_bits = 0;
806
807     if (!wc->mkv_mode) {
808         s->samples = AV_RL32(buf);
809         buf       += 4;
810         if (!s->samples) {
811             *got_frame_ptr = 0;
812             return 0;
813         }
814         if (s->samples > wc->samples) {
815             av_log(avctx, AV_LOG_ERROR, "too many samples in block");
816             return -1;
817         }
818     } else {
819         s->samples = wc->samples;
820     }
821     s->frame_flags = AV_RL32(buf);
822     buf           += 4;
823     bpp            = av_get_bytes_per_sample(avctx->sample_fmt);
824     samples        = (uint8_t *)samples + bpp * wc->ch_offset;
825     orig_bpp       = ((s->frame_flags & 0x03) + 1) << 3;
826
827     s->stereo         = !(s->frame_flags & WV_MONO);
828     s->stereo_in      =  (s->frame_flags & WV_FALSE_STEREO) ? 0 : s->stereo;
829     s->joint          =   s->frame_flags & WV_JOINT_STEREO;
830     s->hybrid         =   s->frame_flags & WV_HYBRID_MODE;
831     s->hybrid_bitrate =   s->frame_flags & WV_HYBRID_BITRATE;
832     s->post_shift     = bpp * 8 - orig_bpp + ((s->frame_flags >> 13) & 0x1f);
833     s->hybrid_maxclip =  ((1LL << (orig_bpp - 1)) - 1);
834     s->hybrid_minclip = ((-1LL << (orig_bpp - 1)));
835     s->CRC            = AV_RL32(buf);
836     buf              += 4;
837     if (wc->mkv_mode)
838         buf += 4;  // skip block size;
839
840     wc->ch_offset += 1 + s->stereo;
841
842     // parse metadata blocks
843     while (buf < buf_end) {
844         id   = *buf++;
845         size = *buf++;
846         if (id & WP_IDF_LONG) {
847             size |= (*buf++) << 8;
848             size |= (*buf++) << 16;
849         }
850         size <<= 1; // size is specified in words
851         ssize  = size;
852         if (id & WP_IDF_ODD)
853             size--;
854         if (size < 0) {
855             av_log(avctx, AV_LOG_ERROR,
856                    "Got incorrect block %02X with size %i\n", id, size);
857             break;
858         }
859         if (buf + ssize > buf_end) {
860             av_log(avctx, AV_LOG_ERROR,
861                    "Block size %i is out of bounds\n", size);
862             break;
863         }
864         if (id & WP_IDF_IGNORE) {
865             buf += ssize;
866             continue;
867         }
868         switch (id & WP_IDF_MASK) {
869         case WP_ID_DECTERMS:
870             if (size > MAX_TERMS) {
871                 av_log(avctx, AV_LOG_ERROR, "Too many decorrelation terms\n");
872                 s->terms = 0;
873                 buf     += ssize;
874                 continue;
875             }
876             s->terms = size;
877             for (i = 0; i < s->terms; i++) {
878                 s->decorr[s->terms - i - 1].value = (*buf & 0x1F) - 5;
879                 s->decorr[s->terms - i - 1].delta = *buf >> 5;
880                 buf++;
881             }
882             got_terms = 1;
883             break;
884         case WP_ID_DECWEIGHTS:
885             if (!got_terms) {
886                 av_log(avctx, AV_LOG_ERROR, "No decorrelation terms met\n");
887                 continue;
888             }
889             weights = size >> s->stereo_in;
890             if (weights > MAX_TERMS || weights > s->terms) {
891                 av_log(avctx, AV_LOG_ERROR, "Too many decorrelation weights\n");
892                 buf += ssize;
893                 continue;
894             }
895             for (i = 0; i < weights; i++) {
896                 t                                   = (int8_t)(*buf++);
897                 s->decorr[s->terms - i - 1].weightA = t << 3;
898                 if (s->decorr[s->terms - i - 1].weightA > 0)
899                     s->decorr[s->terms - i - 1].weightA +=
900                         (s->decorr[s->terms - i - 1].weightA + 64) >> 7;
901                 if (s->stereo_in) {
902                     t                                   = (int8_t)(*buf++);
903                     s->decorr[s->terms - i - 1].weightB = t << 3;
904                     if (s->decorr[s->terms - i - 1].weightB > 0)
905                         s->decorr[s->terms - i - 1].weightB +=
906                             (s->decorr[s->terms - i - 1].weightB + 64) >> 7;
907                 }
908             }
909             got_weights = 1;
910             break;
911         case WP_ID_DECSAMPLES:
912             if (!got_terms) {
913                 av_log(avctx, AV_LOG_ERROR, "No decorrelation terms met\n");
914                 continue;
915             }
916             t = 0;
917             for (i = s->terms - 1; (i >= 0) && (t < size) && buf <= buf_end; i--) {
918                 if (s->decorr[i].value > 8) {
919                     s->decorr[i].samplesA[0] = wp_exp2(AV_RL16(buf));
920                     buf                     += 2;
921                     s->decorr[i].samplesA[1] = wp_exp2(AV_RL16(buf));
922                     buf                     += 2;
923                     if (s->stereo_in) {
924                         s->decorr[i].samplesB[0] = wp_exp2(AV_RL16(buf));
925                         buf                     += 2;
926                         s->decorr[i].samplesB[1] = wp_exp2(AV_RL16(buf));
927                         buf                     += 2;
928                         t                       += 4;
929                     }
930                     t += 4;
931                 } else if (s->decorr[i].value < 0) {
932                     s->decorr[i].samplesA[0] = wp_exp2(AV_RL16(buf));
933                     buf                     += 2;
934                     s->decorr[i].samplesB[0] = wp_exp2(AV_RL16(buf));
935                     buf                     += 2;
936                     t                       += 4;
937                 } else {
938                     for (j = 0; j < s->decorr[i].value && buf+1<buf_end; j++) {
939                         s->decorr[i].samplesA[j] = wp_exp2(AV_RL16(buf));
940                         buf                     += 2;
941                         if (s->stereo_in) {
942                             s->decorr[i].samplesB[j] = wp_exp2(AV_RL16(buf));
943                             buf                     += 2;
944                         }
945                     }
946                     t += s->decorr[i].value * 2 * (s->stereo_in + 1);
947                 }
948             }
949             got_samples = 1;
950             break;
951         case WP_ID_ENTROPY:
952             if (size != 6 * (s->stereo_in + 1)) {
953                 av_log(avctx, AV_LOG_ERROR,
954                        "Entropy vars size should be %i, got %i",
955                        6 * (s->stereo_in + 1), size);
956                 buf += ssize;
957                 continue;
958             }
959             for (j = 0; j <= s->stereo_in; j++)
960                 for (i = 0; i < 3; i++) {
961                     s->ch[j].median[i] = wp_exp2(AV_RL16(buf));
962                     buf               += 2;
963                 }
964             got_entropy = 1;
965             break;
966         case WP_ID_HYBRID:
967             if (s->hybrid_bitrate) {
968                 for (i = 0; i <= s->stereo_in; i++) {
969                     s->ch[i].slow_level = wp_exp2(AV_RL16(buf));
970                     buf                += 2;
971                     size               -= 2;
972                 }
973             }
974             for (i = 0; i < (s->stereo_in + 1); i++) {
975                 s->ch[i].bitrate_acc = AV_RL16(buf) << 16;
976                 buf                 += 2;
977                 size                -= 2;
978             }
979             if (size > 0) {
980                 for (i = 0; i < (s->stereo_in + 1); i++) {
981                     s->ch[i].bitrate_delta = wp_exp2((int16_t)AV_RL16(buf));
982                     buf                   += 2;
983                 }
984             } else {
985                 for (i = 0; i < (s->stereo_in + 1); i++)
986                     s->ch[i].bitrate_delta = 0;
987             }
988             got_hybrid = 1;
989             break;
990         case WP_ID_INT32INFO:
991             if (size != 4) {
992                 av_log(avctx, AV_LOG_ERROR,
993                        "Invalid INT32INFO, size = %i, sent_bits = %i\n",
994                        size, *buf);
995                 buf += ssize;
996                 continue;
997             }
998             if (buf[0])
999                 s->extra_bits = buf[0];
1000             else if (buf[1])
1001                 s->shift = buf[1];
1002             else if (buf[2]) {
1003                 s->and   = s->or = 1;
1004                 s->shift = buf[2];
1005             } else if (buf[3]) {
1006                 s->and   = 1;
1007                 s->shift = buf[3];
1008             }
1009             /* original WavPack decoder forces 32-bit lossy sound to be treated
1010              * as 24-bit one in order to have proper clipping */
1011             if (s->hybrid && bpp == 4 && s->post_shift < 8 && s->shift > 8) {
1012                 s->post_shift      += 8;
1013                 s->shift           -= 8;
1014                 s->hybrid_maxclip >>= 8;
1015                 s->hybrid_minclip >>= 8;
1016             }
1017             buf += 4;
1018             break;
1019         case WP_ID_FLOATINFO:
1020             if (size != 4) {
1021                 av_log(avctx, AV_LOG_ERROR,
1022                        "Invalid FLOATINFO, size = %i\n", size);
1023                 buf += ssize;
1024                 continue;
1025             }
1026             s->float_flag    = buf[0];
1027             s->float_shift   = buf[1];
1028             s->float_max_exp = buf[2];
1029             buf             += 4;
1030             got_float        = 1;
1031             break;
1032         case WP_ID_DATA:
1033             s->sc.offset = buf - orig_buf;
1034             s->sc.size   = size * 8;
1035             init_get_bits(&s->gb, buf, size * 8);
1036             s->data_size = size * 8;
1037             buf         += size;
1038             got_bs       = 1;
1039             break;
1040         case WP_ID_EXTRABITS:
1041             if (size <= 4) {
1042                 av_log(avctx, AV_LOG_ERROR, "Invalid EXTRABITS, size = %i\n",
1043                        size);
1044                 buf += size;
1045                 continue;
1046             }
1047             s->extra_sc.offset = buf - orig_buf;
1048             s->extra_sc.size   = size * 8;
1049             init_get_bits(&s->gb_extra_bits, buf, size * 8);
1050             s->crc_extra_bits  = get_bits_long(&s->gb_extra_bits, 32);
1051             buf               += size;
1052             s->got_extra_bits  = 1;
1053             break;
1054         case WP_ID_CHANINFO:
1055             if (size <= 1) {
1056                 av_log(avctx, AV_LOG_ERROR,
1057                        "Insufficient channel information\n");
1058                 return AVERROR_INVALIDDATA;
1059             }
1060             chan = *buf++;
1061             switch (size - 2) {
1062             case 0:
1063                 chmask = *buf;
1064                 break;
1065             case 1:
1066                 chmask = AV_RL16(buf);
1067                 break;
1068             case 2:
1069                 chmask = AV_RL24(buf);
1070                 break;
1071             case 3:
1072                 chmask = AV_RL32(buf);
1073                 break;
1074             case 5:
1075                 chan  |= (buf[1] & 0xF) << 8;
1076                 chmask = AV_RL24(buf + 2);
1077                 break;
1078             default:
1079                 av_log(avctx, AV_LOG_ERROR, "Invalid channel info size %d\n",
1080                        size);
1081                 chan   = avctx->channels;
1082                 chmask = avctx->channel_layout;
1083             }
1084             if (chan != avctx->channels) {
1085                 av_log(avctx, AV_LOG_ERROR,
1086                        "Block reports total %d channels, "
1087                        "decoder believes it's %d channels\n",
1088                        chan, avctx->channels);
1089                 return AVERROR_INVALIDDATA;
1090             }
1091             if (!avctx->channel_layout)
1092                 avctx->channel_layout = chmask;
1093             buf += size - 1;
1094             break;
1095         default:
1096             buf += size;
1097         }
1098         if (id & WP_IDF_ODD)
1099             buf++;
1100     }
1101
1102     if (!got_terms) {
1103         av_log(avctx, AV_LOG_ERROR, "No block with decorrelation terms\n");
1104         return AVERROR_INVALIDDATA;
1105     }
1106     if (!got_weights) {
1107         av_log(avctx, AV_LOG_ERROR, "No block with decorrelation weights\n");
1108         return AVERROR_INVALIDDATA;
1109     }
1110     if (!got_samples) {
1111         av_log(avctx, AV_LOG_ERROR, "No block with decorrelation samples\n");
1112         return AVERROR_INVALIDDATA;
1113     }
1114     if (!got_entropy) {
1115         av_log(avctx, AV_LOG_ERROR, "No block with entropy info\n");
1116         return AVERROR_INVALIDDATA;
1117     }
1118     if (s->hybrid && !got_hybrid) {
1119         av_log(avctx, AV_LOG_ERROR, "Hybrid config not found\n");
1120         return AVERROR_INVALIDDATA;
1121     }
1122     if (!got_bs) {
1123         av_log(avctx, AV_LOG_ERROR, "Packed samples not found\n");
1124         return AVERROR_INVALIDDATA;
1125     }
1126     if (!got_float && avctx->sample_fmt == AV_SAMPLE_FMT_FLT) {
1127         av_log(avctx, AV_LOG_ERROR, "Float information not found\n");
1128         return AVERROR_INVALIDDATA;
1129     }
1130     if (s->got_extra_bits && avctx->sample_fmt != AV_SAMPLE_FMT_FLT) {
1131         const int size   = get_bits_left(&s->gb_extra_bits);
1132         const int wanted = s->samples * s->extra_bits << s->stereo_in;
1133         if (size < wanted) {
1134             av_log(avctx, AV_LOG_ERROR, "Too small EXTRABITS\n");
1135             s->got_extra_bits = 0;
1136         }
1137     }
1138
1139     if (s->stereo_in) {
1140         if (avctx->sample_fmt == AV_SAMPLE_FMT_S16)
1141             samplecount = wv_unpack_stereo(s, &s->gb, samples, AV_SAMPLE_FMT_S16);
1142         else if (avctx->sample_fmt == AV_SAMPLE_FMT_S32)
1143             samplecount = wv_unpack_stereo(s, &s->gb, samples, AV_SAMPLE_FMT_S32);
1144         else
1145             samplecount = wv_unpack_stereo(s, &s->gb, samples, AV_SAMPLE_FMT_FLT);
1146
1147         if (samplecount < 0)
1148             return samplecount;
1149
1150         samplecount >>= 1;
1151     } else {
1152         const int channel_stride = avctx->channels;
1153
1154         if (avctx->sample_fmt == AV_SAMPLE_FMT_S16)
1155             samplecount = wv_unpack_mono(s, &s->gb, samples, AV_SAMPLE_FMT_S16);
1156         else if (avctx->sample_fmt == AV_SAMPLE_FMT_S32)
1157             samplecount = wv_unpack_mono(s, &s->gb, samples, AV_SAMPLE_FMT_S32);
1158         else
1159             samplecount = wv_unpack_mono(s, &s->gb, samples, AV_SAMPLE_FMT_FLT);
1160
1161         if (samplecount < 0)
1162             return samplecount;
1163
1164         if (s->stereo && avctx->sample_fmt == AV_SAMPLE_FMT_S16) {
1165             int16_t *dst = (int16_t *)samples + 1;
1166             int16_t *src = (int16_t *)samples;
1167             int cnt      = samplecount;
1168             while (cnt--) {
1169                 *dst = *src;
1170                 src += channel_stride;
1171                 dst += channel_stride;
1172             }
1173         } else if (s->stereo && avctx->sample_fmt == AV_SAMPLE_FMT_S32) {
1174             int32_t *dst = (int32_t *)samples + 1;
1175             int32_t *src = (int32_t *)samples;
1176             int cnt      = samplecount;
1177             while (cnt--) {
1178                 *dst = *src;
1179                 src += channel_stride;
1180                 dst += channel_stride;
1181             }
1182         } else if (s->stereo) {
1183             float *dst = (float *)samples + 1;
1184             float *src = (float *)samples;
1185             int cnt    = samplecount;
1186             while (cnt--) {
1187                 *dst = *src;
1188                 src += channel_stride;
1189                 dst += channel_stride;
1190             }
1191         }
1192     }
1193
1194     *got_frame_ptr = 1;
1195
1196     return samplecount * bpp;
1197 }
1198
1199 static void wavpack_decode_flush(AVCodecContext *avctx)
1200 {
1201     WavpackContext *s = avctx->priv_data;
1202     int i;
1203
1204     for (i = 0; i < s->fdec_num; i++)
1205         wv_reset_saved_context(s->fdec[i]);
1206 }
1207
1208 static int wavpack_decode_frame(AVCodecContext *avctx, void *data,
1209                                 int *got_frame_ptr, AVPacket *avpkt)
1210 {
1211     WavpackContext *s  = avctx->priv_data;
1212     const uint8_t *buf = avpkt->data;
1213     int buf_size       = avpkt->size;
1214     AVFrame *frame     = data;
1215     int frame_size, ret, frame_flags;
1216     int samplecount = 0;
1217
1218     s->block     = 0;
1219     s->ch_offset = 0;
1220
1221     /* determine number of samples */
1222     if (s->mkv_mode) {
1223         s->samples  = AV_RL32(buf);
1224         buf        += 4;
1225         frame_flags = AV_RL32(buf);
1226     } else {
1227         if (s->multichannel) {
1228             s->samples  = AV_RL32(buf + 4);
1229             frame_flags = AV_RL32(buf + 8);
1230         } else {
1231             s->samples  = AV_RL32(buf);
1232             frame_flags = AV_RL32(buf + 4);
1233         }
1234     }
1235     if (s->samples <= 0 || s->samples > WV_MAX_SAMPLES) {
1236         av_log(avctx, AV_LOG_ERROR, "Invalid number of samples: %d\n",
1237                s->samples);
1238         return AVERROR_INVALIDDATA;
1239     }
1240
1241     if (frame_flags & 0x80) {
1242         avctx->sample_fmt = AV_SAMPLE_FMT_FLT;
1243     } else if ((frame_flags & 0x03) <= 1) {
1244         avctx->sample_fmt = AV_SAMPLE_FMT_S16;
1245     } else {
1246         avctx->sample_fmt          = AV_SAMPLE_FMT_S32;
1247         avctx->bits_per_raw_sample = ((frame_flags & 0x03) + 1) << 3;
1248     }
1249
1250     /* get output buffer */
1251     frame->nb_samples = s->samples + 1;
1252     if ((ret = ff_get_buffer(avctx, frame, 0)) < 0)
1253         return ret;
1254     frame->nb_samples = s->samples;
1255
1256     while (buf_size > 0) {
1257         if (!s->multichannel) {
1258             frame_size = buf_size;
1259         } else {
1260             if (!s->mkv_mode) {
1261                 frame_size = AV_RL32(buf) - 12;
1262                 buf       += 4;
1263                 buf_size  -= 4;
1264             } else {
1265                 if (buf_size < 12) // MKV files can have zero flags after last block
1266                     break;
1267                 frame_size = AV_RL32(buf + 8) + 12;
1268             }
1269         }
1270         if (frame_size < 0 || frame_size > buf_size) {
1271             av_log(avctx, AV_LOG_ERROR,
1272                    "Block %d has invalid size (size %d vs. %d bytes left)\n",
1273                    s->block, frame_size, buf_size);
1274             wavpack_decode_flush(avctx);
1275             return AVERROR_INVALIDDATA;
1276         }
1277         if ((samplecount = wavpack_decode_block(avctx, s->block,
1278                                                 frame->data[0], got_frame_ptr,
1279                                                 buf, frame_size)) < 0) {
1280             wavpack_decode_flush(avctx);
1281             return samplecount;
1282         }
1283         s->block++;
1284         buf      += frame_size;
1285         buf_size -= frame_size;
1286     }
1287
1288     return avpkt->size;
1289 }
1290
1291 AVCodec ff_wavpack_decoder = {
1292     .name           = "wavpack",
1293     .type           = AVMEDIA_TYPE_AUDIO,
1294     .id             = AV_CODEC_ID_WAVPACK,
1295     .priv_data_size = sizeof(WavpackContext),
1296     .init           = wavpack_decode_init,
1297     .close          = wavpack_decode_end,
1298     .decode         = wavpack_decode_frame,
1299     .flush          = wavpack_decode_flush,
1300     .capabilities   = CODEC_CAP_SUBFRAMES | CODEC_CAP_DR1,
1301     .long_name      = NULL_IF_CONFIG_SMALL("WavPack"),
1302 };