]> git.sesse.net Git - ffmpeg/blob - libavcodec/wmalosslessdec.c
avcodec/wmalosslessdec: Fix 2 overflows in mclms
[ffmpeg] / libavcodec / wmalosslessdec.c
1 /*
2  * Windows Media Audio Lossless decoder
3  * Copyright (c) 2007 Baptiste Coudurier, Benjamin Larsson, Ulion
4  * Copyright (c) 2008 - 2011 Sascha Sommer, Benjamin Larsson
5  * Copyright (c) 2011 Andreas Ă–man
6  * Copyright (c) 2011 - 2012 Mashiat Sarker Shakkhar
7  *
8  * This file is part of FFmpeg.
9  *
10  * FFmpeg is free software; you can redistribute it and/or
11  * modify it under the terms of the GNU Lesser General Public
12  * License as published by the Free Software Foundation; either
13  * version 2.1 of the License, or (at your option) any later version.
14  *
15  * FFmpeg is distributed in the hope that it will be useful,
16  * but WITHOUT ANY WARRANTY; without even the implied warranty of
17  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
18  * Lesser General Public License for more details.
19  *
20  * You should have received a copy of the GNU Lesser General Public
21  * License along with FFmpeg; if not, write to the Free Software
22  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
23  */
24
25 #include <inttypes.h>
26
27 #include "libavutil/attributes.h"
28 #include "libavutil/avassert.h"
29
30 #include "avcodec.h"
31 #include "internal.h"
32 #include "get_bits.h"
33 #include "put_bits.h"
34 #include "lossless_audiodsp.h"
35 #include "wma.h"
36 #include "wma_common.h"
37
38 /** current decoder limitations */
39 #define WMALL_MAX_CHANNELS      8                       ///< max number of handled channels
40 #define MAX_SUBFRAMES          32                       ///< max number of subframes per channel
41 #define MAX_BANDS              29                       ///< max number of scale factor bands
42 #define MAX_FRAMESIZE       32768                       ///< maximum compressed frame size
43 #define MAX_ORDER             256
44
45 #define WMALL_BLOCK_MIN_BITS    6                       ///< log2 of min block size
46 #define WMALL_BLOCK_MAX_BITS   14                       ///< log2 of max block size
47 #define WMALL_BLOCK_MAX_SIZE (1 << WMALL_BLOCK_MAX_BITS)    ///< maximum block size
48 #define WMALL_BLOCK_SIZES    (WMALL_BLOCK_MAX_BITS - WMALL_BLOCK_MIN_BITS + 1) ///< possible block sizes
49
50 #define WMALL_COEFF_PAD_SIZE   16                       ///< pad coef buffers with 0 for use with SIMD
51
52 /**
53  * @brief frame-specific decoder context for a single channel
54  */
55 typedef struct WmallChannelCtx {
56     int16_t     prev_block_len;                         ///< length of the previous block
57     uint8_t     transmit_coefs;
58     uint8_t     num_subframes;
59     uint16_t    subframe_len[MAX_SUBFRAMES];            ///< subframe length in samples
60     uint16_t    subframe_offsets[MAX_SUBFRAMES];        ///< subframe positions in the current frame
61     uint8_t     cur_subframe;                           ///< current subframe number
62     uint16_t    decoded_samples;                        ///< number of already processed samples
63     int         quant_step;                             ///< quantization step for the current subframe
64     int         transient_counter;                      ///< number of transient samples from the beginning of the transient zone
65 } WmallChannelCtx;
66
67 /**
68  * @brief main decoder context
69  */
70 typedef struct WmallDecodeCtx {
71     /* generic decoder variables */
72     AVCodecContext  *avctx;
73     AVFrame         *frame;
74     LLAudDSPContext dsp;                           ///< accelerated DSP functions
75     uint8_t         *frame_data;                    ///< compressed frame data
76     int             max_frame_size;                 ///< max bitstream size
77     PutBitContext   pb;                             ///< context for filling the frame_data buffer
78
79     /* frame size dependent frame information (set during initialization) */
80     uint32_t        decode_flags;                   ///< used compression features
81     int             len_prefix;                     ///< frame is prefixed with its length
82     int             dynamic_range_compression;      ///< frame contains DRC data
83     uint8_t         bits_per_sample;                ///< integer audio sample size for the unscaled IMDCT output (used to scale to [-1.0, 1.0])
84     uint16_t        samples_per_frame;              ///< number of samples to output
85     uint16_t        log2_frame_size;
86     int8_t          num_channels;                   ///< number of channels in the stream (same as AVCodecContext.num_channels)
87     int8_t          lfe_channel;                    ///< lfe channel index
88     uint8_t         max_num_subframes;
89     uint8_t         subframe_len_bits;              ///< number of bits used for the subframe length
90     uint8_t         max_subframe_len_bit;           ///< flag indicating that the subframe is of maximum size when the first subframe length bit is 1
91     uint16_t        min_samples_per_subframe;
92
93     /* packet decode state */
94     GetBitContext   pgb;                            ///< bitstream reader context for the packet
95     int             next_packet_start;              ///< start offset of the next WMA packet in the demuxer packet
96     uint8_t         packet_offset;                  ///< offset to the frame in the packet
97     uint8_t         packet_sequence_number;         ///< current packet number
98     int             num_saved_bits;                 ///< saved number of bits
99     int             frame_offset;                   ///< frame offset in the bit reservoir
100     int             subframe_offset;                ///< subframe offset in the bit reservoir
101     uint8_t         packet_loss;                    ///< set in case of bitstream error
102     uint8_t         packet_done;                    ///< set when a packet is fully decoded
103
104     /* frame decode state */
105     uint32_t        frame_num;                      ///< current frame number (not used for decoding)
106     GetBitContext   gb;                             ///< bitstream reader context
107     int             buf_bit_size;                   ///< buffer size in bits
108     int16_t         *samples_16[WMALL_MAX_CHANNELS]; ///< current sample buffer pointer (16-bit)
109     int32_t         *samples_32[WMALL_MAX_CHANNELS]; ///< current sample buffer pointer (24-bit)
110     uint8_t         drc_gain;                       ///< gain for the DRC tool
111     int8_t          skip_frame;                     ///< skip output step
112     int8_t          parsed_all_subframes;           ///< all subframes decoded?
113
114     /* subframe/block decode state */
115     int16_t         subframe_len;                   ///< current subframe length
116     int8_t          channels_for_cur_subframe;      ///< number of channels that contain the subframe
117     int8_t          channel_indexes_for_cur_subframe[WMALL_MAX_CHANNELS];
118
119     WmallChannelCtx channel[WMALL_MAX_CHANNELS];    ///< per channel data
120
121     // WMA Lossless-specific
122
123     uint8_t do_arith_coding;
124     uint8_t do_ac_filter;
125     uint8_t do_inter_ch_decorr;
126     uint8_t do_mclms;
127     uint8_t do_lpc;
128
129     int8_t  acfilter_order;
130     int8_t  acfilter_scaling;
131     int16_t acfilter_coeffs[16];
132     int     acfilter_prevvalues[WMALL_MAX_CHANNELS][16];
133
134     int8_t  mclms_order;
135     int8_t  mclms_scaling;
136     int16_t mclms_coeffs[WMALL_MAX_CHANNELS * WMALL_MAX_CHANNELS * 32];
137     int16_t mclms_coeffs_cur[WMALL_MAX_CHANNELS * WMALL_MAX_CHANNELS];
138     int32_t mclms_prevvalues[WMALL_MAX_CHANNELS * 2 * 32];
139     int32_t mclms_updates[WMALL_MAX_CHANNELS * 2 * 32];
140     int     mclms_recent;
141
142     int     movave_scaling;
143     int     quant_stepsize;
144
145     struct {
146         int order;
147         int scaling;
148         int coefsend;
149         int bitsend;
150         DECLARE_ALIGNED(16, int16_t, coefs)[MAX_ORDER + WMALL_COEFF_PAD_SIZE/sizeof(int16_t)];
151         DECLARE_ALIGNED(16, int32_t, lms_prevvalues)[MAX_ORDER * 2 + WMALL_COEFF_PAD_SIZE/sizeof(int16_t)];
152         DECLARE_ALIGNED(16, int16_t, lms_updates)[MAX_ORDER * 2 + WMALL_COEFF_PAD_SIZE/sizeof(int16_t)];
153         int recent;
154     } cdlms[WMALL_MAX_CHANNELS][9];
155
156     int cdlms_ttl[WMALL_MAX_CHANNELS];
157
158     int bV3RTM;
159
160     int is_channel_coded[WMALL_MAX_CHANNELS];
161     int update_speed[WMALL_MAX_CHANNELS];
162
163     int transient[WMALL_MAX_CHANNELS];
164     int transient_pos[WMALL_MAX_CHANNELS];
165     int seekable_tile;
166
167     int ave_sum[WMALL_MAX_CHANNELS];
168
169     int channel_residues[WMALL_MAX_CHANNELS][WMALL_BLOCK_MAX_SIZE];
170
171     int lpc_coefs[WMALL_MAX_CHANNELS][40];
172     int lpc_order;
173     int lpc_scaling;
174     int lpc_intbits;
175 } WmallDecodeCtx;
176
177 /** Get sign of integer (1 for positive, -1 for negative and 0 for zero) */
178 #define WMASIGN(x) (((x) > 0) - ((x) < 0))
179
180 static av_cold int decode_init(AVCodecContext *avctx)
181 {
182     WmallDecodeCtx *s  = avctx->priv_data;
183     uint8_t *edata_ptr = avctx->extradata;
184     unsigned int channel_mask;
185     int i, log2_max_num_subframes;
186
187     if (avctx->block_align <= 0) {
188         av_log(avctx, AV_LOG_ERROR, "block_align is not set or invalid\n");
189         return AVERROR(EINVAL);
190     }
191
192     s->max_frame_size = MAX_FRAMESIZE * avctx->channels;
193     s->frame_data = av_mallocz(s->max_frame_size + AV_INPUT_BUFFER_PADDING_SIZE);
194     if (!s->frame_data)
195         return AVERROR(ENOMEM);
196
197     s->avctx = avctx;
198     ff_llauddsp_init(&s->dsp);
199     init_put_bits(&s->pb, s->frame_data, s->max_frame_size);
200
201     if (avctx->extradata_size >= 18) {
202         s->decode_flags    = AV_RL16(edata_ptr + 14);
203         channel_mask       = AV_RL32(edata_ptr +  2);
204         s->bits_per_sample = AV_RL16(edata_ptr);
205         if (s->bits_per_sample == 16)
206             avctx->sample_fmt = AV_SAMPLE_FMT_S16P;
207         else if (s->bits_per_sample == 24) {
208             avctx->sample_fmt = AV_SAMPLE_FMT_S32P;
209             avctx->bits_per_raw_sample = 24;
210         } else {
211             av_log(avctx, AV_LOG_ERROR, "Unknown bit-depth: %"PRIu8"\n",
212                    s->bits_per_sample);
213             return AVERROR_INVALIDDATA;
214         }
215         /* dump the extradata */
216         for (i = 0; i < avctx->extradata_size; i++)
217             ff_dlog(avctx, "[%x] ", avctx->extradata[i]);
218         ff_dlog(avctx, "\n");
219
220     } else {
221         avpriv_request_sample(avctx, "Unsupported extradata size");
222         return AVERROR_PATCHWELCOME;
223     }
224
225     /* generic init */
226     s->log2_frame_size = av_log2(avctx->block_align) + 4;
227
228     /* frame info */
229     s->skip_frame  = 1; /* skip first frame */
230     s->packet_loss = 1;
231     s->len_prefix  = s->decode_flags & 0x40;
232
233     /* get frame len */
234     s->samples_per_frame = 1 << ff_wma_get_frame_len_bits(avctx->sample_rate,
235                                                           3, s->decode_flags);
236     av_assert0(s->samples_per_frame <= WMALL_BLOCK_MAX_SIZE);
237
238     /* init previous block len */
239     for (i = 0; i < avctx->channels; i++)
240         s->channel[i].prev_block_len = s->samples_per_frame;
241
242     /* subframe info */
243     log2_max_num_subframes  = (s->decode_flags & 0x38) >> 3;
244     s->max_num_subframes    = 1 << log2_max_num_subframes;
245     s->max_subframe_len_bit = 0;
246     s->subframe_len_bits    = av_log2(log2_max_num_subframes) + 1;
247
248     s->min_samples_per_subframe  = s->samples_per_frame / s->max_num_subframes;
249     s->dynamic_range_compression = s->decode_flags & 0x80;
250     s->bV3RTM                    = s->decode_flags & 0x100;
251
252     if (s->max_num_subframes > MAX_SUBFRAMES) {
253         av_log(avctx, AV_LOG_ERROR, "invalid number of subframes %"PRIu8"\n",
254                s->max_num_subframes);
255         return AVERROR_INVALIDDATA;
256     }
257
258     s->num_channels = avctx->channels;
259
260     /* extract lfe channel position */
261     s->lfe_channel = -1;
262
263     if (channel_mask & 8) {
264         unsigned int mask;
265         for (mask = 1; mask < 16; mask <<= 1)
266             if (channel_mask & mask)
267                 ++s->lfe_channel;
268     }
269
270     if (s->num_channels < 0) {
271         av_log(avctx, AV_LOG_ERROR, "invalid number of channels %"PRId8"\n",
272                s->num_channels);
273         return AVERROR_INVALIDDATA;
274     } else if (s->num_channels > WMALL_MAX_CHANNELS) {
275         avpriv_request_sample(avctx,
276                               "More than %d channels", WMALL_MAX_CHANNELS);
277         return AVERROR_PATCHWELCOME;
278     }
279
280     s->frame = av_frame_alloc();
281     if (!s->frame)
282         return AVERROR(ENOMEM);
283
284     avctx->channel_layout = channel_mask;
285     return 0;
286 }
287
288 /**
289  * @brief Decode the subframe length.
290  * @param s      context
291  * @param offset sample offset in the frame
292  * @return decoded subframe length on success, < 0 in case of an error
293  */
294 static int decode_subframe_length(WmallDecodeCtx *s, int offset)
295 {
296     int frame_len_ratio, subframe_len, len;
297
298     /* no need to read from the bitstream when only one length is possible */
299     if (offset == s->samples_per_frame - s->min_samples_per_subframe)
300         return s->min_samples_per_subframe;
301
302     len             = av_log2(s->max_num_subframes - 1) + 1;
303     frame_len_ratio = get_bits(&s->gb, len);
304     subframe_len    = s->min_samples_per_subframe * (frame_len_ratio + 1);
305
306     /* sanity check the length */
307     if (subframe_len < s->min_samples_per_subframe ||
308         subframe_len > s->samples_per_frame) {
309         av_log(s->avctx, AV_LOG_ERROR, "broken frame: subframe_len %i\n",
310                subframe_len);
311         return AVERROR_INVALIDDATA;
312     }
313     return subframe_len;
314 }
315
316 /**
317  * @brief Decode how the data in the frame is split into subframes.
318  *       Every WMA frame contains the encoded data for a fixed number of
319  *       samples per channel. The data for every channel might be split
320  *       into several subframes. This function will reconstruct the list of
321  *       subframes for every channel.
322  *
323  *       If the subframes are not evenly split, the algorithm estimates the
324  *       channels with the lowest number of total samples.
325  *       Afterwards, for each of these channels a bit is read from the
326  *       bitstream that indicates if the channel contains a subframe with the
327  *       next subframe size that is going to be read from the bitstream or not.
328  *       If a channel contains such a subframe, the subframe size gets added to
329  *       the channel's subframe list.
330  *       The algorithm repeats these steps until the frame is properly divided
331  *       between the individual channels.
332  *
333  * @param s context
334  * @return 0 on success, < 0 in case of an error
335  */
336 static int decode_tilehdr(WmallDecodeCtx *s)
337 {
338     uint16_t num_samples[WMALL_MAX_CHANNELS] = { 0 }; /* sum of samples for all currently known subframes of a channel */
339     uint8_t  contains_subframe[WMALL_MAX_CHANNELS];   /* flag indicating if a channel contains the current subframe */
340     int channels_for_cur_subframe = s->num_channels;  /* number of channels that contain the current subframe */
341     int fixed_channel_layout = 0;                     /* flag indicating that all channels use the same subfra2me offsets and sizes */
342     int min_channel_len = 0;                          /* smallest sum of samples (channels with this length will be processed first) */
343     int c, tile_aligned;
344
345     /* reset tiling information */
346     for (c = 0; c < s->num_channels; c++)
347         s->channel[c].num_subframes = 0;
348
349     tile_aligned = get_bits1(&s->gb);
350     if (s->max_num_subframes == 1 || tile_aligned)
351         fixed_channel_layout = 1;
352
353     /* loop until the frame data is split between the subframes */
354     do {
355         int subframe_len, in_use = 0;
356
357         /* check which channels contain the subframe */
358         for (c = 0; c < s->num_channels; c++) {
359             if (num_samples[c] == min_channel_len) {
360                 if (fixed_channel_layout || channels_for_cur_subframe == 1 ||
361                    (min_channel_len == s->samples_per_frame - s->min_samples_per_subframe)) {
362                     contains_subframe[c] = 1;
363                 } else {
364                     contains_subframe[c] = get_bits1(&s->gb);
365                 }
366                 in_use |= contains_subframe[c];
367             } else
368                 contains_subframe[c] = 0;
369         }
370
371         if (!in_use) {
372             av_log(s->avctx, AV_LOG_ERROR,
373                    "Found empty subframe\n");
374             return AVERROR_INVALIDDATA;
375         }
376
377         /* get subframe length, subframe_len == 0 is not allowed */
378         if ((subframe_len = decode_subframe_length(s, min_channel_len)) <= 0)
379             return AVERROR_INVALIDDATA;
380         /* add subframes to the individual channels and find new min_channel_len */
381         min_channel_len += subframe_len;
382         for (c = 0; c < s->num_channels; c++) {
383             WmallChannelCtx *chan = &s->channel[c];
384
385             if (contains_subframe[c]) {
386                 if (chan->num_subframes >= MAX_SUBFRAMES) {
387                     av_log(s->avctx, AV_LOG_ERROR,
388                            "broken frame: num subframes > 31\n");
389                     return AVERROR_INVALIDDATA;
390                 }
391                 chan->subframe_len[chan->num_subframes] = subframe_len;
392                 num_samples[c] += subframe_len;
393                 ++chan->num_subframes;
394                 if (num_samples[c] > s->samples_per_frame) {
395                     av_log(s->avctx, AV_LOG_ERROR, "broken frame: "
396                            "channel len(%"PRIu16") > samples_per_frame(%"PRIu16")\n",
397                            num_samples[c], s->samples_per_frame);
398                     return AVERROR_INVALIDDATA;
399                 }
400             } else if (num_samples[c] <= min_channel_len) {
401                 if (num_samples[c] < min_channel_len) {
402                     channels_for_cur_subframe = 0;
403                     min_channel_len = num_samples[c];
404                 }
405                 ++channels_for_cur_subframe;
406             }
407         }
408     } while (min_channel_len < s->samples_per_frame);
409
410     for (c = 0; c < s->num_channels; c++) {
411         int i, offset = 0;
412         for (i = 0; i < s->channel[c].num_subframes; i++) {
413             s->channel[c].subframe_offsets[i] = offset;
414             offset += s->channel[c].subframe_len[i];
415         }
416     }
417
418     return 0;
419 }
420
421 static void decode_ac_filter(WmallDecodeCtx *s)
422 {
423     int i;
424     s->acfilter_order   = get_bits(&s->gb, 4) + 1;
425     s->acfilter_scaling = get_bits(&s->gb, 4);
426
427     for (i = 0; i < s->acfilter_order; i++)
428         s->acfilter_coeffs[i] = get_bitsz(&s->gb, s->acfilter_scaling) + 1;
429 }
430
431 static void decode_mclms(WmallDecodeCtx *s)
432 {
433     s->mclms_order   = (get_bits(&s->gb, 4) + 1) * 2;
434     s->mclms_scaling = get_bits(&s->gb, 4);
435     if (get_bits1(&s->gb)) {
436         int i, send_coef_bits;
437         int cbits = av_log2(s->mclms_scaling + 1);
438         if (1 << cbits < s->mclms_scaling + 1)
439             cbits++;
440
441         send_coef_bits = get_bitsz(&s->gb, cbits) + 2;
442
443         for (i = 0; i < s->mclms_order * s->num_channels * s->num_channels; i++)
444             s->mclms_coeffs[i] = get_bits(&s->gb, send_coef_bits);
445
446         for (i = 0; i < s->num_channels; i++) {
447             int c;
448             for (c = 0; c < i; c++)
449                 s->mclms_coeffs_cur[i * s->num_channels + c] = get_bits(&s->gb, send_coef_bits);
450         }
451     }
452 }
453
454 static int decode_cdlms(WmallDecodeCtx *s)
455 {
456     int c, i;
457     int cdlms_send_coef = get_bits1(&s->gb);
458
459     for (c = 0; c < s->num_channels; c++) {
460         s->cdlms_ttl[c] = get_bits(&s->gb, 3) + 1;
461         for (i = 0; i < s->cdlms_ttl[c]; i++) {
462             s->cdlms[c][i].order = (get_bits(&s->gb, 7) + 1) * 8;
463             if (s->cdlms[c][i].order > MAX_ORDER) {
464                 av_log(s->avctx, AV_LOG_ERROR,
465                        "Order[%d][%d] %d > max (%d), not supported\n",
466                        c, i, s->cdlms[c][i].order, MAX_ORDER);
467                 s->cdlms[0][0].order = 0;
468                 return AVERROR_INVALIDDATA;
469             }
470             if(s->cdlms[c][i].order & 8 && s->bits_per_sample == 16) {
471                 static int warned;
472                 if(!warned)
473                     avpriv_request_sample(s->avctx, "CDLMS of order %d",
474                                           s->cdlms[c][i].order);
475                 warned = 1;
476             }
477         }
478
479         for (i = 0; i < s->cdlms_ttl[c]; i++)
480             s->cdlms[c][i].scaling = get_bits(&s->gb, 4);
481
482         if (cdlms_send_coef) {
483             for (i = 0; i < s->cdlms_ttl[c]; i++) {
484                 int cbits, shift_l, shift_r, j;
485                 cbits = av_log2(s->cdlms[c][i].order);
486                 if ((1 << cbits) < s->cdlms[c][i].order)
487                     cbits++;
488                 s->cdlms[c][i].coefsend = get_bits(&s->gb, cbits) + 1;
489
490                 cbits = av_log2(s->cdlms[c][i].scaling + 1);
491                 if ((1 << cbits) < s->cdlms[c][i].scaling + 1)
492                     cbits++;
493
494                 s->cdlms[c][i].bitsend = get_bitsz(&s->gb, cbits) + 2;
495                 shift_l = 32 - s->cdlms[c][i].bitsend;
496                 shift_r = 32 - s->cdlms[c][i].scaling - 2;
497                 for (j = 0; j < s->cdlms[c][i].coefsend; j++)
498                     s->cdlms[c][i].coefs[j] =
499                         (get_bits(&s->gb, s->cdlms[c][i].bitsend) << shift_l) >> shift_r;
500             }
501         }
502
503         for (i = 0; i < s->cdlms_ttl[c]; i++)
504             memset(s->cdlms[c][i].coefs + s->cdlms[c][i].order,
505                    0, WMALL_COEFF_PAD_SIZE);
506     }
507
508     return 0;
509 }
510
511 static int decode_channel_residues(WmallDecodeCtx *s, int ch, int tile_size)
512 {
513     int i = 0;
514     unsigned int ave_mean;
515     s->transient[ch] = get_bits1(&s->gb);
516     if (s->transient[ch]) {
517         s->transient_pos[ch] = get_bits(&s->gb, av_log2(tile_size));
518         if (s->transient_pos[ch])
519             s->transient[ch] = 0;
520         s->channel[ch].transient_counter =
521             FFMAX(s->channel[ch].transient_counter, s->samples_per_frame / 2);
522     } else if (s->channel[ch].transient_counter)
523         s->transient[ch] = 1;
524
525     if (s->seekable_tile) {
526         ave_mean = get_bits(&s->gb, s->bits_per_sample);
527         s->ave_sum[ch] = ave_mean << (s->movave_scaling + 1);
528     }
529
530     if (s->seekable_tile) {
531         if (s->do_inter_ch_decorr)
532             s->channel_residues[ch][0] = get_sbits_long(&s->gb, s->bits_per_sample + 1);
533         else
534             s->channel_residues[ch][0] = get_sbits_long(&s->gb, s->bits_per_sample);
535         i++;
536     }
537     for (; i < tile_size; i++) {
538         int quo = 0, rem, rem_bits, residue;
539         while(get_bits1(&s->gb)) {
540             quo++;
541             if (get_bits_left(&s->gb) <= 0)
542                 return -1;
543         }
544         if (quo >= 32)
545             quo += get_bits_long(&s->gb, get_bits(&s->gb, 5) + 1);
546
547         ave_mean = (s->ave_sum[ch] + (1 << s->movave_scaling)) >> (s->movave_scaling + 1);
548         if (ave_mean <= 1)
549             residue = quo;
550         else {
551             rem_bits = av_ceil_log2(ave_mean);
552             rem      = get_bits_long(&s->gb, rem_bits);
553             residue  = (quo << rem_bits) + rem;
554         }
555
556         s->ave_sum[ch] = residue + s->ave_sum[ch] -
557                          (s->ave_sum[ch] >> s->movave_scaling);
558
559         residue = (residue >> 1) ^ -(residue & 1);
560         s->channel_residues[ch][i] = residue;
561     }
562
563     return 0;
564
565 }
566
567 static void decode_lpc(WmallDecodeCtx *s)
568 {
569     int ch, i, cbits;
570     s->lpc_order   = get_bits(&s->gb, 5) + 1;
571     s->lpc_scaling = get_bits(&s->gb, 4);
572     s->lpc_intbits = get_bits(&s->gb, 3) + 1;
573     cbits = s->lpc_scaling + s->lpc_intbits;
574     for (ch = 0; ch < s->num_channels; ch++)
575         for (i = 0; i < s->lpc_order; i++)
576             s->lpc_coefs[ch][i] = get_sbits(&s->gb, cbits);
577 }
578
579 static void clear_codec_buffers(WmallDecodeCtx *s)
580 {
581     int ich, ilms;
582
583     memset(s->acfilter_coeffs,     0, sizeof(s->acfilter_coeffs));
584     memset(s->acfilter_prevvalues, 0, sizeof(s->acfilter_prevvalues));
585     memset(s->lpc_coefs,           0, sizeof(s->lpc_coefs));
586
587     memset(s->mclms_coeffs,     0, sizeof(s->mclms_coeffs));
588     memset(s->mclms_coeffs_cur, 0, sizeof(s->mclms_coeffs_cur));
589     memset(s->mclms_prevvalues, 0, sizeof(s->mclms_prevvalues));
590     memset(s->mclms_updates,    0, sizeof(s->mclms_updates));
591
592     for (ich = 0; ich < s->num_channels; ich++) {
593         for (ilms = 0; ilms < s->cdlms_ttl[ich]; ilms++) {
594             memset(s->cdlms[ich][ilms].coefs, 0,
595                    sizeof(s->cdlms[ich][ilms].coefs));
596             memset(s->cdlms[ich][ilms].lms_prevvalues, 0,
597                    sizeof(s->cdlms[ich][ilms].lms_prevvalues));
598             memset(s->cdlms[ich][ilms].lms_updates, 0,
599                    sizeof(s->cdlms[ich][ilms].lms_updates));
600         }
601         s->ave_sum[ich] = 0;
602     }
603 }
604
605 /**
606  * @brief Reset filter parameters and transient area at new seekable tile.
607  */
608 static void reset_codec(WmallDecodeCtx *s)
609 {
610     int ich, ilms;
611     s->mclms_recent = s->mclms_order * s->num_channels;
612     for (ich = 0; ich < s->num_channels; ich++) {
613         for (ilms = 0; ilms < s->cdlms_ttl[ich]; ilms++)
614             s->cdlms[ich][ilms].recent = s->cdlms[ich][ilms].order;
615         /* first sample of a seekable subframe is considered as the starting of
616             a transient area which is samples_per_frame samples long */
617         s->channel[ich].transient_counter = s->samples_per_frame;
618         s->transient[ich]     = 1;
619         s->transient_pos[ich] = 0;
620     }
621 }
622
623 static void mclms_update(WmallDecodeCtx *s, int icoef, int *pred)
624 {
625     int i, j, ich, pred_error;
626     int order        = s->mclms_order;
627     int num_channels = s->num_channels;
628     int range        = 1 << (s->bits_per_sample - 1);
629
630     for (ich = 0; ich < num_channels; ich++) {
631         pred_error = s->channel_residues[ich][icoef] - (unsigned)pred[ich];
632         if (pred_error > 0) {
633             for (i = 0; i < order * num_channels; i++)
634                 s->mclms_coeffs[i + ich * order * num_channels] +=
635                     s->mclms_updates[s->mclms_recent + i];
636             for (j = 0; j < ich; j++)
637                 s->mclms_coeffs_cur[ich * num_channels + j] += WMASIGN(s->channel_residues[j][icoef]);
638         } else if (pred_error < 0) {
639             for (i = 0; i < order * num_channels; i++)
640                 s->mclms_coeffs[i + ich * order * num_channels] -=
641                     s->mclms_updates[s->mclms_recent + i];
642             for (j = 0; j < ich; j++)
643                 s->mclms_coeffs_cur[ich * num_channels + j] -= WMASIGN(s->channel_residues[j][icoef]);
644         }
645     }
646
647     for (ich = num_channels - 1; ich >= 0; ich--) {
648         s->mclms_recent--;
649         s->mclms_prevvalues[s->mclms_recent] = av_clip(s->channel_residues[ich][icoef],
650             -range, range - 1);
651         s->mclms_updates[s->mclms_recent] = WMASIGN(s->channel_residues[ich][icoef]);
652     }
653
654     if (s->mclms_recent == 0) {
655         memcpy(&s->mclms_prevvalues[order * num_channels],
656                s->mclms_prevvalues,
657                sizeof(int32_t) * order * num_channels);
658         memcpy(&s->mclms_updates[order * num_channels],
659                s->mclms_updates,
660                sizeof(int32_t) * order * num_channels);
661         s->mclms_recent = num_channels * order;
662     }
663 }
664
665 static void mclms_predict(WmallDecodeCtx *s, int icoef, int *pred)
666 {
667     int ich, i;
668     int order        = s->mclms_order;
669     int num_channels = s->num_channels;
670
671     for (ich = 0; ich < num_channels; ich++) {
672         pred[ich] = 0;
673         if (!s->is_channel_coded[ich])
674             continue;
675         for (i = 0; i < order * num_channels; i++)
676             pred[ich] += (uint32_t)s->mclms_prevvalues[i + s->mclms_recent] *
677                          s->mclms_coeffs[i + order * num_channels * ich];
678         for (i = 0; i < ich; i++)
679             pred[ich] += (uint32_t)s->channel_residues[i][icoef] *
680                          s->mclms_coeffs_cur[i + num_channels * ich];
681         pred[ich] += (1 << s->mclms_scaling) >> 1;
682         pred[ich] >>= s->mclms_scaling;
683         s->channel_residues[ich][icoef] += (unsigned)pred[ich];
684     }
685 }
686
687 static void revert_mclms(WmallDecodeCtx *s, int tile_size)
688 {
689     int icoef, pred[WMALL_MAX_CHANNELS] = { 0 };
690     for (icoef = 0; icoef < tile_size; icoef++) {
691         mclms_predict(s, icoef, pred);
692         mclms_update(s, icoef, pred);
693     }
694 }
695
696 static void use_high_update_speed(WmallDecodeCtx *s, int ich)
697 {
698     int ilms, recent, icoef;
699     for (ilms = s->cdlms_ttl[ich] - 1; ilms >= 0; ilms--) {
700         recent = s->cdlms[ich][ilms].recent;
701         if (s->update_speed[ich] == 16)
702             continue;
703         if (s->bV3RTM) {
704             for (icoef = 0; icoef < s->cdlms[ich][ilms].order; icoef++)
705                 s->cdlms[ich][ilms].lms_updates[icoef + recent] *= 2;
706         } else {
707             for (icoef = 0; icoef < s->cdlms[ich][ilms].order; icoef++)
708                 s->cdlms[ich][ilms].lms_updates[icoef] *= 2;
709         }
710     }
711     s->update_speed[ich] = 16;
712 }
713
714 static void use_normal_update_speed(WmallDecodeCtx *s, int ich)
715 {
716     int ilms, recent, icoef;
717     for (ilms = s->cdlms_ttl[ich] - 1; ilms >= 0; ilms--) {
718         recent = s->cdlms[ich][ilms].recent;
719         if (s->update_speed[ich] == 8)
720             continue;
721         if (s->bV3RTM)
722             for (icoef = 0; icoef < s->cdlms[ich][ilms].order; icoef++)
723                 s->cdlms[ich][ilms].lms_updates[icoef + recent] /= 2;
724         else
725             for (icoef = 0; icoef < s->cdlms[ich][ilms].order; icoef++)
726                 s->cdlms[ich][ilms].lms_updates[icoef] /= 2;
727     }
728     s->update_speed[ich] = 8;
729 }
730
731 #define CD_LMS(bits, ROUND) \
732 static void lms_update ## bits (WmallDecodeCtx *s, int ich, int ilms, int input) \
733 { \
734     int recent = s->cdlms[ich][ilms].recent; \
735     int range  = 1 << s->bits_per_sample - 1; \
736     int order  = s->cdlms[ich][ilms].order; \
737     int ##bits##_t *prev = (int##bits##_t *)s->cdlms[ich][ilms].lms_prevvalues; \
738  \
739     if (recent) \
740         recent--; \
741     else { \
742         memcpy(prev + order, prev, (bits/8) * order); \
743         memcpy(s->cdlms[ich][ilms].lms_updates + order, \
744                s->cdlms[ich][ilms].lms_updates, \
745                sizeof(*s->cdlms[ich][ilms].lms_updates) * order); \
746         recent = order - 1; \
747     } \
748  \
749     prev[recent] = av_clip(input, -range, range - 1); \
750     s->cdlms[ich][ilms].lms_updates[recent] = WMASIGN(input) * s->update_speed[ich]; \
751  \
752     s->cdlms[ich][ilms].lms_updates[recent + (order >> 4)] >>= 2; \
753     s->cdlms[ich][ilms].lms_updates[recent + (order >> 3)] >>= 1; \
754     s->cdlms[ich][ilms].recent = recent; \
755     memset(s->cdlms[ich][ilms].lms_updates + recent + order, 0, \
756            sizeof(s->cdlms[ich][ilms].lms_updates) - \
757            sizeof(*s->cdlms[ich][ilms].lms_updates)*(recent+order)); \
758 } \
759  \
760 static void revert_cdlms ## bits (WmallDecodeCtx *s, int ch, \
761                                   int coef_begin, int coef_end) \
762 { \
763     int icoef, pred, ilms, num_lms, residue, input; \
764  \
765     num_lms = s->cdlms_ttl[ch]; \
766     for (ilms = num_lms - 1; ilms >= 0; ilms--) { \
767         for (icoef = coef_begin; icoef < coef_end; icoef++) { \
768             int##bits##_t *prevvalues = (int##bits##_t *)s->cdlms[ch][ilms].lms_prevvalues; \
769             pred = (1 << s->cdlms[ch][ilms].scaling) >> 1; \
770             residue = s->channel_residues[ch][icoef]; \
771             pred += s->dsp.scalarproduct_and_madd_int## bits (s->cdlms[ch][ilms].coefs, \
772                                                         prevvalues + s->cdlms[ch][ilms].recent, \
773                                                         s->cdlms[ch][ilms].lms_updates + \
774                                                         s->cdlms[ch][ilms].recent, \
775                                                         FFALIGN(s->cdlms[ch][ilms].order, ROUND), \
776                                                         WMASIGN(residue)); \
777             input = residue + (pred >> s->cdlms[ch][ilms].scaling); \
778             lms_update ## bits(s, ch, ilms, input); \
779             s->channel_residues[ch][icoef] = input; \
780         } \
781     } \
782     if (bits <= 16) emms_c(); \
783 }
784
785 CD_LMS(16, WMALL_COEFF_PAD_SIZE)
786 CD_LMS(32, 8)
787
788 static void revert_inter_ch_decorr(WmallDecodeCtx *s, int tile_size)
789 {
790     if (s->num_channels != 2)
791         return;
792     else if (s->is_channel_coded[0] || s->is_channel_coded[1]) {
793         int icoef;
794         for (icoef = 0; icoef < tile_size; icoef++) {
795             s->channel_residues[0][icoef] -= s->channel_residues[1][icoef] >> 1;
796             s->channel_residues[1][icoef] += s->channel_residues[0][icoef];
797         }
798     }
799 }
800
801 static void revert_acfilter(WmallDecodeCtx *s, int tile_size)
802 {
803     int ich, pred, i, j;
804     int16_t *filter_coeffs = s->acfilter_coeffs;
805     int scaling            = s->acfilter_scaling;
806     int order              = s->acfilter_order;
807
808     for (ich = 0; ich < s->num_channels; ich++) {
809         int *prevvalues = s->acfilter_prevvalues[ich];
810         for (i = 0; i < order; i++) {
811             pred = 0;
812             for (j = 0; j < order; j++) {
813                 if (i <= j)
814                     pred += (uint32_t)filter_coeffs[j] * prevvalues[j - i];
815                 else
816                     pred += (uint32_t)s->channel_residues[ich][i - j - 1] * filter_coeffs[j];
817             }
818             pred >>= scaling;
819             s->channel_residues[ich][i] += (unsigned)pred;
820         }
821         for (i = order; i < tile_size; i++) {
822             pred = 0;
823             for (j = 0; j < order; j++)
824                 pred += (uint32_t)s->channel_residues[ich][i - j - 1] * filter_coeffs[j];
825             pred >>= scaling;
826             s->channel_residues[ich][i] += (unsigned)pred;
827         }
828         for (j = 0; j < order; j++)
829             prevvalues[j] = s->channel_residues[ich][tile_size - j - 1];
830     }
831 }
832
833 static int decode_subframe(WmallDecodeCtx *s)
834 {
835     int offset        = s->samples_per_frame;
836     int subframe_len  = s->samples_per_frame;
837     int total_samples = s->samples_per_frame * s->num_channels;
838     int i, j, rawpcm_tile, padding_zeroes, res;
839
840     s->subframe_offset = get_bits_count(&s->gb);
841
842     /* reset channel context and find the next block offset and size
843         == the next block of the channel with the smallest number of
844         decoded samples */
845     for (i = 0; i < s->num_channels; i++) {
846         if (offset > s->channel[i].decoded_samples) {
847             offset = s->channel[i].decoded_samples;
848             subframe_len =
849                 s->channel[i].subframe_len[s->channel[i].cur_subframe];
850         }
851     }
852
853     /* get a list of all channels that contain the estimated block */
854     s->channels_for_cur_subframe = 0;
855     for (i = 0; i < s->num_channels; i++) {
856         const int cur_subframe = s->channel[i].cur_subframe;
857         /* subtract already processed samples */
858         total_samples -= s->channel[i].decoded_samples;
859
860         /* and count if there are multiple subframes that match our profile */
861         if (offset == s->channel[i].decoded_samples &&
862             subframe_len == s->channel[i].subframe_len[cur_subframe]) {
863             total_samples -= s->channel[i].subframe_len[cur_subframe];
864             s->channel[i].decoded_samples +=
865                 s->channel[i].subframe_len[cur_subframe];
866             s->channel_indexes_for_cur_subframe[s->channels_for_cur_subframe] = i;
867             ++s->channels_for_cur_subframe;
868         }
869     }
870
871     /* check if the frame will be complete after processing the
872         estimated block */
873     if (!total_samples)
874         s->parsed_all_subframes = 1;
875
876
877     s->seekable_tile = get_bits1(&s->gb);
878     if (s->seekable_tile) {
879         clear_codec_buffers(s);
880
881         s->do_arith_coding    = get_bits1(&s->gb);
882         if (s->do_arith_coding) {
883             avpriv_request_sample(s->avctx, "Arithmetic coding");
884             return AVERROR_PATCHWELCOME;
885         }
886         s->do_ac_filter       = get_bits1(&s->gb);
887         s->do_inter_ch_decorr = get_bits1(&s->gb);
888         s->do_mclms           = get_bits1(&s->gb);
889
890         if (s->do_ac_filter)
891             decode_ac_filter(s);
892
893         if (s->do_mclms)
894             decode_mclms(s);
895
896         if ((res = decode_cdlms(s)) < 0)
897             return res;
898         s->movave_scaling = get_bits(&s->gb, 3);
899         s->quant_stepsize = get_bits(&s->gb, 8) + 1;
900
901         reset_codec(s);
902     }
903
904     rawpcm_tile = get_bits1(&s->gb);
905
906     if (!rawpcm_tile && !s->cdlms[0][0].order) {
907         av_log(s->avctx, AV_LOG_DEBUG,
908                "Waiting for seekable tile\n");
909         av_frame_unref(s->frame);
910         return -1;
911     }
912
913
914     for (i = 0; i < s->num_channels; i++)
915         s->is_channel_coded[i] = 1;
916
917     if (!rawpcm_tile) {
918         for (i = 0; i < s->num_channels; i++)
919             s->is_channel_coded[i] = get_bits1(&s->gb);
920
921         if (s->bV3RTM) {
922             // LPC
923             s->do_lpc = get_bits1(&s->gb);
924             if (s->do_lpc) {
925                 decode_lpc(s);
926                 avpriv_request_sample(s->avctx, "Expect wrong output since "
927                                       "inverse LPC filter");
928             }
929         } else
930             s->do_lpc = 0;
931     }
932
933
934     if (get_bits1(&s->gb))
935         padding_zeroes = get_bits(&s->gb, 5);
936     else
937         padding_zeroes = 0;
938
939     if (rawpcm_tile) {
940         int bits = s->bits_per_sample - padding_zeroes;
941         if (bits <= 0) {
942             av_log(s->avctx, AV_LOG_ERROR,
943                    "Invalid number of padding bits in raw PCM tile\n");
944             return AVERROR_INVALIDDATA;
945         }
946         ff_dlog(s->avctx, "RAWPCM %d bits per sample. "
947                 "total %d bits, remain=%d\n", bits,
948                 bits * s->num_channels * subframe_len, get_bits_count(&s->gb));
949         for (i = 0; i < s->num_channels; i++)
950             for (j = 0; j < subframe_len; j++)
951                 s->channel_residues[i][j] = get_sbits_long(&s->gb, bits);
952     } else {
953         if (s->bits_per_sample < padding_zeroes)
954             return AVERROR_INVALIDDATA;
955         for (i = 0; i < s->num_channels; i++) {
956             if (s->is_channel_coded[i]) {
957                 decode_channel_residues(s, i, subframe_len);
958                 if (s->seekable_tile)
959                     use_high_update_speed(s, i);
960                 else
961                     use_normal_update_speed(s, i);
962                 if (s->bits_per_sample > 16)
963                     revert_cdlms32(s, i, 0, subframe_len);
964                 else
965                     revert_cdlms16(s, i, 0, subframe_len);
966             } else {
967                 memset(s->channel_residues[i], 0, sizeof(**s->channel_residues) * subframe_len);
968             }
969         }
970
971         if (s->do_mclms)
972             revert_mclms(s, subframe_len);
973         if (s->do_inter_ch_decorr)
974             revert_inter_ch_decorr(s, subframe_len);
975         if (s->do_ac_filter)
976             revert_acfilter(s, subframe_len);
977
978         /* Dequantize */
979         if (s->quant_stepsize != 1)
980             for (i = 0; i < s->num_channels; i++)
981                 for (j = 0; j < subframe_len; j++)
982                     s->channel_residues[i][j] *= s->quant_stepsize;
983     }
984
985     /* Write to proper output buffer depending on bit-depth */
986     for (i = 0; i < s->channels_for_cur_subframe; i++) {
987         int c = s->channel_indexes_for_cur_subframe[i];
988         int subframe_len = s->channel[c].subframe_len[s->channel[c].cur_subframe];
989
990         for (j = 0; j < subframe_len; j++) {
991             if (s->bits_per_sample == 16) {
992                 *s->samples_16[c]++ = (int16_t) s->channel_residues[c][j] * (1 << padding_zeroes);
993             } else {
994                 *s->samples_32[c]++ = s->channel_residues[c][j] * (256 << padding_zeroes);
995             }
996         }
997     }
998
999     /* handled one subframe */
1000     for (i = 0; i < s->channels_for_cur_subframe; i++) {
1001         int c = s->channel_indexes_for_cur_subframe[i];
1002         if (s->channel[c].cur_subframe >= s->channel[c].num_subframes) {
1003             av_log(s->avctx, AV_LOG_ERROR, "broken subframe\n");
1004             return AVERROR_INVALIDDATA;
1005         }
1006         ++s->channel[c].cur_subframe;
1007     }
1008     return 0;
1009 }
1010
1011 /**
1012  * @brief Decode one WMA frame.
1013  * @param s codec context
1014  * @return 0 if the trailer bit indicates that this is the last frame,
1015  *         1 if there are additional frames
1016  */
1017 static int decode_frame(WmallDecodeCtx *s)
1018 {
1019     GetBitContext* gb = &s->gb;
1020     int more_frames = 0, len = 0, i, ret;
1021
1022     s->frame->nb_samples = s->samples_per_frame;
1023     if ((ret = ff_get_buffer(s->avctx, s->frame, 0)) < 0) {
1024         /* return an error if no frame could be decoded at all */
1025         s->packet_loss = 1;
1026         s->frame->nb_samples = 0;
1027         return ret;
1028     }
1029     for (i = 0; i < s->num_channels; i++) {
1030         s->samples_16[i] = (int16_t *)s->frame->extended_data[i];
1031         s->samples_32[i] = (int32_t *)s->frame->extended_data[i];
1032     }
1033
1034     /* get frame length */
1035     if (s->len_prefix)
1036         len = get_bits(gb, s->log2_frame_size);
1037
1038     /* decode tile information */
1039     if ((ret = decode_tilehdr(s))) {
1040         s->packet_loss = 1;
1041         av_frame_unref(s->frame);
1042         return ret;
1043     }
1044
1045     /* read drc info */
1046     if (s->dynamic_range_compression)
1047         s->drc_gain = get_bits(gb, 8);
1048
1049     /* no idea what these are for, might be the number of samples
1050        that need to be skipped at the beginning or end of a stream */
1051     if (get_bits1(gb)) {
1052         int av_unused skip;
1053
1054         /* usually true for the first frame */
1055         if (get_bits1(gb)) {
1056             skip = get_bits(gb, av_log2(s->samples_per_frame * 2));
1057             ff_dlog(s->avctx, "start skip: %i\n", skip);
1058         }
1059
1060         /* sometimes true for the last frame */
1061         if (get_bits1(gb)) {
1062             skip = get_bits(gb, av_log2(s->samples_per_frame * 2));
1063             ff_dlog(s->avctx, "end skip: %i\n", skip);
1064             s->frame->nb_samples -= skip;
1065             if (s->frame->nb_samples <= 0)
1066                 return AVERROR_INVALIDDATA;
1067         }
1068
1069     }
1070
1071     /* reset subframe states */
1072     s->parsed_all_subframes = 0;
1073     for (i = 0; i < s->num_channels; i++) {
1074         s->channel[i].decoded_samples = 0;
1075         s->channel[i].cur_subframe    = 0;
1076     }
1077
1078     /* decode all subframes */
1079     while (!s->parsed_all_subframes) {
1080         int decoded_samples = s->channel[0].decoded_samples;
1081         if (decode_subframe(s) < 0) {
1082             s->packet_loss = 1;
1083             if (s->frame->nb_samples)
1084                 s->frame->nb_samples = decoded_samples;
1085             return 0;
1086         }
1087     }
1088
1089     ff_dlog(s->avctx, "Frame done\n");
1090
1091     s->skip_frame = 0;
1092
1093     if (s->len_prefix) {
1094         if (len != (get_bits_count(gb) - s->frame_offset) + 2) {
1095             /* FIXME: not sure if this is always an error */
1096             av_log(s->avctx, AV_LOG_ERROR,
1097                    "frame[%"PRIu32"] would have to skip %i bits\n",
1098                    s->frame_num,
1099                    len - (get_bits_count(gb) - s->frame_offset) - 1);
1100             s->packet_loss = 1;
1101             return 0;
1102         }
1103
1104         /* skip the rest of the frame data */
1105         skip_bits_long(gb, len - (get_bits_count(gb) - s->frame_offset) - 1);
1106     }
1107
1108     /* decode trailer bit */
1109     more_frames = get_bits1(gb);
1110     ++s->frame_num;
1111     return more_frames;
1112 }
1113
1114 /**
1115  * @brief Calculate remaining input buffer length.
1116  * @param s  codec context
1117  * @param gb bitstream reader context
1118  * @return remaining size in bits
1119  */
1120 static int remaining_bits(WmallDecodeCtx *s, GetBitContext *gb)
1121 {
1122     return s->buf_bit_size - get_bits_count(gb);
1123 }
1124
1125 /**
1126  * @brief Fill the bit reservoir with a (partial) frame.
1127  * @param s      codec context
1128  * @param gb     bitstream reader context
1129  * @param len    length of the partial frame
1130  * @param append decides whether to reset the buffer or not
1131  */
1132 static void save_bits(WmallDecodeCtx *s, GetBitContext* gb, int len,
1133                       int append)
1134 {
1135     int buflen;
1136     PutBitContext tmp;
1137
1138     /* when the frame data does not need to be concatenated, the input buffer
1139         is reset and additional bits from the previous frame are copied
1140         and skipped later so that a fast byte copy is possible */
1141
1142     if (!append) {
1143         s->frame_offset   = get_bits_count(gb) & 7;
1144         s->num_saved_bits = s->frame_offset;
1145         init_put_bits(&s->pb, s->frame_data, s->max_frame_size);
1146     }
1147
1148     buflen = (s->num_saved_bits + len + 8) >> 3;
1149
1150     if (len <= 0 || buflen > s->max_frame_size) {
1151         avpriv_request_sample(s->avctx, "Too small input buffer");
1152         s->packet_loss = 1;
1153         s->num_saved_bits = 0;
1154         return;
1155     }
1156
1157     s->num_saved_bits += len;
1158     if (!append) {
1159         avpriv_copy_bits(&s->pb, gb->buffer + (get_bits_count(gb) >> 3),
1160                          s->num_saved_bits);
1161     } else {
1162         int align = 8 - (get_bits_count(gb) & 7);
1163         align = FFMIN(align, len);
1164         put_bits(&s->pb, align, get_bits(gb, align));
1165         len -= align;
1166         avpriv_copy_bits(&s->pb, gb->buffer + (get_bits_count(gb) >> 3), len);
1167     }
1168     skip_bits_long(gb, len);
1169
1170     tmp = s->pb;
1171     flush_put_bits(&tmp);
1172
1173     init_get_bits(&s->gb, s->frame_data, s->num_saved_bits);
1174     skip_bits(&s->gb, s->frame_offset);
1175 }
1176
1177 static int decode_packet(AVCodecContext *avctx, void *data, int *got_frame_ptr,
1178                          AVPacket* avpkt)
1179 {
1180     WmallDecodeCtx *s = avctx->priv_data;
1181     GetBitContext* gb  = &s->pgb;
1182     const uint8_t* buf = avpkt->data;
1183     int buf_size       = avpkt->size;
1184     int num_bits_prev_frame, packet_sequence_number, spliced_packet;
1185
1186     s->frame->nb_samples = 0;
1187
1188     if (!buf_size && s->num_saved_bits > get_bits_count(&s->gb)) {
1189         s->packet_done = 0;
1190         if (!decode_frame(s))
1191             s->num_saved_bits = 0;
1192     } else if (s->packet_done || s->packet_loss) {
1193         s->packet_done = 0;
1194
1195         if (!buf_size)
1196             return 0;
1197
1198         s->next_packet_start = buf_size - FFMIN(avctx->block_align, buf_size);
1199         buf_size             = FFMIN(avctx->block_align, buf_size);
1200         s->buf_bit_size      = buf_size << 3;
1201
1202         /* parse packet header */
1203         init_get_bits(gb, buf, s->buf_bit_size);
1204         packet_sequence_number = get_bits(gb, 4);
1205         skip_bits(gb, 1);   // Skip seekable_frame_in_packet, currently unused
1206         spliced_packet = get_bits1(gb);
1207         if (spliced_packet)
1208             avpriv_request_sample(avctx, "Bitstream splicing");
1209
1210         /* get number of bits that need to be added to the previous frame */
1211         num_bits_prev_frame = get_bits(gb, s->log2_frame_size);
1212
1213         /* check for packet loss */
1214         if (!s->packet_loss &&
1215             ((s->packet_sequence_number + 1) & 0xF) != packet_sequence_number) {
1216             s->packet_loss = 1;
1217             av_log(avctx, AV_LOG_ERROR,
1218                    "Packet loss detected! seq %"PRIx8" vs %x\n",
1219                    s->packet_sequence_number, packet_sequence_number);
1220         }
1221         s->packet_sequence_number = packet_sequence_number;
1222
1223         if (num_bits_prev_frame > 0) {
1224             int remaining_packet_bits = s->buf_bit_size - get_bits_count(gb);
1225             if (num_bits_prev_frame >= remaining_packet_bits) {
1226                 num_bits_prev_frame = remaining_packet_bits;
1227                 s->packet_done = 1;
1228             }
1229
1230             /* Append the previous frame data to the remaining data from the
1231              * previous packet to create a full frame. */
1232             save_bits(s, gb, num_bits_prev_frame, 1);
1233
1234             /* decode the cross packet frame if it is valid */
1235             if (num_bits_prev_frame < remaining_packet_bits && !s->packet_loss)
1236                 decode_frame(s);
1237         } else if (s->num_saved_bits - s->frame_offset) {
1238             ff_dlog(avctx, "ignoring %x previously saved bits\n",
1239                     s->num_saved_bits - s->frame_offset);
1240         }
1241
1242         if (s->packet_loss) {
1243             /* Reset number of saved bits so that the decoder does not start
1244              * to decode incomplete frames in the s->len_prefix == 0 case. */
1245             s->num_saved_bits = 0;
1246             s->packet_loss    = 0;
1247             init_put_bits(&s->pb, s->frame_data, s->max_frame_size);
1248         }
1249
1250     } else {
1251         int frame_size;
1252
1253         s->buf_bit_size = (avpkt->size - s->next_packet_start) << 3;
1254         init_get_bits(gb, avpkt->data, s->buf_bit_size);
1255         skip_bits(gb, s->packet_offset);
1256
1257         if (s->len_prefix && remaining_bits(s, gb) > s->log2_frame_size &&
1258             (frame_size = show_bits(gb, s->log2_frame_size)) &&
1259             frame_size <= remaining_bits(s, gb)) {
1260             save_bits(s, gb, frame_size, 0);
1261
1262             if (!s->packet_loss)
1263                 s->packet_done = !decode_frame(s);
1264         } else if (!s->len_prefix
1265                    && s->num_saved_bits > get_bits_count(&s->gb)) {
1266             /* when the frames do not have a length prefix, we don't know the
1267              * compressed length of the individual frames however, we know what
1268              * part of a new packet belongs to the previous frame therefore we
1269              * save the incoming packet first, then we append the "previous
1270              * frame" data from the next packet so that we get a buffer that
1271              * only contains full frames */
1272             s->packet_done = !decode_frame(s);
1273         } else {
1274             s->packet_done = 1;
1275         }
1276     }
1277
1278     if (remaining_bits(s, gb) < 0) {
1279         av_log(avctx, AV_LOG_ERROR, "Overread %d\n", -remaining_bits(s, gb));
1280         s->packet_loss = 1;
1281     }
1282
1283     if (s->packet_done && !s->packet_loss &&
1284         remaining_bits(s, gb) > 0) {
1285         /* save the rest of the data so that it can be decoded
1286          * with the next packet */
1287         save_bits(s, gb, remaining_bits(s, gb), 0);
1288     }
1289
1290     *got_frame_ptr   = s->frame->nb_samples > 0;
1291     av_frame_move_ref(data, s->frame);
1292
1293     s->packet_offset = get_bits_count(gb) & 7;
1294
1295     return (s->packet_loss) ? AVERROR_INVALIDDATA : buf_size ? get_bits_count(gb) >> 3 : 0;
1296 }
1297
1298 static void flush(AVCodecContext *avctx)
1299 {
1300     WmallDecodeCtx *s    = avctx->priv_data;
1301     s->packet_loss       = 1;
1302     s->packet_done       = 0;
1303     s->num_saved_bits    = 0;
1304     s->frame_offset      = 0;
1305     s->next_packet_start = 0;
1306     s->cdlms[0][0].order = 0;
1307     s->frame->nb_samples = 0;
1308     init_put_bits(&s->pb, s->frame_data, s->max_frame_size);
1309 }
1310
1311 static av_cold int decode_close(AVCodecContext *avctx)
1312 {
1313     WmallDecodeCtx *s = avctx->priv_data;
1314
1315     av_frame_free(&s->frame);
1316     av_freep(&s->frame_data);
1317
1318     return 0;
1319 }
1320
1321 AVCodec ff_wmalossless_decoder = {
1322     .name           = "wmalossless",
1323     .long_name      = NULL_IF_CONFIG_SMALL("Windows Media Audio Lossless"),
1324     .type           = AVMEDIA_TYPE_AUDIO,
1325     .id             = AV_CODEC_ID_WMALOSSLESS,
1326     .priv_data_size = sizeof(WmallDecodeCtx),
1327     .init           = decode_init,
1328     .close          = decode_close,
1329     .decode         = decode_packet,
1330     .flush          = flush,
1331     .capabilities   = AV_CODEC_CAP_SUBFRAMES | AV_CODEC_CAP_DR1 | AV_CODEC_CAP_DELAY,
1332     .caps_internal  = FF_CODEC_CAP_INIT_CLEANUP,
1333     .sample_fmts    = (const enum AVSampleFormat[]) { AV_SAMPLE_FMT_S16P,
1334                                                       AV_SAMPLE_FMT_S32P,
1335                                                       AV_SAMPLE_FMT_NONE },
1336 };