]> git.sesse.net Git - ffmpeg/blob - libavcodec/wavpack.c
dnxhddec: better support for 4:4:4
[ffmpeg] / libavcodec / wavpack.c
1 /*
2  * WavPack lossless audio decoder
3  * Copyright (c) 2006,2011 Konstantin Shishkov
4  *
5  * This file is part of FFmpeg.
6  *
7  * FFmpeg is free software; you can redistribute it and/or
8  * modify it under the terms of the GNU Lesser General Public
9  * License as published by the Free Software Foundation; either
10  * version 2.1 of the License, or (at your option) any later version.
11  *
12  * FFmpeg is distributed in the hope that it will be useful,
13  * but WITHOUT ANY WARRANTY; without even the implied warranty of
14  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
15  * Lesser General Public License for more details.
16  *
17  * You should have received a copy of the GNU Lesser General Public
18  * License along with FFmpeg; if not, write to the Free Software
19  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
20  */
21
22 #define BITSTREAM_READER_LE
23
24 #include "libavutil/channel_layout.h"
25 #include "avcodec.h"
26 #include "get_bits.h"
27 #include "internal.h"
28 #include "thread.h"
29 #include "unary.h"
30 #include "bytestream.h"
31 #include "wavpack.h"
32
33 /**
34  * @file
35  * WavPack lossless audio decoder
36  */
37
38 typedef struct SavedContext {
39     int offset;
40     int size;
41     int bits_used;
42     uint32_t crc;
43 } SavedContext;
44
45 typedef struct WavpackFrameContext {
46     AVCodecContext *avctx;
47     int frame_flags;
48     int stereo, stereo_in;
49     int joint;
50     uint32_t CRC;
51     GetBitContext gb;
52     int got_extra_bits;
53     uint32_t crc_extra_bits;
54     GetBitContext gb_extra_bits;
55     int data_size; // in bits
56     int samples;
57     int terms;
58     Decorr decorr[MAX_TERMS];
59     int zero, one, zeroes;
60     int extra_bits;
61     int and, or, shift;
62     int post_shift;
63     int hybrid, hybrid_bitrate;
64     int hybrid_maxclip, hybrid_minclip;
65     int float_flag;
66     int float_shift;
67     int float_max_exp;
68     WvChannel ch[2];
69     int pos;
70     SavedContext sc, extra_sc;
71 } WavpackFrameContext;
72
73 #define WV_MAX_FRAME_DECODERS 14
74
75 typedef struct WavpackContext {
76     AVCodecContext *avctx;
77
78     WavpackFrameContext *fdec[WV_MAX_FRAME_DECODERS];
79     int fdec_num;
80
81     int block;
82     int samples;
83     int ch_offset;
84 } WavpackContext;
85
86 #define LEVEL_DECAY(a)  (((a) + 0x80) >> 8)
87
88 static av_always_inline int get_tail(GetBitContext *gb, int k)
89 {
90     int p, e, res;
91
92     if (k < 1)
93         return 0;
94     p   = av_log2(k);
95     e   = (1 << (p + 1)) - k - 1;
96     res = p ? get_bits(gb, p) : 0;
97     if (res >= e)
98         res = (res << 1) - e + get_bits1(gb);
99     return res;
100 }
101
102 static void update_error_limit(WavpackFrameContext *ctx)
103 {
104     int i, br[2], sl[2];
105
106     for (i = 0; i <= ctx->stereo_in; i++) {
107         ctx->ch[i].bitrate_acc += ctx->ch[i].bitrate_delta;
108         br[i]                   = ctx->ch[i].bitrate_acc >> 16;
109         sl[i]                   = LEVEL_DECAY(ctx->ch[i].slow_level);
110     }
111     if (ctx->stereo_in && ctx->hybrid_bitrate) {
112         int balance = (sl[1] - sl[0] + br[1] + 1) >> 1;
113         if (balance > br[0]) {
114             br[1] = br[0] << 1;
115             br[0] = 0;
116         } else if (-balance > br[0]) {
117             br[0] <<= 1;
118             br[1]   = 0;
119         } else {
120             br[1] = br[0] + balance;
121             br[0] = br[0] - balance;
122         }
123     }
124     for (i = 0; i <= ctx->stereo_in; i++) {
125         if (ctx->hybrid_bitrate) {
126             if (sl[i] - br[i] > -0x100)
127                 ctx->ch[i].error_limit = wp_exp2(sl[i] - br[i] + 0x100);
128             else
129                 ctx->ch[i].error_limit = 0;
130         } else {
131             ctx->ch[i].error_limit = wp_exp2(br[i]);
132         }
133     }
134 }
135
136 static int wv_get_value(WavpackFrameContext *ctx, GetBitContext *gb,
137                         int channel, int *last)
138 {
139     int t, t2;
140     int sign, base, add, ret;
141     WvChannel *c = &ctx->ch[channel];
142
143     *last = 0;
144
145     if ((ctx->ch[0].median[0] < 2U) && (ctx->ch[1].median[0] < 2U) &&
146         !ctx->zero && !ctx->one) {
147         if (ctx->zeroes) {
148             ctx->zeroes--;
149             if (ctx->zeroes) {
150                 c->slow_level -= LEVEL_DECAY(c->slow_level);
151                 return 0;
152             }
153         } else {
154             t = get_unary_0_33(gb);
155             if (t >= 2) {
156                 if (get_bits_left(gb) < t - 1)
157                     goto error;
158                 t = get_bits_long(gb, t - 1) | (1 << (t - 1));
159             } else {
160                 if (get_bits_left(gb) < 0)
161                     goto error;
162             }
163             ctx->zeroes = t;
164             if (ctx->zeroes) {
165                 memset(ctx->ch[0].median, 0, sizeof(ctx->ch[0].median));
166                 memset(ctx->ch[1].median, 0, sizeof(ctx->ch[1].median));
167                 c->slow_level -= LEVEL_DECAY(c->slow_level);
168                 return 0;
169             }
170         }
171     }
172
173     if (ctx->zero) {
174         t         = 0;
175         ctx->zero = 0;
176     } else {
177         t = get_unary_0_33(gb);
178         if (get_bits_left(gb) < 0)
179             goto error;
180         if (t == 16) {
181             t2 = get_unary_0_33(gb);
182             if (t2 < 2) {
183                 if (get_bits_left(gb) < 0)
184                     goto error;
185                 t += t2;
186             } else {
187                 if (get_bits_left(gb) < t2 - 1)
188                     goto error;
189                 t += get_bits_long(gb, t2 - 1) | (1 << (t2 - 1));
190             }
191         }
192
193         if (ctx->one) {
194             ctx->one = t & 1;
195             t        = (t >> 1) + 1;
196         } else {
197             ctx->one = t & 1;
198             t      >>= 1;
199         }
200         ctx->zero = !ctx->one;
201     }
202
203     if (ctx->hybrid && !channel)
204         update_error_limit(ctx);
205
206     if (!t) {
207         base = 0;
208         add  = GET_MED(0) - 1;
209         DEC_MED(0);
210     } else if (t == 1) {
211         base = GET_MED(0);
212         add  = GET_MED(1) - 1;
213         INC_MED(0);
214         DEC_MED(1);
215     } else if (t == 2) {
216         base = GET_MED(0) + GET_MED(1);
217         add  = GET_MED(2) - 1;
218         INC_MED(0);
219         INC_MED(1);
220         DEC_MED(2);
221     } else {
222         base = GET_MED(0) + GET_MED(1) + GET_MED(2) * (t - 2);
223         add  = GET_MED(2) - 1;
224         INC_MED(0);
225         INC_MED(1);
226         INC_MED(2);
227     }
228     if (!c->error_limit) {
229         if (add >= 0x2000000U) {
230             av_log(ctx->avctx, AV_LOG_ERROR, "k %d is too large\n", add);
231             goto error;
232         }
233         ret = base + get_tail(gb, add);
234         if (get_bits_left(gb) <= 0)
235             goto error;
236     } else {
237         int mid = (base * 2 + add + 1) >> 1;
238         while (add > c->error_limit) {
239             if (get_bits_left(gb) <= 0)
240                 goto error;
241             if (get_bits1(gb)) {
242                 add -= (mid - base);
243                 base = mid;
244             } else
245                 add = mid - base - 1;
246             mid = (base * 2 + add + 1) >> 1;
247         }
248         ret = mid;
249     }
250     sign = get_bits1(gb);
251     if (ctx->hybrid_bitrate)
252         c->slow_level += wp_log2(ret) - LEVEL_DECAY(c->slow_level);
253     return sign ? ~ret : ret;
254
255 error:
256     ret = get_bits_left(gb);
257     if (ret <= 0) {
258         av_log(ctx->avctx, AV_LOG_ERROR, "Too few bits (%d) left\n", ret);
259     }
260     *last = 1;
261     return 0;
262 }
263
264 static inline int wv_get_value_integer(WavpackFrameContext *s, uint32_t *crc,
265                                        int S)
266 {
267     int bit;
268
269     if (s->extra_bits) {
270         S <<= s->extra_bits;
271
272         if (s->got_extra_bits &&
273             get_bits_left(&s->gb_extra_bits) >= s->extra_bits) {
274             S   |= get_bits_long(&s->gb_extra_bits, s->extra_bits);
275             *crc = *crc * 9 + (S & 0xffff) * 3 + ((unsigned)S >> 16);
276         }
277     }
278
279     bit = (S & s->and) | s->or;
280     bit = ((S + bit) << s->shift) - bit;
281
282     if (s->hybrid)
283         bit = av_clip(bit, s->hybrid_minclip, s->hybrid_maxclip);
284
285     return bit << s->post_shift;
286 }
287
288 static float wv_get_value_float(WavpackFrameContext *s, uint32_t *crc, int S)
289 {
290     union {
291         float    f;
292         uint32_t u;
293     } value;
294
295     unsigned int sign;
296     int exp = s->float_max_exp;
297
298     if (s->got_extra_bits) {
299         const int max_bits  = 1 + 23 + 8 + 1;
300         const int left_bits = get_bits_left(&s->gb_extra_bits);
301
302         if (left_bits + 8 * AV_INPUT_BUFFER_PADDING_SIZE < max_bits)
303             return 0.0;
304     }
305
306     if (S) {
307         S  <<= s->float_shift;
308         sign = S < 0;
309         if (sign)
310             S = -S;
311         if (S >= 0x1000000) {
312             if (s->got_extra_bits && get_bits1(&s->gb_extra_bits))
313                 S = get_bits(&s->gb_extra_bits, 23);
314             else
315                 S = 0;
316             exp = 255;
317         } else if (exp) {
318             int shift = 23 - av_log2(S);
319             exp = s->float_max_exp;
320             if (exp <= shift)
321                 shift = --exp;
322             exp -= shift;
323
324             if (shift) {
325                 S <<= shift;
326                 if ((s->float_flag & WV_FLT_SHIFT_ONES) ||
327                     (s->got_extra_bits &&
328                      (s->float_flag & WV_FLT_SHIFT_SAME) &&
329                      get_bits1(&s->gb_extra_bits))) {
330                     S |= (1 << shift) - 1;
331                 } else if (s->got_extra_bits &&
332                            (s->float_flag & WV_FLT_SHIFT_SENT)) {
333                     S |= get_bits(&s->gb_extra_bits, shift);
334                 }
335             }
336         } else {
337             exp = s->float_max_exp;
338         }
339         S &= 0x7fffff;
340     } else {
341         sign = 0;
342         exp  = 0;
343         if (s->got_extra_bits && (s->float_flag & WV_FLT_ZERO_SENT)) {
344             if (get_bits1(&s->gb_extra_bits)) {
345                 S = get_bits(&s->gb_extra_bits, 23);
346                 if (s->float_max_exp >= 25)
347                     exp = get_bits(&s->gb_extra_bits, 8);
348                 sign = get_bits1(&s->gb_extra_bits);
349             } else {
350                 if (s->float_flag & WV_FLT_ZERO_SIGN)
351                     sign = get_bits1(&s->gb_extra_bits);
352             }
353         }
354     }
355
356     *crc = *crc * 27 + S * 9 + exp * 3 + sign;
357
358     value.u = (sign << 31) | (exp << 23) | S;
359     return value.f;
360 }
361
362 static void wv_reset_saved_context(WavpackFrameContext *s)
363 {
364     s->pos    = 0;
365     s->sc.crc = s->extra_sc.crc = 0xFFFFFFFF;
366 }
367
368 static inline int wv_check_crc(WavpackFrameContext *s, uint32_t crc,
369                                uint32_t crc_extra_bits)
370 {
371     if (crc != s->CRC) {
372         av_log(s->avctx, AV_LOG_ERROR, "CRC error\n");
373         return AVERROR_INVALIDDATA;
374     }
375     if (s->got_extra_bits && crc_extra_bits != s->crc_extra_bits) {
376         av_log(s->avctx, AV_LOG_ERROR, "Extra bits CRC error\n");
377         return AVERROR_INVALIDDATA;
378     }
379
380     return 0;
381 }
382
383 static inline int wv_unpack_stereo(WavpackFrameContext *s, GetBitContext *gb,
384                                    void *dst_l, void *dst_r, const int type)
385 {
386     int i, j, count = 0;
387     int last, t;
388     int A, B, L, L2, R, R2;
389     int pos                 = s->pos;
390     uint32_t crc            = s->sc.crc;
391     uint32_t crc_extra_bits = s->extra_sc.crc;
392     int16_t *dst16_l        = dst_l;
393     int16_t *dst16_r        = dst_r;
394     int32_t *dst32_l        = dst_l;
395     int32_t *dst32_r        = dst_r;
396     float *dstfl_l          = dst_l;
397     float *dstfl_r          = dst_r;
398
399     s->one = s->zero = s->zeroes = 0;
400     do {
401         L = wv_get_value(s, gb, 0, &last);
402         if (last)
403             break;
404         R = wv_get_value(s, gb, 1, &last);
405         if (last)
406             break;
407         for (i = 0; i < s->terms; i++) {
408             t = s->decorr[i].value;
409             if (t > 0) {
410                 if (t > 8) {
411                     if (t & 1) {
412                         A = 2 * s->decorr[i].samplesA[0] - s->decorr[i].samplesA[1];
413                         B = 2 * s->decorr[i].samplesB[0] - s->decorr[i].samplesB[1];
414                     } else {
415                         A = (3 * s->decorr[i].samplesA[0] - s->decorr[i].samplesA[1]) >> 1;
416                         B = (3 * s->decorr[i].samplesB[0] - s->decorr[i].samplesB[1]) >> 1;
417                     }
418                     s->decorr[i].samplesA[1] = s->decorr[i].samplesA[0];
419                     s->decorr[i].samplesB[1] = s->decorr[i].samplesB[0];
420                     j                        = 0;
421                 } else {
422                     A = s->decorr[i].samplesA[pos];
423                     B = s->decorr[i].samplesB[pos];
424                     j = (pos + t) & 7;
425                 }
426                 if (type != AV_SAMPLE_FMT_S16P) {
427                     L2 = L + ((s->decorr[i].weightA * (int64_t)A + 512) >> 10);
428                     R2 = R + ((s->decorr[i].weightB * (int64_t)B + 512) >> 10);
429                 } else {
430                     L2 = L + ((s->decorr[i].weightA * A + 512) >> 10);
431                     R2 = R + ((s->decorr[i].weightB * B + 512) >> 10);
432                 }
433                 if (A && L)
434                     s->decorr[i].weightA -= ((((L ^ A) >> 30) & 2) - 1) * s->decorr[i].delta;
435                 if (B && R)
436                     s->decorr[i].weightB -= ((((R ^ B) >> 30) & 2) - 1) * s->decorr[i].delta;
437                 s->decorr[i].samplesA[j] = L = L2;
438                 s->decorr[i].samplesB[j] = R = R2;
439             } else if (t == -1) {
440                 if (type != AV_SAMPLE_FMT_S16P)
441                     L2 = L + ((s->decorr[i].weightA * (int64_t)s->decorr[i].samplesA[0] + 512) >> 10);
442                 else
443                     L2 = L + ((s->decorr[i].weightA * s->decorr[i].samplesA[0] + 512) >> 10);
444                 UPDATE_WEIGHT_CLIP(s->decorr[i].weightA, s->decorr[i].delta, s->decorr[i].samplesA[0], L);
445                 L = L2;
446                 if (type != AV_SAMPLE_FMT_S16P)
447                     R2 = R + ((s->decorr[i].weightB * (int64_t)L2 + 512) >> 10);
448                 else
449                     R2 = R + ((s->decorr[i].weightB * L2 + 512) >> 10);
450                 UPDATE_WEIGHT_CLIP(s->decorr[i].weightB, s->decorr[i].delta, L2, R);
451                 R                        = R2;
452                 s->decorr[i].samplesA[0] = R;
453             } else {
454                 if (type != AV_SAMPLE_FMT_S16P)
455                     R2 = R + ((s->decorr[i].weightB * (int64_t)s->decorr[i].samplesB[0] + 512) >> 10);
456                 else
457                     R2 = R + ((s->decorr[i].weightB * s->decorr[i].samplesB[0] + 512) >> 10);
458                 UPDATE_WEIGHT_CLIP(s->decorr[i].weightB, s->decorr[i].delta, s->decorr[i].samplesB[0], R);
459                 R = R2;
460
461                 if (t == -3) {
462                     R2                       = s->decorr[i].samplesA[0];
463                     s->decorr[i].samplesA[0] = R;
464                 }
465
466                 if (type != AV_SAMPLE_FMT_S16P)
467                     L2 = L + ((s->decorr[i].weightA * (int64_t)R2 + 512) >> 10);
468                 else
469                     L2 = L + ((s->decorr[i].weightA * R2 + 512) >> 10);
470                 UPDATE_WEIGHT_CLIP(s->decorr[i].weightA, s->decorr[i].delta, R2, L);
471                 L                        = L2;
472                 s->decorr[i].samplesB[0] = L;
473             }
474         }
475
476         if (type == AV_SAMPLE_FMT_S16P) {
477             if (FFABS(L) + FFABS(R) > (1<<19)) {
478                 av_log(s->avctx, AV_LOG_ERROR, "sample %d %d too large\n", L, R);
479                 return AVERROR_INVALIDDATA;
480             }
481         }
482
483         pos = (pos + 1) & 7;
484         if (s->joint)
485             L += (R -= (L >> 1));
486         crc = (crc * 3 + L) * 3 + R;
487
488         if (type == AV_SAMPLE_FMT_FLTP) {
489             *dstfl_l++ = wv_get_value_float(s, &crc_extra_bits, L);
490             *dstfl_r++ = wv_get_value_float(s, &crc_extra_bits, R);
491         } else if (type == AV_SAMPLE_FMT_S32P) {
492             *dst32_l++ = wv_get_value_integer(s, &crc_extra_bits, L);
493             *dst32_r++ = wv_get_value_integer(s, &crc_extra_bits, R);
494         } else {
495             *dst16_l++ = wv_get_value_integer(s, &crc_extra_bits, L);
496             *dst16_r++ = wv_get_value_integer(s, &crc_extra_bits, R);
497         }
498         count++;
499     } while (!last && count < s->samples);
500
501     wv_reset_saved_context(s);
502
503     if (last && count < s->samples) {
504         int size = av_get_bytes_per_sample(type);
505         memset((uint8_t*)dst_l + count*size, 0, (s->samples-count)*size);
506         memset((uint8_t*)dst_r + count*size, 0, (s->samples-count)*size);
507     }
508
509     if ((s->avctx->err_recognition & AV_EF_CRCCHECK) &&
510         wv_check_crc(s, crc, crc_extra_bits))
511         return AVERROR_INVALIDDATA;
512
513     return 0;
514 }
515
516 static inline int wv_unpack_mono(WavpackFrameContext *s, GetBitContext *gb,
517                                  void *dst, const int type)
518 {
519     int i, j, count = 0;
520     int last, t;
521     int A, S, T;
522     int pos                  = s->pos;
523     uint32_t crc             = s->sc.crc;
524     uint32_t crc_extra_bits  = s->extra_sc.crc;
525     int16_t *dst16           = dst;
526     int32_t *dst32           = dst;
527     float *dstfl             = dst;
528
529     s->one = s->zero = s->zeroes = 0;
530     do {
531         T = wv_get_value(s, gb, 0, &last);
532         S = 0;
533         if (last)
534             break;
535         for (i = 0; i < s->terms; i++) {
536             t = s->decorr[i].value;
537             if (t > 8) {
538                 if (t & 1)
539                     A =  2 * s->decorr[i].samplesA[0] - s->decorr[i].samplesA[1];
540                 else
541                     A = (3 * s->decorr[i].samplesA[0] - s->decorr[i].samplesA[1]) >> 1;
542                 s->decorr[i].samplesA[1] = s->decorr[i].samplesA[0];
543                 j                        = 0;
544             } else {
545                 A = s->decorr[i].samplesA[pos];
546                 j = (pos + t) & 7;
547             }
548             if (type != AV_SAMPLE_FMT_S16P)
549                 S = T + ((s->decorr[i].weightA * (int64_t)A + 512) >> 10);
550             else
551                 S = T + ((s->decorr[i].weightA * A + 512) >> 10);
552             if (A && T)
553                 s->decorr[i].weightA -= ((((T ^ A) >> 30) & 2) - 1) * s->decorr[i].delta;
554             s->decorr[i].samplesA[j] = T = S;
555         }
556         pos = (pos + 1) & 7;
557         crc = crc * 3 + S;
558
559         if (type == AV_SAMPLE_FMT_FLTP) {
560             *dstfl++ = wv_get_value_float(s, &crc_extra_bits, S);
561         } else if (type == AV_SAMPLE_FMT_S32P) {
562             *dst32++ = wv_get_value_integer(s, &crc_extra_bits, S);
563         } else {
564             *dst16++ = wv_get_value_integer(s, &crc_extra_bits, S);
565         }
566         count++;
567     } while (!last && count < s->samples);
568
569     wv_reset_saved_context(s);
570
571     if (last && count < s->samples) {
572         int size = av_get_bytes_per_sample(type);
573         memset((uint8_t*)dst + count*size, 0, (s->samples-count)*size);
574     }
575
576     if (s->avctx->err_recognition & AV_EF_CRCCHECK) {
577         int ret = wv_check_crc(s, crc, crc_extra_bits);
578         if (ret < 0 && s->avctx->err_recognition & AV_EF_EXPLODE)
579             return ret;
580     }
581
582     return 0;
583 }
584
585 static av_cold int wv_alloc_frame_context(WavpackContext *c)
586 {
587     if (c->fdec_num == WV_MAX_FRAME_DECODERS)
588         return -1;
589
590     c->fdec[c->fdec_num] = av_mallocz(sizeof(**c->fdec));
591     if (!c->fdec[c->fdec_num])
592         return -1;
593     c->fdec_num++;
594     c->fdec[c->fdec_num - 1]->avctx = c->avctx;
595     wv_reset_saved_context(c->fdec[c->fdec_num - 1]);
596
597     return 0;
598 }
599
600 #if HAVE_THREADS
601 static int init_thread_copy(AVCodecContext *avctx)
602 {
603     WavpackContext *s = avctx->priv_data;
604     s->avctx = avctx;
605     return 0;
606 }
607 #endif
608
609 static av_cold int wavpack_decode_init(AVCodecContext *avctx)
610 {
611     WavpackContext *s = avctx->priv_data;
612
613     s->avctx = avctx;
614
615     s->fdec_num = 0;
616
617     return 0;
618 }
619
620 static av_cold int wavpack_decode_end(AVCodecContext *avctx)
621 {
622     WavpackContext *s = avctx->priv_data;
623     int i;
624
625     for (i = 0; i < s->fdec_num; i++)
626         av_freep(&s->fdec[i]);
627     s->fdec_num = 0;
628
629     return 0;
630 }
631
632 static int wavpack_decode_block(AVCodecContext *avctx, int block_no,
633                                 AVFrame *frame, const uint8_t *buf, int buf_size)
634 {
635     WavpackContext *wc = avctx->priv_data;
636     ThreadFrame tframe = { .f = frame };
637     WavpackFrameContext *s;
638     GetByteContext gb;
639     void *samples_l = NULL, *samples_r = NULL;
640     int ret;
641     int got_terms   = 0, got_weights = 0, got_samples = 0,
642         got_entropy = 0, got_bs      = 0, got_float   = 0, got_hybrid = 0;
643     int i, j, id, size, ssize, weights, t;
644     int bpp, chan = 0, chmask = 0, orig_bpp, sample_rate = 0;
645     int multiblock;
646
647     if (block_no >= wc->fdec_num && wv_alloc_frame_context(wc) < 0) {
648         av_log(avctx, AV_LOG_ERROR, "Error creating frame decode context\n");
649         return AVERROR_INVALIDDATA;
650     }
651
652     s = wc->fdec[block_no];
653     if (!s) {
654         av_log(avctx, AV_LOG_ERROR, "Context for block %d is not present\n",
655                block_no);
656         return AVERROR_INVALIDDATA;
657     }
658
659     memset(s->decorr, 0, MAX_TERMS * sizeof(Decorr));
660     memset(s->ch, 0, sizeof(s->ch));
661     s->extra_bits     = 0;
662     s->and            = s->or = s->shift = 0;
663     s->got_extra_bits = 0;
664
665     bytestream2_init(&gb, buf, buf_size);
666
667     s->samples = bytestream2_get_le32(&gb);
668     if (s->samples != wc->samples) {
669         av_log(avctx, AV_LOG_ERROR, "Mismatching number of samples in "
670                "a sequence: %d and %d\n", wc->samples, s->samples);
671         return AVERROR_INVALIDDATA;
672     }
673     s->frame_flags = bytestream2_get_le32(&gb);
674     bpp            = av_get_bytes_per_sample(avctx->sample_fmt);
675     orig_bpp       = ((s->frame_flags & 0x03) + 1) << 3;
676     multiblock     = (s->frame_flags & WV_SINGLE_BLOCK) != WV_SINGLE_BLOCK;
677
678     s->stereo         = !(s->frame_flags & WV_MONO);
679     s->stereo_in      =  (s->frame_flags & WV_FALSE_STEREO) ? 0 : s->stereo;
680     s->joint          =   s->frame_flags & WV_JOINT_STEREO;
681     s->hybrid         =   s->frame_flags & WV_HYBRID_MODE;
682     s->hybrid_bitrate =   s->frame_flags & WV_HYBRID_BITRATE;
683     s->post_shift     = bpp * 8 - orig_bpp + ((s->frame_flags >> 13) & 0x1f);
684     s->hybrid_maxclip =  ((1LL << (orig_bpp - 1)) - 1);
685     s->hybrid_minclip = ((-1LL << (orig_bpp - 1)));
686     s->CRC            = bytestream2_get_le32(&gb);
687
688     // parse metadata blocks
689     while (bytestream2_get_bytes_left(&gb)) {
690         id   = bytestream2_get_byte(&gb);
691         size = bytestream2_get_byte(&gb);
692         if (id & WP_IDF_LONG) {
693             size |= (bytestream2_get_byte(&gb)) << 8;
694             size |= (bytestream2_get_byte(&gb)) << 16;
695         }
696         size <<= 1; // size is specified in words
697         ssize  = size;
698         if (id & WP_IDF_ODD)
699             size--;
700         if (size < 0) {
701             av_log(avctx, AV_LOG_ERROR,
702                    "Got incorrect block %02X with size %i\n", id, size);
703             break;
704         }
705         if (bytestream2_get_bytes_left(&gb) < ssize) {
706             av_log(avctx, AV_LOG_ERROR,
707                    "Block size %i is out of bounds\n", size);
708             break;
709         }
710         switch (id & WP_IDF_MASK) {
711         case WP_ID_DECTERMS:
712             if (size > MAX_TERMS) {
713                 av_log(avctx, AV_LOG_ERROR, "Too many decorrelation terms\n");
714                 s->terms = 0;
715                 bytestream2_skip(&gb, ssize);
716                 continue;
717             }
718             s->terms = size;
719             for (i = 0; i < s->terms; i++) {
720                 uint8_t val = bytestream2_get_byte(&gb);
721                 s->decorr[s->terms - i - 1].value = (val & 0x1F) - 5;
722                 s->decorr[s->terms - i - 1].delta =  val >> 5;
723             }
724             got_terms = 1;
725             break;
726         case WP_ID_DECWEIGHTS:
727             if (!got_terms) {
728                 av_log(avctx, AV_LOG_ERROR, "No decorrelation terms met\n");
729                 continue;
730             }
731             weights = size >> s->stereo_in;
732             if (weights > MAX_TERMS || weights > s->terms) {
733                 av_log(avctx, AV_LOG_ERROR, "Too many decorrelation weights\n");
734                 bytestream2_skip(&gb, ssize);
735                 continue;
736             }
737             for (i = 0; i < weights; i++) {
738                 t = (int8_t)bytestream2_get_byte(&gb);
739                 s->decorr[s->terms - i - 1].weightA = t << 3;
740                 if (s->decorr[s->terms - i - 1].weightA > 0)
741                     s->decorr[s->terms - i - 1].weightA +=
742                         (s->decorr[s->terms - i - 1].weightA + 64) >> 7;
743                 if (s->stereo_in) {
744                     t = (int8_t)bytestream2_get_byte(&gb);
745                     s->decorr[s->terms - i - 1].weightB = t << 3;
746                     if (s->decorr[s->terms - i - 1].weightB > 0)
747                         s->decorr[s->terms - i - 1].weightB +=
748                             (s->decorr[s->terms - i - 1].weightB + 64) >> 7;
749                 }
750             }
751             got_weights = 1;
752             break;
753         case WP_ID_DECSAMPLES:
754             if (!got_terms) {
755                 av_log(avctx, AV_LOG_ERROR, "No decorrelation terms met\n");
756                 continue;
757             }
758             t = 0;
759             for (i = s->terms - 1; (i >= 0) && (t < size); i--) {
760                 if (s->decorr[i].value > 8) {
761                     s->decorr[i].samplesA[0] =
762                         wp_exp2(bytestream2_get_le16(&gb));
763                     s->decorr[i].samplesA[1] =
764                         wp_exp2(bytestream2_get_le16(&gb));
765
766                     if (s->stereo_in) {
767                         s->decorr[i].samplesB[0] =
768                             wp_exp2(bytestream2_get_le16(&gb));
769                         s->decorr[i].samplesB[1] =
770                             wp_exp2(bytestream2_get_le16(&gb));
771                         t                       += 4;
772                     }
773                     t += 4;
774                 } else if (s->decorr[i].value < 0) {
775                     s->decorr[i].samplesA[0] =
776                         wp_exp2(bytestream2_get_le16(&gb));
777                     s->decorr[i].samplesB[0] =
778                         wp_exp2(bytestream2_get_le16(&gb));
779                     t                       += 4;
780                 } else {
781                     for (j = 0; j < s->decorr[i].value; j++) {
782                         s->decorr[i].samplesA[j] =
783                             wp_exp2(bytestream2_get_le16(&gb));
784                         if (s->stereo_in) {
785                             s->decorr[i].samplesB[j] =
786                                 wp_exp2(bytestream2_get_le16(&gb));
787                         }
788                     }
789                     t += s->decorr[i].value * 2 * (s->stereo_in + 1);
790                 }
791             }
792             got_samples = 1;
793             break;
794         case WP_ID_ENTROPY:
795             if (size != 6 * (s->stereo_in + 1)) {
796                 av_log(avctx, AV_LOG_ERROR,
797                        "Entropy vars size should be %i, got %i.\n",
798                        6 * (s->stereo_in + 1), size);
799                 bytestream2_skip(&gb, ssize);
800                 continue;
801             }
802             for (j = 0; j <= s->stereo_in; j++)
803                 for (i = 0; i < 3; i++) {
804                     s->ch[j].median[i] = wp_exp2(bytestream2_get_le16(&gb));
805                 }
806             got_entropy = 1;
807             break;
808         case WP_ID_HYBRID:
809             if (s->hybrid_bitrate) {
810                 for (i = 0; i <= s->stereo_in; i++) {
811                     s->ch[i].slow_level = wp_exp2(bytestream2_get_le16(&gb));
812                     size               -= 2;
813                 }
814             }
815             for (i = 0; i < (s->stereo_in + 1); i++) {
816                 s->ch[i].bitrate_acc = bytestream2_get_le16(&gb) << 16;
817                 size                -= 2;
818             }
819             if (size > 0) {
820                 for (i = 0; i < (s->stereo_in + 1); i++) {
821                     s->ch[i].bitrate_delta =
822                         wp_exp2((int16_t)bytestream2_get_le16(&gb));
823                 }
824             } else {
825                 for (i = 0; i < (s->stereo_in + 1); i++)
826                     s->ch[i].bitrate_delta = 0;
827             }
828             got_hybrid = 1;
829             break;
830         case WP_ID_INT32INFO: {
831             uint8_t val[4];
832             if (size != 4) {
833                 av_log(avctx, AV_LOG_ERROR,
834                        "Invalid INT32INFO, size = %i\n",
835                        size);
836                 bytestream2_skip(&gb, ssize - 4);
837                 continue;
838             }
839             bytestream2_get_buffer(&gb, val, 4);
840             if (val[0] > 32) {
841                 av_log(avctx, AV_LOG_ERROR,
842                        "Invalid INT32INFO, extra_bits = %d (> 32)\n", val[0]);
843                 continue;
844             } else if (val[0]) {
845                 s->extra_bits = val[0];
846             } else if (val[1]) {
847                 s->shift = val[1];
848             } else if (val[2]) {
849                 s->and   = s->or = 1;
850                 s->shift = val[2];
851             } else if (val[3]) {
852                 s->and   = 1;
853                 s->shift = val[3];
854             }
855             /* original WavPack decoder forces 32-bit lossy sound to be treated
856              * as 24-bit one in order to have proper clipping */
857             if (s->hybrid && bpp == 4 && s->post_shift < 8 && s->shift > 8) {
858                 s->post_shift      += 8;
859                 s->shift           -= 8;
860                 s->hybrid_maxclip >>= 8;
861                 s->hybrid_minclip >>= 8;
862             }
863             break;
864         }
865         case WP_ID_FLOATINFO:
866             if (size != 4) {
867                 av_log(avctx, AV_LOG_ERROR,
868                        "Invalid FLOATINFO, size = %i\n", size);
869                 bytestream2_skip(&gb, ssize);
870                 continue;
871             }
872             s->float_flag    = bytestream2_get_byte(&gb);
873             s->float_shift   = bytestream2_get_byte(&gb);
874             s->float_max_exp = bytestream2_get_byte(&gb);
875             got_float        = 1;
876             bytestream2_skip(&gb, 1);
877             break;
878         case WP_ID_DATA:
879             s->sc.offset = bytestream2_tell(&gb);
880             s->sc.size   = size * 8;
881             if ((ret = init_get_bits8(&s->gb, gb.buffer, size)) < 0)
882                 return ret;
883             s->data_size = size * 8;
884             bytestream2_skip(&gb, size);
885             got_bs       = 1;
886             break;
887         case WP_ID_EXTRABITS:
888             if (size <= 4) {
889                 av_log(avctx, AV_LOG_ERROR, "Invalid EXTRABITS, size = %i\n",
890                        size);
891                 bytestream2_skip(&gb, size);
892                 continue;
893             }
894             s->extra_sc.offset = bytestream2_tell(&gb);
895             s->extra_sc.size   = size * 8;
896             if ((ret = init_get_bits8(&s->gb_extra_bits, gb.buffer, size)) < 0)
897                 return ret;
898             s->crc_extra_bits  = get_bits_long(&s->gb_extra_bits, 32);
899             bytestream2_skip(&gb, size);
900             s->got_extra_bits  = 1;
901             break;
902         case WP_ID_CHANINFO:
903             if (size <= 1) {
904                 av_log(avctx, AV_LOG_ERROR,
905                        "Insufficient channel information\n");
906                 return AVERROR_INVALIDDATA;
907             }
908             chan = bytestream2_get_byte(&gb);
909             switch (size - 2) {
910             case 0:
911                 chmask = bytestream2_get_byte(&gb);
912                 break;
913             case 1:
914                 chmask = bytestream2_get_le16(&gb);
915                 break;
916             case 2:
917                 chmask = bytestream2_get_le24(&gb);
918                 break;
919             case 3:
920                 chmask = bytestream2_get_le32(&gb);
921                 break;
922             case 5:
923                 size = bytestream2_get_byte(&gb);
924                 if (avctx->channels != size)
925                     av_log(avctx, AV_LOG_WARNING, "%i channels signalled"
926                            " instead of %i.\n", size, avctx->channels);
927                 chan  |= (bytestream2_get_byte(&gb) & 0xF) << 8;
928                 chmask = bytestream2_get_le16(&gb);
929                 break;
930             default:
931                 av_log(avctx, AV_LOG_ERROR, "Invalid channel info size %d\n",
932                        size);
933                 chan   = avctx->channels;
934                 chmask = avctx->channel_layout;
935             }
936             break;
937         case WP_ID_SAMPLE_RATE:
938             if (size != 3) {
939                 av_log(avctx, AV_LOG_ERROR, "Invalid custom sample rate.\n");
940                 return AVERROR_INVALIDDATA;
941             }
942             sample_rate = bytestream2_get_le24(&gb);
943             break;
944         default:
945             bytestream2_skip(&gb, size);
946         }
947         if (id & WP_IDF_ODD)
948             bytestream2_skip(&gb, 1);
949     }
950
951     if (!got_terms) {
952         av_log(avctx, AV_LOG_ERROR, "No block with decorrelation terms\n");
953         return AVERROR_INVALIDDATA;
954     }
955     if (!got_weights) {
956         av_log(avctx, AV_LOG_ERROR, "No block with decorrelation weights\n");
957         return AVERROR_INVALIDDATA;
958     }
959     if (!got_samples) {
960         av_log(avctx, AV_LOG_ERROR, "No block with decorrelation samples\n");
961         return AVERROR_INVALIDDATA;
962     }
963     if (!got_entropy) {
964         av_log(avctx, AV_LOG_ERROR, "No block with entropy info\n");
965         return AVERROR_INVALIDDATA;
966     }
967     if (s->hybrid && !got_hybrid) {
968         av_log(avctx, AV_LOG_ERROR, "Hybrid config not found\n");
969         return AVERROR_INVALIDDATA;
970     }
971     if (!got_bs) {
972         av_log(avctx, AV_LOG_ERROR, "Packed samples not found\n");
973         return AVERROR_INVALIDDATA;
974     }
975     if (!got_float && avctx->sample_fmt == AV_SAMPLE_FMT_FLTP) {
976         av_log(avctx, AV_LOG_ERROR, "Float information not found\n");
977         return AVERROR_INVALIDDATA;
978     }
979     if (s->got_extra_bits && avctx->sample_fmt != AV_SAMPLE_FMT_FLTP) {
980         const int size   = get_bits_left(&s->gb_extra_bits);
981         const int wanted = s->samples * s->extra_bits << s->stereo_in;
982         if (size < wanted) {
983             av_log(avctx, AV_LOG_ERROR, "Too small EXTRABITS\n");
984             s->got_extra_bits = 0;
985         }
986     }
987
988     if (!wc->ch_offset) {
989         int sr = (s->frame_flags >> 23) & 0xf;
990         if (sr == 0xf) {
991             if (!sample_rate) {
992                 av_log(avctx, AV_LOG_ERROR, "Custom sample rate missing.\n");
993                 return AVERROR_INVALIDDATA;
994             }
995             avctx->sample_rate = sample_rate;
996         } else
997             avctx->sample_rate = wv_rates[sr];
998
999         if (multiblock) {
1000             if (chan)
1001                 avctx->channels = chan;
1002             if (chmask)
1003                 avctx->channel_layout = chmask;
1004         } else {
1005             avctx->channels       = s->stereo ? 2 : 1;
1006             avctx->channel_layout = s->stereo ? AV_CH_LAYOUT_STEREO :
1007                                                 AV_CH_LAYOUT_MONO;
1008         }
1009
1010         /* get output buffer */
1011         frame->nb_samples = s->samples + 1;
1012         if ((ret = ff_thread_get_buffer(avctx, &tframe, 0)) < 0)
1013             return ret;
1014         frame->nb_samples = s->samples;
1015     }
1016
1017     if (wc->ch_offset + s->stereo >= avctx->channels) {
1018         av_log(avctx, AV_LOG_WARNING, "Too many channels coded in a packet.\n");
1019         return (avctx->err_recognition & AV_EF_EXPLODE) ? AVERROR_INVALIDDATA : 0;
1020     }
1021
1022     samples_l = frame->extended_data[wc->ch_offset];
1023     if (s->stereo)
1024         samples_r = frame->extended_data[wc->ch_offset + 1];
1025
1026     wc->ch_offset += 1 + s->stereo;
1027
1028     if (s->stereo_in) {
1029         ret = wv_unpack_stereo(s, &s->gb, samples_l, samples_r, avctx->sample_fmt);
1030         if (ret < 0)
1031             return ret;
1032     } else {
1033         ret = wv_unpack_mono(s, &s->gb, samples_l, avctx->sample_fmt);
1034         if (ret < 0)
1035             return ret;
1036
1037         if (s->stereo)
1038             memcpy(samples_r, samples_l, bpp * s->samples);
1039     }
1040
1041     return 0;
1042 }
1043
1044 static void wavpack_decode_flush(AVCodecContext *avctx)
1045 {
1046     WavpackContext *s = avctx->priv_data;
1047     int i;
1048
1049     for (i = 0; i < s->fdec_num; i++)
1050         wv_reset_saved_context(s->fdec[i]);
1051 }
1052
1053 static int wavpack_decode_frame(AVCodecContext *avctx, void *data,
1054                                 int *got_frame_ptr, AVPacket *avpkt)
1055 {
1056     WavpackContext *s  = avctx->priv_data;
1057     const uint8_t *buf = avpkt->data;
1058     int buf_size       = avpkt->size;
1059     AVFrame *frame     = data;
1060     int frame_size, ret, frame_flags;
1061
1062     if (avpkt->size <= WV_HEADER_SIZE)
1063         return AVERROR_INVALIDDATA;
1064
1065     s->block     = 0;
1066     s->ch_offset = 0;
1067
1068     /* determine number of samples */
1069     s->samples  = AV_RL32(buf + 20);
1070     frame_flags = AV_RL32(buf + 24);
1071     if (s->samples <= 0 || s->samples > WV_MAX_SAMPLES) {
1072         av_log(avctx, AV_LOG_ERROR, "Invalid number of samples: %d\n",
1073                s->samples);
1074         return AVERROR_INVALIDDATA;
1075     }
1076
1077     if (frame_flags & 0x80) {
1078         avctx->sample_fmt = AV_SAMPLE_FMT_FLTP;
1079     } else if ((frame_flags & 0x03) <= 1) {
1080         avctx->sample_fmt = AV_SAMPLE_FMT_S16P;
1081     } else {
1082         avctx->sample_fmt          = AV_SAMPLE_FMT_S32P;
1083         avctx->bits_per_raw_sample = ((frame_flags & 0x03) + 1) << 3;
1084     }
1085
1086     while (buf_size > 0) {
1087         if (buf_size <= WV_HEADER_SIZE)
1088             break;
1089         frame_size = AV_RL32(buf + 4) - 12;
1090         buf       += 20;
1091         buf_size  -= 20;
1092         if (frame_size <= 0 || frame_size > buf_size) {
1093             av_log(avctx, AV_LOG_ERROR,
1094                    "Block %d has invalid size (size %d vs. %d bytes left)\n",
1095                    s->block, frame_size, buf_size);
1096             wavpack_decode_flush(avctx);
1097             return AVERROR_INVALIDDATA;
1098         }
1099         if ((ret = wavpack_decode_block(avctx, s->block,
1100                                         frame, buf, frame_size)) < 0) {
1101             wavpack_decode_flush(avctx);
1102             return ret;
1103         }
1104         s->block++;
1105         buf      += frame_size;
1106         buf_size -= frame_size;
1107     }
1108
1109     if (s->ch_offset != avctx->channels) {
1110         av_log(avctx, AV_LOG_ERROR, "Not enough channels coded in a packet.\n");
1111         return AVERROR_INVALIDDATA;
1112     }
1113
1114     *got_frame_ptr = 1;
1115
1116     return avpkt->size;
1117 }
1118
1119 AVCodec ff_wavpack_decoder = {
1120     .name           = "wavpack",
1121     .long_name      = NULL_IF_CONFIG_SMALL("WavPack"),
1122     .type           = AVMEDIA_TYPE_AUDIO,
1123     .id             = AV_CODEC_ID_WAVPACK,
1124     .priv_data_size = sizeof(WavpackContext),
1125     .init           = wavpack_decode_init,
1126     .close          = wavpack_decode_end,
1127     .decode         = wavpack_decode_frame,
1128     .flush          = wavpack_decode_flush,
1129     .init_thread_copy = ONLY_IF_THREADS_ENABLED(init_thread_copy),
1130     .capabilities   = AV_CODEC_CAP_DR1 | AV_CODEC_CAP_FRAME_THREADS,
1131 };