]> git.sesse.net Git - ffmpeg/blob - libavcodec/wavpack.c
png: check bit depth for PAL8/Y400A pixel formats.
[ffmpeg] / libavcodec / wavpack.c
1 /*
2  * WavPack lossless audio decoder
3  * Copyright (c) 2006,2011 Konstantin Shishkov
4  *
5  * This file is part of Libav.
6  *
7  * Libav is free software; you can redistribute it and/or
8  * modify it under the terms of the GNU Lesser General Public
9  * License as published by the Free Software Foundation; either
10  * version 2.1 of the License, or (at your option) any later version.
11  *
12  * Libav is distributed in the hope that it will be useful,
13  * but WITHOUT ANY WARRANTY; without even the implied warranty of
14  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
15  * Lesser General Public License for more details.
16  *
17  * You should have received a copy of the GNU Lesser General Public
18  * License along with Libav; if not, write to the Free Software
19  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
20  */
21
22 #define BITSTREAM_READER_LE
23
24 #include "libavutil/audioconvert.h"
25 #include "avcodec.h"
26 #include "get_bits.h"
27 #include "unary.h"
28
29 /**
30  * @file
31  * WavPack lossless audio decoder
32  */
33
34 #define WV_MONO           0x00000004
35 #define WV_JOINT_STEREO   0x00000010
36 #define WV_FALSE_STEREO   0x40000000
37
38 #define WV_HYBRID_MODE    0x00000008
39 #define WV_HYBRID_SHAPE   0x00000008
40 #define WV_HYBRID_BITRATE 0x00000200
41 #define WV_HYBRID_BALANCE 0x00000400
42
43 #define WV_FLT_SHIFT_ONES 0x01
44 #define WV_FLT_SHIFT_SAME 0x02
45 #define WV_FLT_SHIFT_SENT 0x04
46 #define WV_FLT_ZERO_SENT  0x08
47 #define WV_FLT_ZERO_SIGN  0x10
48
49 enum WP_ID_Flags {
50     WP_IDF_MASK   = 0x1F,
51     WP_IDF_IGNORE = 0x20,
52     WP_IDF_ODD    = 0x40,
53     WP_IDF_LONG   = 0x80
54 };
55
56 enum WP_ID {
57     WP_ID_DUMMY = 0,
58     WP_ID_ENCINFO,
59     WP_ID_DECTERMS,
60     WP_ID_DECWEIGHTS,
61     WP_ID_DECSAMPLES,
62     WP_ID_ENTROPY,
63     WP_ID_HYBRID,
64     WP_ID_SHAPING,
65     WP_ID_FLOATINFO,
66     WP_ID_INT32INFO,
67     WP_ID_DATA,
68     WP_ID_CORR,
69     WP_ID_EXTRABITS,
70     WP_ID_CHANINFO
71 };
72
73 typedef struct SavedContext {
74     int offset;
75     int size;
76     int bits_used;
77     uint32_t crc;
78 } SavedContext;
79
80 #define MAX_TERMS 16
81
82 typedef struct Decorr {
83     int delta;
84     int value;
85     int weightA;
86     int weightB;
87     int samplesA[8];
88     int samplesB[8];
89 } Decorr;
90
91 typedef struct WvChannel {
92     int median[3];
93     int slow_level, error_limit;
94     int bitrate_acc, bitrate_delta;
95 } WvChannel;
96
97 typedef struct WavpackFrameContext {
98     AVCodecContext *avctx;
99     int frame_flags;
100     int stereo, stereo_in;
101     int joint;
102     uint32_t CRC;
103     GetBitContext gb;
104     int got_extra_bits;
105     uint32_t crc_extra_bits;
106     GetBitContext gb_extra_bits;
107     int data_size; // in bits
108     int samples;
109     int terms;
110     Decorr decorr[MAX_TERMS];
111     int zero, one, zeroes;
112     int extra_bits;
113     int and, or, shift;
114     int post_shift;
115     int hybrid, hybrid_bitrate;
116     int hybrid_maxclip, hybrid_minclip;
117     int float_flag;
118     int float_shift;
119     int float_max_exp;
120     WvChannel ch[2];
121     int pos;
122     SavedContext sc, extra_sc;
123 } WavpackFrameContext;
124
125 #define WV_MAX_FRAME_DECODERS 14
126
127 typedef struct WavpackContext {
128     AVCodecContext *avctx;
129     AVFrame frame;
130
131     WavpackFrameContext *fdec[WV_MAX_FRAME_DECODERS];
132     int fdec_num;
133
134     int multichannel;
135     int mkv_mode;
136     int block;
137     int samples;
138     int ch_offset;
139 } WavpackContext;
140
141 // exponent table copied from WavPack source
142 static const uint8_t wp_exp2_table [256] = {
143     0x00, 0x01, 0x01, 0x02, 0x03, 0x03, 0x04, 0x05, 0x06, 0x06, 0x07, 0x08, 0x08, 0x09, 0x0a, 0x0b,
144     0x0b, 0x0c, 0x0d, 0x0e, 0x0e, 0x0f, 0x10, 0x10, 0x11, 0x12, 0x13, 0x13, 0x14, 0x15, 0x16, 0x16,
145     0x17, 0x18, 0x19, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1d, 0x1e, 0x1f, 0x20, 0x20, 0x21, 0x22, 0x23,
146     0x24, 0x24, 0x25, 0x26, 0x27, 0x28, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2c, 0x2d, 0x2e, 0x2f, 0x30,
147     0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3a, 0x3a, 0x3b, 0x3c, 0x3d,
148     0x3e, 0x3f, 0x40, 0x41, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48, 0x48, 0x49, 0x4a, 0x4b,
149     0x4c, 0x4d, 0x4e, 0x4f, 0x50, 0x51, 0x51, 0x52, 0x53, 0x54, 0x55, 0x56, 0x57, 0x58, 0x59, 0x5a,
150     0x5b, 0x5c, 0x5d, 0x5e, 0x5e, 0x5f, 0x60, 0x61, 0x62, 0x63, 0x64, 0x65, 0x66, 0x67, 0x68, 0x69,
151     0x6a, 0x6b, 0x6c, 0x6d, 0x6e, 0x6f, 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79,
152     0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f, 0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x87, 0x88, 0x89, 0x8a,
153     0x8b, 0x8c, 0x8d, 0x8e, 0x8f, 0x90, 0x91, 0x92, 0x93, 0x95, 0x96, 0x97, 0x98, 0x99, 0x9a, 0x9b,
154     0x9c, 0x9d, 0x9f, 0xa0, 0xa1, 0xa2, 0xa3, 0xa4, 0xa5, 0xa6, 0xa8, 0xa9, 0xaa, 0xab, 0xac, 0xad,
155     0xaf, 0xb0, 0xb1, 0xb2, 0xb3, 0xb4, 0xb6, 0xb7, 0xb8, 0xb9, 0xba, 0xbc, 0xbd, 0xbe, 0xbf, 0xc0,
156     0xc2, 0xc3, 0xc4, 0xc5, 0xc6, 0xc8, 0xc9, 0xca, 0xcb, 0xcd, 0xce, 0xcf, 0xd0, 0xd2, 0xd3, 0xd4,
157     0xd6, 0xd7, 0xd8, 0xd9, 0xdb, 0xdc, 0xdd, 0xde, 0xe0, 0xe1, 0xe2, 0xe4, 0xe5, 0xe6, 0xe8, 0xe9,
158     0xea, 0xec, 0xed, 0xee, 0xf0, 0xf1, 0xf2, 0xf4, 0xf5, 0xf6, 0xf8, 0xf9, 0xfa, 0xfc, 0xfd, 0xff
159 };
160
161 static const uint8_t wp_log2_table [] = {
162     0x00, 0x01, 0x03, 0x04, 0x06, 0x07, 0x09, 0x0a, 0x0b, 0x0d, 0x0e, 0x10, 0x11, 0x12, 0x14, 0x15,
163     0x16, 0x18, 0x19, 0x1a, 0x1c, 0x1d, 0x1e, 0x20, 0x21, 0x22, 0x24, 0x25, 0x26, 0x28, 0x29, 0x2a,
164     0x2c, 0x2d, 0x2e, 0x2f, 0x31, 0x32, 0x33, 0x34, 0x36, 0x37, 0x38, 0x39, 0x3b, 0x3c, 0x3d, 0x3e,
165     0x3f, 0x41, 0x42, 0x43, 0x44, 0x45, 0x47, 0x48, 0x49, 0x4a, 0x4b, 0x4d, 0x4e, 0x4f, 0x50, 0x51,
166     0x52, 0x54, 0x55, 0x56, 0x57, 0x58, 0x59, 0x5a, 0x5c, 0x5d, 0x5e, 0x5f, 0x60, 0x61, 0x62, 0x63,
167     0x64, 0x66, 0x67, 0x68, 0x69, 0x6a, 0x6b, 0x6c, 0x6d, 0x6e, 0x6f, 0x70, 0x71, 0x72, 0x74, 0x75,
168     0x76, 0x77, 0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f, 0x80, 0x81, 0x82, 0x83, 0x84, 0x85,
169     0x86, 0x87, 0x88, 0x89, 0x8a, 0x8b, 0x8c, 0x8d, 0x8e, 0x8f, 0x90, 0x91, 0x92, 0x93, 0x94, 0x95,
170     0x96, 0x97, 0x98, 0x99, 0x9a, 0x9b, 0x9b, 0x9c, 0x9d, 0x9e, 0x9f, 0xa0, 0xa1, 0xa2, 0xa3, 0xa4,
171     0xa5, 0xa6, 0xa7, 0xa8, 0xa9, 0xa9, 0xaa, 0xab, 0xac, 0xad, 0xae, 0xaf, 0xb0, 0xb1, 0xb2, 0xb2,
172     0xb3, 0xb4, 0xb5, 0xb6, 0xb7, 0xb8, 0xb9, 0xb9, 0xba, 0xbb, 0xbc, 0xbd, 0xbe, 0xbf, 0xc0, 0xc0,
173     0xc1, 0xc2, 0xc3, 0xc4, 0xc5, 0xc6, 0xc6, 0xc7, 0xc8, 0xc9, 0xca, 0xcb, 0xcb, 0xcc, 0xcd, 0xce,
174     0xcf, 0xd0, 0xd0, 0xd1, 0xd2, 0xd3, 0xd4, 0xd4, 0xd5, 0xd6, 0xd7, 0xd8, 0xd8, 0xd9, 0xda, 0xdb,
175     0xdc, 0xdc, 0xdd, 0xde, 0xdf, 0xe0, 0xe0, 0xe1, 0xe2, 0xe3, 0xe4, 0xe4, 0xe5, 0xe6, 0xe7, 0xe7,
176     0xe8, 0xe9, 0xea, 0xea, 0xeb, 0xec, 0xed, 0xee, 0xee, 0xef, 0xf0, 0xf1, 0xf1, 0xf2, 0xf3, 0xf4,
177     0xf4, 0xf5, 0xf6, 0xf7, 0xf7, 0xf8, 0xf9, 0xf9, 0xfa, 0xfb, 0xfc, 0xfc, 0xfd, 0xfe, 0xff, 0xff
178 };
179
180 static av_always_inline int wp_exp2(int16_t val)
181 {
182     int res, neg = 0;
183
184     if (val < 0) {
185         val = -val;
186         neg = 1;
187     }
188
189     res = wp_exp2_table[val & 0xFF] | 0x100;
190     val >>= 8;
191     res = (val > 9) ? (res << (val - 9)) : (res >> (9 - val));
192     return neg ? -res : res;
193 }
194
195 static av_always_inline int wp_log2(int32_t val)
196 {
197     int bits;
198
199     if (!val)
200         return 0;
201     if (val == 1)
202         return 256;
203     val += val >> 9;
204     bits = av_log2(val) + 1;
205     if (bits < 9)
206         return (bits << 8) + wp_log2_table[(val << (9 - bits)) & 0xFF];
207     else
208         return (bits << 8) + wp_log2_table[(val >> (bits - 9)) & 0xFF];
209 }
210
211 #define LEVEL_DECAY(a)  ((a + 0x80) >> 8)
212
213 // macros for manipulating median values
214 #define GET_MED(n) ((c->median[n] >> 4) + 1)
215 #define DEC_MED(n) c->median[n] -= ((c->median[n] + (128 >> n) - 2) / (128 >> n)) * 2
216 #define INC_MED(n) c->median[n] += ((c->median[n] + (128 >> n)    ) / (128 >> n)) * 5
217
218 // macros for applying weight
219 #define UPDATE_WEIGHT_CLIP(weight, delta, samples, in) \
220     if (samples && in) { \
221         if ((samples ^ in) < 0) { \
222             weight -= delta; \
223             if (weight < -1024) \
224                 weight = -1024; \
225         } else { \
226             weight += delta; \
227             if (weight > 1024) \
228                 weight = 1024; \
229         } \
230     }
231
232
233 static av_always_inline int get_tail(GetBitContext *gb, int k)
234 {
235     int p, e, res;
236
237     if (k < 1)
238         return 0;
239     p = av_log2(k);
240     e = (1 << (p + 1)) - k - 1;
241     res = p ? get_bits(gb, p) : 0;
242     if (res >= e)
243         res = (res << 1) - e + get_bits1(gb);
244     return res;
245 }
246
247 static void update_error_limit(WavpackFrameContext *ctx)
248 {
249     int i, br[2], sl[2];
250
251     for (i = 0; i <= ctx->stereo_in; i++) {
252         ctx->ch[i].bitrate_acc += ctx->ch[i].bitrate_delta;
253         br[i] = ctx->ch[i].bitrate_acc >> 16;
254         sl[i] = LEVEL_DECAY(ctx->ch[i].slow_level);
255     }
256     if (ctx->stereo_in && ctx->hybrid_bitrate) {
257         int balance = (sl[1] - sl[0] + br[1] + 1) >> 1;
258         if (balance > br[0]) {
259             br[1] = br[0] << 1;
260             br[0] = 0;
261         } else if (-balance > br[0]) {
262             br[0] <<= 1;
263             br[1] = 0;
264         } else {
265             br[1] = br[0] + balance;
266             br[0] = br[0] - balance;
267         }
268     }
269     for (i = 0; i <= ctx->stereo_in; i++) {
270         if (ctx->hybrid_bitrate) {
271             if (sl[i] - br[i] > -0x100)
272                 ctx->ch[i].error_limit = wp_exp2(sl[i] - br[i] + 0x100);
273             else
274                 ctx->ch[i].error_limit = 0;
275         } else {
276             ctx->ch[i].error_limit = wp_exp2(br[i]);
277         }
278     }
279 }
280
281 static int wv_get_value(WavpackFrameContext *ctx, GetBitContext *gb,
282                         int channel, int *last)
283 {
284     int t, t2;
285     int sign, base, add, ret;
286     WvChannel *c = &ctx->ch[channel];
287
288     *last = 0;
289
290     if ((ctx->ch[0].median[0] < 2U) && (ctx->ch[1].median[0] < 2U) &&
291         !ctx->zero && !ctx->one) {
292         if (ctx->zeroes) {
293             ctx->zeroes--;
294             if (ctx->zeroes) {
295                 c->slow_level -= LEVEL_DECAY(c->slow_level);
296                 return 0;
297             }
298         } else {
299             t = get_unary_0_33(gb);
300             if (t >= 2) {
301                 if (get_bits_left(gb) < t - 1)
302                     goto error;
303                 t = get_bits(gb, t - 1) | (1 << (t-1));
304             } else {
305                 if (get_bits_left(gb) < 0)
306                     goto error;
307             }
308             ctx->zeroes = t;
309             if (ctx->zeroes) {
310                 memset(ctx->ch[0].median, 0, sizeof(ctx->ch[0].median));
311                 memset(ctx->ch[1].median, 0, sizeof(ctx->ch[1].median));
312                 c->slow_level -= LEVEL_DECAY(c->slow_level);
313                 return 0;
314             }
315         }
316     }
317
318     if (ctx->zero) {
319         t = 0;
320         ctx->zero = 0;
321     } else {
322         t = get_unary_0_33(gb);
323         if (get_bits_left(gb) < 0)
324             goto error;
325         if (t == 16) {
326             t2 = get_unary_0_33(gb);
327             if (t2 < 2) {
328                 if (get_bits_left(gb) < 0)
329                     goto error;
330                 t += t2;
331             } else {
332                 if (get_bits_left(gb) < t2 - 1)
333                     goto error;
334                 t += get_bits(gb, t2 - 1) | (1 << (t2 - 1));
335             }
336         }
337
338         if (ctx->one) {
339             ctx->one = t & 1;
340             t = (t >> 1) + 1;
341         } else {
342             ctx->one = t & 1;
343             t >>= 1;
344         }
345         ctx->zero = !ctx->one;
346     }
347
348     if (ctx->hybrid && !channel)
349         update_error_limit(ctx);
350
351     if (!t) {
352         base = 0;
353         add  = GET_MED(0) - 1;
354         DEC_MED(0);
355     } else if (t == 1) {
356         base = GET_MED(0);
357         add  = GET_MED(1) - 1;
358         INC_MED(0);
359         DEC_MED(1);
360     } else if (t == 2) {
361         base = GET_MED(0) + GET_MED(1);
362         add  = GET_MED(2) - 1;
363         INC_MED(0);
364         INC_MED(1);
365         DEC_MED(2);
366     } else {
367         base = GET_MED(0) + GET_MED(1) + GET_MED(2) * (t - 2);
368         add  = GET_MED(2) - 1;
369         INC_MED(0);
370         INC_MED(1);
371         INC_MED(2);
372     }
373     if (!c->error_limit) {
374         ret = base + get_tail(gb, add);
375         if (get_bits_left(gb) <= 0)
376             goto error;
377     } else {
378         int mid = (base * 2 + add + 1) >> 1;
379         while (add > c->error_limit) {
380             if (get_bits_left(gb) <= 0)
381                 goto error;
382             if (get_bits1(gb)) {
383                 add -= (mid - base);
384                 base = mid;
385             } else
386                 add = mid - base - 1;
387             mid = (base * 2 + add + 1) >> 1;
388         }
389         ret = mid;
390     }
391     sign = get_bits1(gb);
392     if (ctx->hybrid_bitrate)
393         c->slow_level += wp_log2(ret) - LEVEL_DECAY(c->slow_level);
394     return sign ? ~ret : ret;
395
396 error:
397     *last = 1;
398     return 0;
399 }
400
401 static inline int wv_get_value_integer(WavpackFrameContext *s, uint32_t *crc,
402                                        int S)
403 {
404     int bit;
405
406     if (s->extra_bits){
407         S <<= s->extra_bits;
408
409         if (s->got_extra_bits && get_bits_left(&s->gb_extra_bits) >= s->extra_bits) {
410             S |= get_bits(&s->gb_extra_bits, s->extra_bits);
411             *crc = *crc * 9 + (S & 0xffff) * 3 + ((unsigned)S >> 16);
412         }
413     }
414
415     bit = (S & s->and) | s->or;
416     bit = ((S + bit) << s->shift) - bit;
417
418     if (s->hybrid)
419         bit = av_clip(bit, s->hybrid_minclip, s->hybrid_maxclip);
420
421     return bit << s->post_shift;
422 }
423
424 static float wv_get_value_float(WavpackFrameContext *s, uint32_t *crc, int S)
425 {
426     union {
427         float    f;
428         uint32_t u;
429     } value;
430
431     unsigned int sign;
432     int exp = s->float_max_exp;
433
434     if (s->got_extra_bits) {
435         const int max_bits  = 1 + 23 + 8 + 1;
436         const int left_bits = get_bits_left(&s->gb_extra_bits);
437
438         if (left_bits + 8 * FF_INPUT_BUFFER_PADDING_SIZE < max_bits)
439             return 0.0;
440     }
441
442     if (S) {
443         S <<= s->float_shift;
444         sign = S < 0;
445         if (sign)
446             S = -S;
447         if (S >= 0x1000000) {
448             if (s->got_extra_bits && get_bits1(&s->gb_extra_bits))
449                 S = get_bits(&s->gb_extra_bits, 23);
450             else
451                 S = 0;
452             exp = 255;
453         } else if (exp) {
454             int shift = 23 - av_log2(S);
455             exp = s->float_max_exp;
456             if (exp <= shift)
457                 shift = --exp;
458             exp -= shift;
459
460             if (shift) {
461                 S <<= shift;
462                 if ((s->float_flag & WV_FLT_SHIFT_ONES) ||
463                     (s->got_extra_bits && (s->float_flag & WV_FLT_SHIFT_SAME) &&
464                      get_bits1(&s->gb_extra_bits))) {
465                     S |= (1 << shift) - 1;
466                 } else if (s->got_extra_bits &&
467                            (s->float_flag & WV_FLT_SHIFT_SENT)) {
468                     S |= get_bits(&s->gb_extra_bits, shift);
469                 }
470             }
471         } else {
472             exp = s->float_max_exp;
473         }
474         S &= 0x7fffff;
475     } else {
476         sign = 0;
477         exp = 0;
478         if (s->got_extra_bits && (s->float_flag & WV_FLT_ZERO_SENT)) {
479             if (get_bits1(&s->gb_extra_bits)) {
480                 S = get_bits(&s->gb_extra_bits, 23);
481                 if (s->float_max_exp >= 25)
482                     exp = get_bits(&s->gb_extra_bits, 8);
483                 sign = get_bits1(&s->gb_extra_bits);
484             } else {
485                 if (s->float_flag & WV_FLT_ZERO_SIGN)
486                     sign = get_bits1(&s->gb_extra_bits);
487             }
488         }
489     }
490
491     *crc = *crc * 27 + S * 9 + exp * 3 + sign;
492
493     value.u = (sign << 31) | (exp << 23) | S;
494     return value.f;
495 }
496
497 static void wv_reset_saved_context(WavpackFrameContext *s)
498 {
499     s->pos = 0;
500     s->sc.crc = s->extra_sc.crc = 0xFFFFFFFF;
501 }
502
503 static inline int wv_check_crc(WavpackFrameContext *s, uint32_t crc,
504                                uint32_t crc_extra_bits)
505 {
506     if (crc != s->CRC) {
507         av_log(s->avctx, AV_LOG_ERROR, "CRC error\n");
508         return AVERROR_INVALIDDATA;
509     }
510     if (s->got_extra_bits && crc_extra_bits != s->crc_extra_bits) {
511         av_log(s->avctx, AV_LOG_ERROR, "Extra bits CRC error\n");
512         return AVERROR_INVALIDDATA;
513     }
514
515     return 0;
516 }
517
518 static inline int wv_unpack_stereo(WavpackFrameContext *s, GetBitContext *gb,
519                                    void *dst, const int type)
520 {
521     int i, j, count = 0;
522     int last, t;
523     int A, B, L, L2, R, R2;
524     int pos = s->pos;
525     uint32_t crc = s->sc.crc;
526     uint32_t crc_extra_bits = s->extra_sc.crc;
527     int16_t *dst16 = dst;
528     int32_t *dst32 = dst;
529     float   *dstfl = dst;
530     const int channel_pad = s->avctx->channels - 2;
531
532     s->one = s->zero = s->zeroes = 0;
533     do {
534         L = wv_get_value(s, gb, 0, &last);
535         if (last)
536             break;
537         R = wv_get_value(s, gb, 1, &last);
538         if (last)
539             break;
540         for (i = 0; i < s->terms; i++) {
541             t = s->decorr[i].value;
542             if (t > 0) {
543                 if (t > 8) {
544                     if (t & 1) {
545                         A = 2 * s->decorr[i].samplesA[0] - s->decorr[i].samplesA[1];
546                         B = 2 * s->decorr[i].samplesB[0] - s->decorr[i].samplesB[1];
547                     } else {
548                         A = (3 * s->decorr[i].samplesA[0] - s->decorr[i].samplesA[1]) >> 1;
549                         B = (3 * s->decorr[i].samplesB[0] - s->decorr[i].samplesB[1]) >> 1;
550                     }
551                     s->decorr[i].samplesA[1] = s->decorr[i].samplesA[0];
552                     s->decorr[i].samplesB[1] = s->decorr[i].samplesB[0];
553                     j = 0;
554                 } else {
555                     A = s->decorr[i].samplesA[pos];
556                     B = s->decorr[i].samplesB[pos];
557                     j = (pos + t) & 7;
558                 }
559                 if (type != AV_SAMPLE_FMT_S16) {
560                     L2 = L + ((s->decorr[i].weightA * (int64_t)A + 512) >> 10);
561                     R2 = R + ((s->decorr[i].weightB * (int64_t)B + 512) >> 10);
562                 } else {
563                     L2 = L + ((s->decorr[i].weightA * A + 512) >> 10);
564                     R2 = R + ((s->decorr[i].weightB * B + 512) >> 10);
565                 }
566                 if (A && L) s->decorr[i].weightA -= ((((L ^ A) >> 30) & 2) - 1) * s->decorr[i].delta;
567                 if (B && R) s->decorr[i].weightB -= ((((R ^ B) >> 30) & 2) - 1) * s->decorr[i].delta;
568                 s->decorr[i].samplesA[j] = L = L2;
569                 s->decorr[i].samplesB[j] = R = R2;
570             } else if (t == -1) {
571                 if (type != AV_SAMPLE_FMT_S16)
572                     L2 = L + ((s->decorr[i].weightA * (int64_t)s->decorr[i].samplesA[0] + 512) >> 10);
573                 else
574                     L2 = L + ((s->decorr[i].weightA * s->decorr[i].samplesA[0] + 512) >> 10);
575                 UPDATE_WEIGHT_CLIP(s->decorr[i].weightA, s->decorr[i].delta, s->decorr[i].samplesA[0], L);
576                 L = L2;
577                 if (type != AV_SAMPLE_FMT_S16)
578                     R2 = R + ((s->decorr[i].weightB * (int64_t)L2 + 512) >> 10);
579                 else
580                     R2 = R + ((s->decorr[i].weightB * L2 + 512) >> 10);
581                 UPDATE_WEIGHT_CLIP(s->decorr[i].weightB, s->decorr[i].delta, L2, R);
582                 R = R2;
583                 s->decorr[i].samplesA[0] = R;
584             } else {
585                 if (type != AV_SAMPLE_FMT_S16)
586                     R2 = R + ((s->decorr[i].weightB * (int64_t)s->decorr[i].samplesB[0] + 512) >> 10);
587                 else
588                     R2 = R + ((s->decorr[i].weightB * s->decorr[i].samplesB[0] + 512) >> 10);
589                 UPDATE_WEIGHT_CLIP(s->decorr[i].weightB, s->decorr[i].delta, s->decorr[i].samplesB[0], R);
590                 R = R2;
591
592                 if (t == -3) {
593                     R2 = s->decorr[i].samplesA[0];
594                     s->decorr[i].samplesA[0] = R;
595                 }
596
597                 if (type != AV_SAMPLE_FMT_S16)
598                     L2 = L + ((s->decorr[i].weightA * (int64_t)R2 + 512) >> 10);
599                 else
600                     L2 = L + ((s->decorr[i].weightA * R2 + 512) >> 10);
601                 UPDATE_WEIGHT_CLIP(s->decorr[i].weightA, s->decorr[i].delta, R2, L);
602                 L = L2;
603                 s->decorr[i].samplesB[0] = L;
604             }
605         }
606         pos = (pos + 1) & 7;
607         if (s->joint)
608             L += (R -= (L >> 1));
609         crc = (crc * 3 + L) * 3 + R;
610
611         if (type == AV_SAMPLE_FMT_FLT) {
612             *dstfl++ = wv_get_value_float(s, &crc_extra_bits, L);
613             *dstfl++ = wv_get_value_float(s, &crc_extra_bits, R);
614             dstfl += channel_pad;
615         } else if (type == AV_SAMPLE_FMT_S32) {
616             *dst32++ = wv_get_value_integer(s, &crc_extra_bits, L);
617             *dst32++ = wv_get_value_integer(s, &crc_extra_bits, R);
618             dst32 += channel_pad;
619         } else {
620             *dst16++ = wv_get_value_integer(s, &crc_extra_bits, L);
621             *dst16++ = wv_get_value_integer(s, &crc_extra_bits, R);
622             dst16 += channel_pad;
623         }
624         count++;
625     } while (!last && count < s->samples);
626
627     wv_reset_saved_context(s);
628     if ((s->avctx->err_recognition & AV_EF_CRCCHECK) &&
629         wv_check_crc(s, crc, crc_extra_bits))
630         return AVERROR_INVALIDDATA;
631
632     return count * 2;
633 }
634
635 static inline int wv_unpack_mono(WavpackFrameContext *s, GetBitContext *gb,
636                                  void *dst, const int type)
637 {
638     int i, j, count = 0;
639     int last, t;
640     int A, S, T;
641     int pos = s->pos;
642     uint32_t crc = s->sc.crc;
643     uint32_t crc_extra_bits = s->extra_sc.crc;
644     int16_t *dst16 = dst;
645     int32_t *dst32 = dst;
646     float   *dstfl = dst;
647     const int channel_stride = s->avctx->channels;
648
649     s->one = s->zero = s->zeroes = 0;
650     do {
651         T = wv_get_value(s, gb, 0, &last);
652         S = 0;
653         if (last)
654             break;
655         for (i = 0; i < s->terms; i++) {
656             t = s->decorr[i].value;
657             if (t > 8) {
658                 if (t & 1)
659                     A =  2 * s->decorr[i].samplesA[0] - s->decorr[i].samplesA[1];
660                 else
661                     A = (3 * s->decorr[i].samplesA[0] - s->decorr[i].samplesA[1]) >> 1;
662                 s->decorr[i].samplesA[1] = s->decorr[i].samplesA[0];
663                 j = 0;
664             } else {
665                 A = s->decorr[i].samplesA[pos];
666                 j = (pos + t) & 7;
667             }
668             if (type != AV_SAMPLE_FMT_S16)
669                 S = T + ((s->decorr[i].weightA * (int64_t)A + 512) >> 10);
670             else
671                 S = T + ((s->decorr[i].weightA * A + 512) >> 10);
672             if (A && T)
673                 s->decorr[i].weightA -= ((((T ^ A) >> 30) & 2) - 1) * s->decorr[i].delta;
674             s->decorr[i].samplesA[j] = T = S;
675         }
676         pos = (pos + 1) & 7;
677         crc = crc * 3 + S;
678
679         if (type == AV_SAMPLE_FMT_FLT) {
680             *dstfl = wv_get_value_float(s, &crc_extra_bits, S);
681             dstfl += channel_stride;
682         } else if (type == AV_SAMPLE_FMT_S32) {
683             *dst32 = wv_get_value_integer(s, &crc_extra_bits, S);
684             dst32 += channel_stride;
685         } else {
686             *dst16 = wv_get_value_integer(s, &crc_extra_bits, S);
687             dst16 += channel_stride;
688         }
689         count++;
690     } while (!last && count < s->samples);
691
692     wv_reset_saved_context(s);
693     if ((s->avctx->err_recognition & AV_EF_CRCCHECK) &&
694         wv_check_crc(s, crc, crc_extra_bits))
695         return AVERROR_INVALIDDATA;
696
697     return count;
698 }
699
700 static av_cold int wv_alloc_frame_context(WavpackContext *c)
701 {
702
703     if (c->fdec_num == WV_MAX_FRAME_DECODERS)
704         return -1;
705
706     c->fdec[c->fdec_num] = av_mallocz(sizeof(**c->fdec));
707     if (!c->fdec[c->fdec_num])
708         return -1;
709     c->fdec_num++;
710     c->fdec[c->fdec_num - 1]->avctx = c->avctx;
711     wv_reset_saved_context(c->fdec[c->fdec_num - 1]);
712
713     return 0;
714 }
715
716 static av_cold int wavpack_decode_init(AVCodecContext *avctx)
717 {
718     WavpackContext *s = avctx->priv_data;
719
720     s->avctx = avctx;
721     if (avctx->bits_per_coded_sample <= 16)
722         avctx->sample_fmt = AV_SAMPLE_FMT_S16;
723     else
724         avctx->sample_fmt = AV_SAMPLE_FMT_S32;
725     if (avctx->channels <= 2 && !avctx->channel_layout)
726         avctx->channel_layout = (avctx->channels == 2) ? AV_CH_LAYOUT_STEREO :
727                                                          AV_CH_LAYOUT_MONO;
728
729     s->multichannel = avctx->channels > 2;
730     /* lavf demuxer does not provide extradata, Matroska stores 0x403
731        there, use this to detect decoding mode for multichannel */
732     s->mkv_mode = 0;
733     if (s->multichannel && avctx->extradata && avctx->extradata_size == 2) {
734         int ver = AV_RL16(avctx->extradata);
735         if (ver >= 0x402 && ver <= 0x410)
736             s->mkv_mode = 1;
737     }
738
739     s->fdec_num = 0;
740
741     avcodec_get_frame_defaults(&s->frame);
742     avctx->coded_frame = &s->frame;
743
744     return 0;
745 }
746
747 static av_cold int wavpack_decode_end(AVCodecContext *avctx)
748 {
749     WavpackContext *s = avctx->priv_data;
750     int i;
751
752     for (i = 0; i < s->fdec_num; i++)
753         av_freep(&s->fdec[i]);
754     s->fdec_num = 0;
755
756     return 0;
757 }
758
759 static int wavpack_decode_block(AVCodecContext *avctx, int block_no,
760                                 void *data, int *got_frame_ptr,
761                                 const uint8_t *buf, int buf_size)
762 {
763     WavpackContext *wc = avctx->priv_data;
764     WavpackFrameContext *s;
765     void *samples = data;
766     int samplecount;
767     int got_terms   = 0, got_weights = 0, got_samples = 0,
768         got_entropy = 0, got_bs      = 0, got_float   = 0, got_hybrid = 0;
769     const uint8_t *orig_buf = buf;
770     const uint8_t *buf_end  = buf + buf_size;
771     int i, j, id, size, ssize, weights, t;
772     int bpp, chan, chmask, orig_bpp;
773
774     if (buf_size == 0) {
775         *got_frame_ptr = 0;
776         return 0;
777     }
778
779     if (block_no >= wc->fdec_num && wv_alloc_frame_context(wc) < 0) {
780         av_log(avctx, AV_LOG_ERROR, "Error creating frame decode context\n");
781         return -1;
782     }
783
784     s = wc->fdec[block_no];
785     if (!s) {
786         av_log(avctx, AV_LOG_ERROR, "Context for block %d is not present\n", block_no);
787         return -1;
788     }
789
790     memset(s->decorr, 0, MAX_TERMS * sizeof(Decorr));
791     memset(s->ch, 0, sizeof(s->ch));
792     s->extra_bits = 0;
793     s->and = s->or = s->shift = 0;
794     s->got_extra_bits = 0;
795
796     if (!wc->mkv_mode) {
797         s->samples = AV_RL32(buf); buf += 4;
798         if (!s->samples) {
799             *got_frame_ptr = 0;
800             return 0;
801         }
802     } else {
803         s->samples = wc->samples;
804     }
805     s->frame_flags = AV_RL32(buf); buf += 4;
806     bpp = av_get_bytes_per_sample(avctx->sample_fmt);
807     samples = (uint8_t*)samples + bpp * wc->ch_offset;
808     orig_bpp = ((s->frame_flags & 0x03) + 1) << 3;
809
810     s->stereo         = !(s->frame_flags & WV_MONO);
811     s->stereo_in      =  (s->frame_flags & WV_FALSE_STEREO) ? 0 : s->stereo;
812     s->joint          =   s->frame_flags & WV_JOINT_STEREO;
813     s->hybrid         =   s->frame_flags & WV_HYBRID_MODE;
814     s->hybrid_bitrate =   s->frame_flags & WV_HYBRID_BITRATE;
815     s->post_shift     = bpp * 8 - orig_bpp + ((s->frame_flags >> 13) & 0x1f);
816     s->hybrid_maxclip = (( 1LL << (orig_bpp - 1)) - 1);
817     s->hybrid_minclip = ((-1LL << (orig_bpp - 1)));
818     s->CRC            = AV_RL32(buf); buf += 4;
819     if (wc->mkv_mode)
820         buf += 4; //skip block size;
821
822     wc->ch_offset += 1 + s->stereo;
823
824     // parse metadata blocks
825     while (buf < buf_end) {
826         id   = *buf++;
827         size = *buf++;
828         if (id & WP_IDF_LONG) {
829             size |= (*buf++) << 8;
830             size |= (*buf++) << 16;
831         }
832         size <<= 1; // size is specified in words
833         ssize = size;
834         if (id & WP_IDF_ODD)
835             size--;
836         if (size < 0) {
837             av_log(avctx, AV_LOG_ERROR, "Got incorrect block %02X with size %i\n", id, size);
838             break;
839         }
840         if (buf + ssize > buf_end) {
841             av_log(avctx, AV_LOG_ERROR, "Block size %i is out of bounds\n", size);
842             break;
843         }
844         if (id & WP_IDF_IGNORE) {
845             buf += ssize;
846             continue;
847         }
848         switch (id & WP_IDF_MASK) {
849         case WP_ID_DECTERMS:
850             if (size > MAX_TERMS) {
851                 av_log(avctx, AV_LOG_ERROR, "Too many decorrelation terms\n");
852                 s->terms = 0;
853                 buf += ssize;
854                 continue;
855             }
856             s->terms = size;
857             for (i = 0; i < s->terms; i++) {
858                 s->decorr[s->terms - i - 1].value = (*buf & 0x1F) - 5;
859                 s->decorr[s->terms - i - 1].delta = *buf >> 5;
860                 buf++;
861             }
862             got_terms = 1;
863             break;
864         case WP_ID_DECWEIGHTS:
865             if (!got_terms) {
866                 av_log(avctx, AV_LOG_ERROR, "No decorrelation terms met\n");
867                 continue;
868             }
869             weights = size >> s->stereo_in;
870             if (weights > MAX_TERMS || weights > s->terms) {
871                 av_log(avctx, AV_LOG_ERROR, "Too many decorrelation weights\n");
872                 buf += ssize;
873                 continue;
874             }
875             for (i = 0; i < weights; i++) {
876                 t = (int8_t)(*buf++);
877                 s->decorr[s->terms - i - 1].weightA = t << 3;
878                 if (s->decorr[s->terms - i - 1].weightA > 0)
879                     s->decorr[s->terms - i - 1].weightA +=
880                             (s->decorr[s->terms - i - 1].weightA + 64) >> 7;
881                 if (s->stereo_in) {
882                     t = (int8_t)(*buf++);
883                     s->decorr[s->terms - i - 1].weightB = t << 3;
884                     if (s->decorr[s->terms - i - 1].weightB > 0)
885                         s->decorr[s->terms - i - 1].weightB +=
886                                 (s->decorr[s->terms - i - 1].weightB + 64) >> 7;
887                 }
888             }
889             got_weights = 1;
890             break;
891         case WP_ID_DECSAMPLES:
892             if (!got_terms) {
893                 av_log(avctx, AV_LOG_ERROR, "No decorrelation terms met\n");
894                 continue;
895             }
896             t = 0;
897             for (i = s->terms - 1; (i >= 0) && (t < size); i--) {
898                 if (s->decorr[i].value > 8) {
899                     s->decorr[i].samplesA[0] = wp_exp2(AV_RL16(buf)); buf += 2;
900                     s->decorr[i].samplesA[1] = wp_exp2(AV_RL16(buf)); buf += 2;
901                     if (s->stereo_in) {
902                         s->decorr[i].samplesB[0] = wp_exp2(AV_RL16(buf)); buf += 2;
903                         s->decorr[i].samplesB[1] = wp_exp2(AV_RL16(buf)); buf += 2;
904                         t += 4;
905                     }
906                     t += 4;
907                 } else if (s->decorr[i].value < 0) {
908                     s->decorr[i].samplesA[0] = wp_exp2(AV_RL16(buf)); buf += 2;
909                     s->decorr[i].samplesB[0] = wp_exp2(AV_RL16(buf)); buf += 2;
910                     t += 4;
911                 } else {
912                     for (j = 0; j < s->decorr[i].value; j++) {
913                         s->decorr[i].samplesA[j] = wp_exp2(AV_RL16(buf)); buf += 2;
914                         if (s->stereo_in) {
915                             s->decorr[i].samplesB[j] = wp_exp2(AV_RL16(buf)); buf += 2;
916                         }
917                     }
918                     t += s->decorr[i].value * 2 * (s->stereo_in + 1);
919                 }
920             }
921             got_samples = 1;
922             break;
923         case WP_ID_ENTROPY:
924             if (size != 6 * (s->stereo_in + 1)) {
925                 av_log(avctx, AV_LOG_ERROR, "Entropy vars size should be %i, "
926                        "got %i", 6 * (s->stereo_in + 1), size);
927                 buf += ssize;
928                 continue;
929             }
930             for (j = 0; j <= s->stereo_in; j++) {
931                 for (i = 0; i < 3; i++) {
932                     s->ch[j].median[i] = wp_exp2(AV_RL16(buf));
933                     buf += 2;
934                 }
935             }
936             got_entropy = 1;
937             break;
938         case WP_ID_HYBRID:
939             if (s->hybrid_bitrate) {
940                 for (i = 0; i <= s->stereo_in; i++) {
941                     s->ch[i].slow_level = wp_exp2(AV_RL16(buf));
942                     buf += 2;
943                     size -= 2;
944                 }
945             }
946             for (i = 0; i < (s->stereo_in + 1); i++) {
947                 s->ch[i].bitrate_acc = AV_RL16(buf) << 16;
948                 buf += 2;
949                 size -= 2;
950             }
951             if (size > 0) {
952                 for (i = 0; i < (s->stereo_in + 1); i++) {
953                     s->ch[i].bitrate_delta = wp_exp2((int16_t)AV_RL16(buf));
954                     buf += 2;
955                 }
956             } else {
957                 for (i = 0; i < (s->stereo_in + 1); i++)
958                     s->ch[i].bitrate_delta = 0;
959             }
960             got_hybrid = 1;
961             break;
962         case WP_ID_INT32INFO:
963             if (size != 4) {
964                 av_log(avctx, AV_LOG_ERROR, "Invalid INT32INFO, size = %i, sent_bits = %i\n", size, *buf);
965                 buf += ssize;
966                 continue;
967             }
968             if (buf[0])
969                 s->extra_bits = buf[0];
970             else if (buf[1])
971                 s->shift = buf[1];
972             else if (buf[2]){
973                 s->and = s->or = 1;
974                 s->shift = buf[2];
975             } else if(buf[3]) {
976                 s->and   = 1;
977                 s->shift = buf[3];
978             }
979             /* original WavPack decoder forces 32-bit lossy sound to be treated
980              * as 24-bit one in order to have proper clipping
981              */
982             if (s->hybrid && bpp == 4 && s->post_shift < 8 && s->shift > 8) {
983                 s->post_shift += 8;
984                 s->shift      -= 8;
985                 s->hybrid_maxclip >>= 8;
986                 s->hybrid_minclip >>= 8;
987             }
988             buf += 4;
989             break;
990         case WP_ID_FLOATINFO:
991             if (size != 4) {
992                 av_log(avctx, AV_LOG_ERROR, "Invalid FLOATINFO, size = %i\n", size);
993                 buf += ssize;
994                 continue;
995             }
996             s->float_flag    = buf[0];
997             s->float_shift   = buf[1];
998             s->float_max_exp = buf[2];
999             buf += 4;
1000             got_float = 1;
1001             break;
1002         case WP_ID_DATA:
1003             s->sc.offset = buf - orig_buf;
1004             s->sc.size   = size * 8;
1005             init_get_bits(&s->gb, buf, size * 8);
1006             s->data_size = size * 8;
1007             buf += size;
1008             got_bs = 1;
1009             break;
1010         case WP_ID_EXTRABITS:
1011             if (size <= 4) {
1012                 av_log(avctx, AV_LOG_ERROR, "Invalid EXTRABITS, size = %i\n",
1013                        size);
1014                 buf += size;
1015                 continue;
1016             }
1017             s->extra_sc.offset = buf - orig_buf;
1018             s->extra_sc.size   = size * 8;
1019             init_get_bits(&s->gb_extra_bits, buf, size * 8);
1020             s->crc_extra_bits = get_bits_long(&s->gb_extra_bits, 32);
1021             buf += size;
1022             s->got_extra_bits = 1;
1023             break;
1024         case WP_ID_CHANINFO:
1025             if (size <= 1) {
1026                 av_log(avctx, AV_LOG_ERROR, "Insufficient channel information\n");
1027                 return -1;
1028             }
1029             chan = *buf++;
1030             switch (size - 2) {
1031             case 0: chmask = *buf;         break;
1032             case 1: chmask = AV_RL16(buf); break;
1033             case 2: chmask = AV_RL24(buf); break;
1034             case 3: chmask = AV_RL32(buf); break;
1035             case 5:
1036                 chan |= (buf[1] & 0xF) << 8;
1037                 chmask = AV_RL24(buf + 2);
1038                 break;
1039             default:
1040                 av_log(avctx, AV_LOG_ERROR, "Invalid channel info size %d\n",
1041                        size);
1042                 chan   = avctx->channels;
1043                 chmask = avctx->channel_layout;
1044             }
1045             if (chan != avctx->channels) {
1046                 av_log(avctx, AV_LOG_ERROR, "Block reports total %d channels, "
1047                        "decoder believes it's %d channels\n", chan,
1048                        avctx->channels);
1049                 return -1;
1050             }
1051             if (!avctx->channel_layout)
1052                 avctx->channel_layout = chmask;
1053             buf += size - 1;
1054             break;
1055         default:
1056             buf += size;
1057         }
1058         if (id & WP_IDF_ODD)
1059             buf++;
1060     }
1061
1062     if (!got_terms) {
1063         av_log(avctx, AV_LOG_ERROR, "No block with decorrelation terms\n");
1064         return -1;
1065     }
1066     if (!got_weights) {
1067         av_log(avctx, AV_LOG_ERROR, "No block with decorrelation weights\n");
1068         return -1;
1069     }
1070     if (!got_samples) {
1071         av_log(avctx, AV_LOG_ERROR, "No block with decorrelation samples\n");
1072         return -1;
1073     }
1074     if (!got_entropy) {
1075         av_log(avctx, AV_LOG_ERROR, "No block with entropy info\n");
1076         return -1;
1077     }
1078     if (s->hybrid && !got_hybrid) {
1079         av_log(avctx, AV_LOG_ERROR, "Hybrid config not found\n");
1080         return -1;
1081     }
1082     if (!got_bs) {
1083         av_log(avctx, AV_LOG_ERROR, "Packed samples not found\n");
1084         return -1;
1085     }
1086     if (!got_float && avctx->sample_fmt == AV_SAMPLE_FMT_FLT) {
1087         av_log(avctx, AV_LOG_ERROR, "Float information not found\n");
1088         return -1;
1089     }
1090     if (s->got_extra_bits && avctx->sample_fmt != AV_SAMPLE_FMT_FLT) {
1091         const int size   = get_bits_left(&s->gb_extra_bits);
1092         const int wanted = s->samples * s->extra_bits << s->stereo_in;
1093         if (size < wanted) {
1094             av_log(avctx, AV_LOG_ERROR, "Too small EXTRABITS\n");
1095             s->got_extra_bits = 0;
1096         }
1097     }
1098
1099     if (s->stereo_in) {
1100         if (avctx->sample_fmt == AV_SAMPLE_FMT_S16)
1101             samplecount = wv_unpack_stereo(s, &s->gb, samples, AV_SAMPLE_FMT_S16);
1102         else if (avctx->sample_fmt == AV_SAMPLE_FMT_S32)
1103             samplecount = wv_unpack_stereo(s, &s->gb, samples, AV_SAMPLE_FMT_S32);
1104         else
1105             samplecount = wv_unpack_stereo(s, &s->gb, samples, AV_SAMPLE_FMT_FLT);
1106
1107         if (samplecount < 0)
1108             return -1;
1109
1110         samplecount >>= 1;
1111     } else {
1112         const int channel_stride = avctx->channels;
1113
1114         if (avctx->sample_fmt == AV_SAMPLE_FMT_S16)
1115             samplecount = wv_unpack_mono(s, &s->gb, samples, AV_SAMPLE_FMT_S16);
1116         else if (avctx->sample_fmt == AV_SAMPLE_FMT_S32)
1117             samplecount = wv_unpack_mono(s, &s->gb, samples, AV_SAMPLE_FMT_S32);
1118         else
1119             samplecount = wv_unpack_mono(s, &s->gb, samples, AV_SAMPLE_FMT_FLT);
1120
1121         if (samplecount < 0)
1122             return -1;
1123
1124         if (s->stereo && avctx->sample_fmt == AV_SAMPLE_FMT_S16) {
1125             int16_t *dst = (int16_t*)samples + 1;
1126             int16_t *src = (int16_t*)samples;
1127             int cnt = samplecount;
1128             while (cnt--) {
1129                 *dst = *src;
1130                 src += channel_stride;
1131                 dst += channel_stride;
1132             }
1133         } else if (s->stereo && avctx->sample_fmt == AV_SAMPLE_FMT_S32) {
1134             int32_t *dst = (int32_t*)samples + 1;
1135             int32_t *src = (int32_t*)samples;
1136             int cnt = samplecount;
1137             while (cnt--) {
1138                 *dst = *src;
1139                 src += channel_stride;
1140                 dst += channel_stride;
1141             }
1142         } else if (s->stereo) {
1143             float *dst = (float*)samples + 1;
1144             float *src = (float*)samples;
1145             int cnt = samplecount;
1146             while (cnt--) {
1147                 *dst = *src;
1148                 src += channel_stride;
1149                 dst += channel_stride;
1150             }
1151         }
1152     }
1153
1154     *got_frame_ptr = 1;
1155
1156     return samplecount * bpp;
1157 }
1158
1159 static void wavpack_decode_flush(AVCodecContext *avctx)
1160 {
1161     WavpackContext *s = avctx->priv_data;
1162     int i;
1163
1164     for (i = 0; i < s->fdec_num; i++)
1165         wv_reset_saved_context(s->fdec[i]);
1166 }
1167
1168 static int wavpack_decode_frame(AVCodecContext *avctx, void *data,
1169                                 int *got_frame_ptr, AVPacket *avpkt)
1170 {
1171     WavpackContext *s  = avctx->priv_data;
1172     const uint8_t *buf = avpkt->data;
1173     int buf_size       = avpkt->size;
1174     int frame_size, ret, frame_flags;
1175     int samplecount = 0;
1176
1177     s->block     = 0;
1178     s->ch_offset = 0;
1179
1180     /* determine number of samples */
1181     if (s->mkv_mode) {
1182         s->samples  = AV_RL32(buf); buf += 4;
1183         frame_flags = AV_RL32(buf);
1184     } else {
1185         if (s->multichannel) {
1186             s->samples  = AV_RL32(buf + 4);
1187             frame_flags = AV_RL32(buf + 8);
1188         } else {
1189             s->samples  = AV_RL32(buf);
1190             frame_flags = AV_RL32(buf + 4);
1191         }
1192     }
1193     if (s->samples <= 0) {
1194         av_log(avctx, AV_LOG_ERROR, "Invalid number of samples: %d\n",
1195                s->samples);
1196         return AVERROR(EINVAL);
1197     }
1198
1199     if (frame_flags & 0x80) {
1200         avctx->sample_fmt = AV_SAMPLE_FMT_FLT;
1201     } else if ((frame_flags & 0x03) <= 1) {
1202         avctx->sample_fmt = AV_SAMPLE_FMT_S16;
1203     } else {
1204         avctx->sample_fmt = AV_SAMPLE_FMT_S32;
1205     }
1206
1207     /* get output buffer */
1208     s->frame.nb_samples = s->samples;
1209     if ((ret = avctx->get_buffer(avctx, &s->frame)) < 0) {
1210         av_log(avctx, AV_LOG_ERROR, "get_buffer() failed\n");
1211         return ret;
1212     }
1213
1214     while (buf_size > 0) {
1215         if (!s->multichannel) {
1216             frame_size = buf_size;
1217         } else {
1218             if (!s->mkv_mode) {
1219                 frame_size = AV_RL32(buf) - 12; buf += 4; buf_size -= 4;
1220             } else {
1221                 if (buf_size < 12) //MKV files can have zero flags after last block
1222                     break;
1223                 frame_size = AV_RL32(buf + 8) + 12;
1224             }
1225         }
1226         if (frame_size < 0 || frame_size > buf_size) {
1227             av_log(avctx, AV_LOG_ERROR, "Block %d has invalid size (size %d "
1228                    "vs. %d bytes left)\n", s->block, frame_size, buf_size);
1229             wavpack_decode_flush(avctx);
1230             return -1;
1231         }
1232         if ((samplecount = wavpack_decode_block(avctx, s->block,
1233                                                 s->frame.data[0], got_frame_ptr,
1234                                                 buf, frame_size)) < 0) {
1235             wavpack_decode_flush(avctx);
1236             return -1;
1237         }
1238         s->block++;
1239         buf += frame_size; buf_size -= frame_size;
1240     }
1241
1242     if (*got_frame_ptr)
1243         *(AVFrame *)data = s->frame;
1244
1245     return avpkt->size;
1246 }
1247
1248 AVCodec ff_wavpack_decoder = {
1249     .name           = "wavpack",
1250     .type           = AVMEDIA_TYPE_AUDIO,
1251     .id             = CODEC_ID_WAVPACK,
1252     .priv_data_size = sizeof(WavpackContext),
1253     .init           = wavpack_decode_init,
1254     .close          = wavpack_decode_end,
1255     .decode         = wavpack_decode_frame,
1256     .flush          = wavpack_decode_flush,
1257     .capabilities   = CODEC_CAP_SUBFRAMES | CODEC_CAP_DR1,
1258     .long_name      = NULL_IF_CONFIG_SMALL("WavPack"),
1259 };