]> git.sesse.net Git - ffmpeg/blob - libavcodec/vp9.c
dvbsubdec: check memory allocations and propagate errors
[ffmpeg] / libavcodec / vp9.c
1 /*
2  * VP9 compatible video decoder
3  *
4  * Copyright (C) 2013 Ronald S. Bultje <rsbultje gmail com>
5  * Copyright (C) 2013 Clément Bœsch <u pkh me>
6  *
7  * This file is part of Libav.
8  *
9  * Libav is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public
11  * License as published by the Free Software Foundation; either
12  * version 2.1 of the License, or (at your option) any later version.
13  *
14  * Libav is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with Libav; if not, write to the Free Software
21  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
22  */
23
24 #include "libavutil/avassert.h"
25
26 #include "avcodec.h"
27 #include "get_bits.h"
28 #include "internal.h"
29 #include "videodsp.h"
30 #include "vp56.h"
31 #include "vp9.h"
32 #include "vp9data.h"
33
34 #define VP9_SYNCCODE 0x498342
35 #define MAX_PROB 255
36
37 static void vp9_decode_flush(AVCodecContext *avctx)
38 {
39     VP9Context *s = avctx->priv_data;
40     int i;
41
42     for (i = 0; i < FF_ARRAY_ELEMS(s->refs); i++)
43         av_frame_unref(s->refs[i]);
44 }
45
46 static int update_size(AVCodecContext *avctx, int w, int h)
47 {
48     VP9Context *s = avctx->priv_data;
49     uint8_t *p;
50
51     if (s->above_partition_ctx && w == avctx->width && h == avctx->height)
52         return 0;
53
54     vp9_decode_flush(avctx);
55
56     if (w <= 0 || h <= 0)
57         return AVERROR_INVALIDDATA;
58
59     avctx->width  = w;
60     avctx->height = h;
61     s->sb_cols    = (w + 63) >> 6;
62     s->sb_rows    = (h + 63) >> 6;
63     s->cols       = (w +  7) >> 3;
64     s->rows       = (h +  7) >> 3;
65
66 #define assign(var, type, n) var = (type)p; p += s->sb_cols * n * sizeof(*var)
67     av_free(s->above_partition_ctx);
68     p = av_malloc(s->sb_cols *
69                   (240 + sizeof(*s->lflvl) + 16 * sizeof(*s->above_mv_ctx) +
70                    64 * s->sb_rows * (1 + sizeof(*s->mv[0]) * 2)));
71     if (!p)
72         return AVERROR(ENOMEM);
73     assign(s->above_partition_ctx, uint8_t *,     8);
74     assign(s->above_skip_ctx,      uint8_t *,     8);
75     assign(s->above_txfm_ctx,      uint8_t *,     8);
76     assign(s->above_mode_ctx,      uint8_t *,    16);
77     assign(s->above_y_nnz_ctx,     uint8_t *,    16);
78     assign(s->above_uv_nnz_ctx[0], uint8_t *,     8);
79     assign(s->above_uv_nnz_ctx[1], uint8_t *,     8);
80     assign(s->intra_pred_data[0],  uint8_t *,    64);
81     assign(s->intra_pred_data[1],  uint8_t *,    32);
82     assign(s->intra_pred_data[2],  uint8_t *,    32);
83     assign(s->above_segpred_ctx,   uint8_t *,     8);
84     assign(s->above_intra_ctx,     uint8_t *,     8);
85     assign(s->above_comp_ctx,      uint8_t *,     8);
86     assign(s->above_ref_ctx,       uint8_t *,     8);
87     assign(s->above_filter_ctx,    uint8_t *,     8);
88     assign(s->lflvl,               VP9Filter *,   1);
89     assign(s->above_mv_ctx,        VP56mv(*)[2], 16);
90     assign(s->segmentation_map,    uint8_t *,      64 * s->sb_rows);
91     assign(s->mv[0],               VP9MVRefPair *, 64 * s->sb_rows);
92     assign(s->mv[1],               VP9MVRefPair *, 64 * s->sb_rows);
93 #undef assign
94
95     return 0;
96 }
97
98 // The sign bit is at the end, not the start, of a bit sequence
99 static av_always_inline int get_bits_with_sign(GetBitContext *gb, int n)
100 {
101     int v = get_bits(gb, n);
102     return get_bits1(gb) ? -v : v;
103 }
104
105 static av_always_inline int inv_recenter_nonneg(int v, int m)
106 {
107     if (v > 2 * m)
108         return v;
109     if (v & 1)
110         return m - ((v + 1) >> 1);
111     return m + (v >> 1);
112 }
113
114 // differential forward probability updates
115 static int update_prob(VP56RangeCoder *c, int p)
116 {
117     static const int inv_map_table[MAX_PROB - 1] = {
118           7,  20,  33,  46,  59,  72,  85,  98, 111, 124, 137, 150, 163, 176,
119         189, 202, 215, 228, 241, 254,   1,   2,   3,   4,   5,   6,   8,   9,
120          10,  11,  12,  13,  14,  15,  16,  17,  18,  19,  21,  22,  23,  24,
121          25,  26,  27,  28,  29,  30,  31,  32,  34,  35,  36,  37,  38,  39,
122          40,  41,  42,  43,  44,  45,  47,  48,  49,  50,  51,  52,  53,  54,
123          55,  56,  57,  58,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
124          70,  71,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,  84,
125          86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,  99, 100,
126         101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 112, 113, 114, 115,
127         116, 117, 118, 119, 120, 121, 122, 123, 125, 126, 127, 128, 129, 130,
128         131, 132, 133, 134, 135, 136, 138, 139, 140, 141, 142, 143, 144, 145,
129         146, 147, 148, 149, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160,
130         161, 162, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175,
131         177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 190, 191,
132         192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 203, 204, 205, 206,
133         207, 208, 209, 210, 211, 212, 213, 214, 216, 217, 218, 219, 220, 221,
134         222, 223, 224, 225, 226, 227, 229, 230, 231, 232, 233, 234, 235, 236,
135         237, 238, 239, 240, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251,
136         252, 253,
137     };
138     int d;
139
140     /* This code is trying to do a differential probability update. For a
141      * current probability A in the range [1, 255], the difference to a new
142      * probability of any value can be expressed differentially as 1-A, 255-A
143      * where some part of this (absolute range) exists both in positive as
144      * well as the negative part, whereas another part only exists in one
145      * half. We're trying to code this shared part differentially, i.e.
146      * times two where the value of the lowest bit specifies the sign, and
147      * the single part is then coded on top of this. This absolute difference
148      * then again has a value of [0, 254], but a bigger value in this range
149      * indicates that we're further away from the original value A, so we
150      * can code this as a VLC code, since higher values are increasingly
151      * unlikely. The first 20 values in inv_map_table[] allow 'cheap, rough'
152      * updates vs. the 'fine, exact' updates further down the range, which
153      * adds one extra dimension to this differential update model. */
154
155     if (!vp8_rac_get(c)) {
156         d = vp8_rac_get_uint(c, 4) + 0;
157     } else if (!vp8_rac_get(c)) {
158         d = vp8_rac_get_uint(c, 4) + 16;
159     } else if (!vp8_rac_get(c)) {
160         d = vp8_rac_get_uint(c, 5) + 32;
161     } else {
162         d = vp8_rac_get_uint(c, 7);
163         if (d >= 65) {
164             d = (d << 1) - 65 + vp8_rac_get(c);
165             d = av_clip(d, 0, MAX_PROB - 65 - 1);
166         }
167         d += 64;
168     }
169
170     return p <= 128
171            ?   1 + inv_recenter_nonneg(inv_map_table[d], p - 1)
172            : 255 - inv_recenter_nonneg(inv_map_table[d], 255 - p);
173 }
174
175 static int decode_frame_header(AVCodecContext *avctx,
176                                const uint8_t *data, int size, int *ref)
177 {
178     VP9Context *s = avctx->priv_data;
179     int c, i, j, k, l, m, n, w, h, max, size2, ret, sharp;
180     int last_invisible;
181     const uint8_t *data2;
182
183     /* general header */
184     if ((ret = init_get_bits8(&s->gb, data, size)) < 0) {
185         av_log(avctx, AV_LOG_ERROR, "Failed to initialize bitstream reader\n");
186         return ret;
187     }
188     if (get_bits(&s->gb, 2) != 0x2) { // frame marker
189         av_log(avctx, AV_LOG_ERROR, "Invalid frame marker\n");
190         return AVERROR_INVALIDDATA;
191     }
192     s->profile = get_bits1(&s->gb);
193     if (get_bits1(&s->gb)) { // reserved bit
194         av_log(avctx, AV_LOG_ERROR, "Reserved bit should be zero\n");
195         return AVERROR_INVALIDDATA;
196     }
197     if (get_bits1(&s->gb)) {
198         *ref = get_bits(&s->gb, 3);
199         return 0;
200     }
201
202     s->last_keyframe = s->keyframe;
203     s->keyframe      = !get_bits1(&s->gb);
204
205     last_invisible = s->invisible;
206     s->invisible   = !get_bits1(&s->gb);
207     s->errorres    = get_bits1(&s->gb);
208     // FIXME disable this upon resolution change
209     s->use_last_frame_mvs = !s->errorres && !last_invisible;
210
211     if (s->keyframe) {
212         if (get_bits_long(&s->gb, 24) != VP9_SYNCCODE) { // synccode
213             av_log(avctx, AV_LOG_ERROR, "Invalid sync code\n");
214             return AVERROR_INVALIDDATA;
215         }
216         s->colorspace = get_bits(&s->gb, 3);
217         if (s->colorspace == 7) { // RGB = profile 1
218             av_log(avctx, AV_LOG_ERROR, "RGB not supported in profile 0\n");
219             return AVERROR_INVALIDDATA;
220         }
221         s->fullrange = get_bits1(&s->gb);
222         // for profile 1, here follows the subsampling bits
223         s->refreshrefmask = 0xff;
224         w = get_bits(&s->gb, 16) + 1;
225         h = get_bits(&s->gb, 16) + 1;
226         if (get_bits1(&s->gb)) // display size
227             skip_bits(&s->gb, 32);
228     } else {
229         s->intraonly = s->invisible ? get_bits1(&s->gb) : 0;
230         s->resetctx  = s->errorres ? 0 : get_bits(&s->gb, 2);
231         if (s->intraonly) {
232             if (get_bits_long(&s->gb, 24) != VP9_SYNCCODE) { // synccode
233                 av_log(avctx, AV_LOG_ERROR, "Invalid sync code\n");
234                 return AVERROR_INVALIDDATA;
235             }
236             s->refreshrefmask = get_bits(&s->gb, 8);
237             w = get_bits(&s->gb, 16) + 1;
238             h = get_bits(&s->gb, 16) + 1;
239             if (get_bits1(&s->gb)) // display size
240                 skip_bits(&s->gb, 32);
241         } else {
242             s->refreshrefmask = get_bits(&s->gb, 8);
243             s->refidx[0]      = get_bits(&s->gb, 3);
244             s->signbias[0]    = get_bits1(&s->gb);
245             s->refidx[1]      = get_bits(&s->gb, 3);
246             s->signbias[1]    = get_bits1(&s->gb);
247             s->refidx[2]      = get_bits(&s->gb, 3);
248             s->signbias[2]    = get_bits1(&s->gb);
249             if (!s->refs[s->refidx[0]]->buf[0] ||
250                 !s->refs[s->refidx[1]]->buf[0] ||
251                 !s->refs[s->refidx[2]]->buf[0]) {
252                 av_log(avctx, AV_LOG_ERROR,
253                        "Not all references are available\n");
254                 return AVERROR_INVALIDDATA;
255             }
256             if (get_bits1(&s->gb)) {
257                 w = s->refs[s->refidx[0]]->width;
258                 h = s->refs[s->refidx[0]]->height;
259             } else if (get_bits1(&s->gb)) {
260                 w = s->refs[s->refidx[1]]->width;
261                 h = s->refs[s->refidx[1]]->height;
262             } else if (get_bits1(&s->gb)) {
263                 w = s->refs[s->refidx[2]]->width;
264                 h = s->refs[s->refidx[2]]->height;
265             } else {
266                 w = get_bits(&s->gb, 16) + 1;
267                 h = get_bits(&s->gb, 16) + 1;
268             }
269             if (get_bits1(&s->gb)) // display size
270                 skip_bits(&s->gb, 32);
271             s->highprecisionmvs = get_bits1(&s->gb);
272             s->filtermode       = get_bits1(&s->gb) ? FILTER_SWITCHABLE :
273                                   get_bits(&s->gb, 2);
274             s->allowcompinter   = s->signbias[0] != s->signbias[1] ||
275                                   s->signbias[0] != s->signbias[2];
276             if (s->allowcompinter) {
277                 if (s->signbias[0] == s->signbias[1]) {
278                     s->fixcompref    = 2;
279                     s->varcompref[0] = 0;
280                     s->varcompref[1] = 1;
281                 } else if (s->signbias[0] == s->signbias[2]) {
282                     s->fixcompref    = 1;
283                     s->varcompref[0] = 0;
284                     s->varcompref[1] = 2;
285                 } else {
286                     s->fixcompref    = 0;
287                     s->varcompref[0] = 1;
288                     s->varcompref[1] = 2;
289                 }
290             }
291         }
292     }
293
294     s->refreshctx   = s->errorres ? 0 : get_bits1(&s->gb);
295     s->parallelmode = s->errorres ? 1 : get_bits1(&s->gb);
296     s->framectxid   = c = get_bits(&s->gb, 2);
297
298     /* loopfilter header data */
299     s->filter.level = get_bits(&s->gb, 6);
300     sharp           = get_bits(&s->gb, 3);
301     /* If sharpness changed, reinit lim/mblim LUTs. if it didn't change,
302      * keep the old cache values since they are still valid. */
303     if (s->filter.sharpness != sharp)
304         memset(s->filter.lim_lut, 0, sizeof(s->filter.lim_lut));
305     s->filter.sharpness = sharp;
306     if ((s->lf_delta.enabled = get_bits1(&s->gb))) {
307         if (get_bits1(&s->gb)) {
308             for (i = 0; i < 4; i++)
309                 if (get_bits1(&s->gb))
310                     s->lf_delta.ref[i] = get_bits_with_sign(&s->gb, 6);
311             for (i = 0; i < 2; i++)
312                 if (get_bits1(&s->gb))
313                     s->lf_delta.mode[i] = get_bits_with_sign(&s->gb, 6);
314         }
315     } else {
316         memset(&s->lf_delta, 0, sizeof(s->lf_delta));
317     }
318
319     /* quantization header data */
320     s->yac_qi      = get_bits(&s->gb, 8);
321     s->ydc_qdelta  = get_bits1(&s->gb) ? get_bits_with_sign(&s->gb, 4) : 0;
322     s->uvdc_qdelta = get_bits1(&s->gb) ? get_bits_with_sign(&s->gb, 4) : 0;
323     s->uvac_qdelta = get_bits1(&s->gb) ? get_bits_with_sign(&s->gb, 4) : 0;
324     s->lossless    = s->yac_qi == 0 && s->ydc_qdelta == 0 &&
325                      s->uvdc_qdelta == 0 && s->uvac_qdelta == 0;
326
327     /* segmentation header info */
328     if ((s->segmentation.enabled = get_bits1(&s->gb))) {
329         if ((s->segmentation.update_map = get_bits1(&s->gb))) {
330             for (i = 0; i < 7; i++)
331                 s->prob.seg[i] = get_bits1(&s->gb) ?
332                                  get_bits(&s->gb, 8) : 255;
333             if ((s->segmentation.temporal = get_bits1(&s->gb)))
334                 for (i = 0; i < 3; i++)
335                     s->prob.segpred[i] = get_bits1(&s->gb) ?
336                                          get_bits(&s->gb, 8) : 255;
337         }
338
339         if (get_bits1(&s->gb)) {
340             s->segmentation.absolute_vals = get_bits1(&s->gb);
341             for (i = 0; i < 8; i++) {
342                 if ((s->segmentation.feat[i].q_enabled = get_bits1(&s->gb)))
343                     s->segmentation.feat[i].q_val = get_bits_with_sign(&s->gb, 8);
344                 if ((s->segmentation.feat[i].lf_enabled = get_bits1(&s->gb)))
345                     s->segmentation.feat[i].lf_val = get_bits_with_sign(&s->gb, 6);
346                 if ((s->segmentation.feat[i].ref_enabled = get_bits1(&s->gb)))
347                     s->segmentation.feat[i].ref_val = get_bits(&s->gb, 2);
348                 s->segmentation.feat[i].skip_enabled = get_bits1(&s->gb);
349             }
350         }
351     } else {
352         s->segmentation.feat[0].q_enabled    = 0;
353         s->segmentation.feat[0].lf_enabled   = 0;
354         s->segmentation.feat[0].skip_enabled = 0;
355         s->segmentation.feat[0].ref_enabled  = 0;
356     }
357
358     // set qmul[] based on Y/UV, AC/DC and segmentation Q idx deltas
359     for (i = 0; i < (s->segmentation.enabled ? 8 : 1); i++) {
360         int qyac, qydc, quvac, quvdc, lflvl, sh;
361
362         if (s->segmentation.feat[i].q_enabled) {
363             if (s->segmentation.absolute_vals)
364                 qyac = s->segmentation.feat[i].q_val;
365             else
366                 qyac = s->yac_qi + s->segmentation.feat[i].q_val;
367         } else {
368             qyac = s->yac_qi;
369         }
370         qydc  = av_clip_uintp2(qyac + s->ydc_qdelta, 8);
371         quvdc = av_clip_uintp2(qyac + s->uvdc_qdelta, 8);
372         quvac = av_clip_uintp2(qyac + s->uvac_qdelta, 8);
373         qyac  = av_clip_uintp2(qyac, 8);
374
375         s->segmentation.feat[i].qmul[0][0] = ff_vp9_dc_qlookup[qydc];
376         s->segmentation.feat[i].qmul[0][1] = ff_vp9_ac_qlookup[qyac];
377         s->segmentation.feat[i].qmul[1][0] = ff_vp9_dc_qlookup[quvdc];
378         s->segmentation.feat[i].qmul[1][1] = ff_vp9_ac_qlookup[quvac];
379
380         sh = s->filter.level >= 32;
381         if (s->segmentation.feat[i].lf_enabled) {
382             if (s->segmentation.absolute_vals)
383                 lflvl = s->segmentation.feat[i].lf_val;
384             else
385                 lflvl = s->filter.level + s->segmentation.feat[i].lf_val;
386         } else {
387             lflvl = s->filter.level;
388         }
389         s->segmentation.feat[i].lflvl[0][0] =
390         s->segmentation.feat[i].lflvl[0][1] =
391             av_clip_uintp2(lflvl + (s->lf_delta.ref[0] << sh), 6);
392         for (j = 1; j < 4; j++) {
393             s->segmentation.feat[i].lflvl[j][0] =
394                 av_clip_uintp2(lflvl + ((s->lf_delta.ref[j] +
395                                          s->lf_delta.mode[0]) << sh), 6);
396             s->segmentation.feat[i].lflvl[j][1] =
397                 av_clip_uintp2(lflvl + ((s->lf_delta.ref[j] +
398                                          s->lf_delta.mode[1]) << sh), 6);
399         }
400     }
401
402     /* tiling info */
403     if ((ret = update_size(avctx, w, h)) < 0) {
404         av_log(avctx, AV_LOG_ERROR,
405                "Failed to initialize decoder for %dx%d\n", w, h);
406         return ret;
407     }
408     for (s->tiling.log2_tile_cols = 0;
409          (s->sb_cols >> s->tiling.log2_tile_cols) > 64;
410          s->tiling.log2_tile_cols++) ;
411     for (max = 0; (s->sb_cols >> max) >= 4; max++) ;
412     max = FFMAX(0, max - 1);
413     while (max > s->tiling.log2_tile_cols) {
414         if (get_bits1(&s->gb))
415             s->tiling.log2_tile_cols++;
416         else
417             break;
418     }
419     s->tiling.log2_tile_rows = decode012(&s->gb);
420     s->tiling.tile_rows      = 1 << s->tiling.log2_tile_rows;
421     if (s->tiling.tile_cols != (1 << s->tiling.log2_tile_cols)) {
422         s->tiling.tile_cols = 1 << s->tiling.log2_tile_cols;
423         s->c_b              = av_fast_realloc(s->c_b, &s->c_b_size,
424                                               sizeof(VP56RangeCoder) *
425                                               s->tiling.tile_cols);
426         if (!s->c_b) {
427             av_log(avctx, AV_LOG_ERROR,
428                    "Ran out of memory during range coder init\n");
429             return AVERROR(ENOMEM);
430         }
431     }
432
433     if (s->keyframe || s->errorres || s->intraonly) {
434         s->prob_ctx[0].p =
435         s->prob_ctx[1].p =
436         s->prob_ctx[2].p =
437         s->prob_ctx[3].p = ff_vp9_default_probs;
438         memcpy(s->prob_ctx[0].coef, ff_vp9_default_coef_probs,
439                sizeof(ff_vp9_default_coef_probs));
440         memcpy(s->prob_ctx[1].coef, ff_vp9_default_coef_probs,
441                sizeof(ff_vp9_default_coef_probs));
442         memcpy(s->prob_ctx[2].coef, ff_vp9_default_coef_probs,
443                sizeof(ff_vp9_default_coef_probs));
444         memcpy(s->prob_ctx[3].coef, ff_vp9_default_coef_probs,
445                sizeof(ff_vp9_default_coef_probs));
446     }
447
448     // next 16 bits is size of the rest of the header (arith-coded)
449     size2 = get_bits(&s->gb, 16);
450     data2 = align_get_bits(&s->gb);
451     if (size2 > size - (data2 - data)) {
452         av_log(avctx, AV_LOG_ERROR, "Invalid compressed header size\n");
453         return AVERROR_INVALIDDATA;
454     }
455     ff_vp56_init_range_decoder(&s->c, data2, size2);
456     if (vp56_rac_get_prob_branchy(&s->c, 128)) { // marker bit
457         av_log(avctx, AV_LOG_ERROR, "Marker bit was set\n");
458         return AVERROR_INVALIDDATA;
459     }
460
461     if (s->keyframe || s->intraonly)
462         memset(s->counts.coef, 0,
463                sizeof(s->counts.coef) + sizeof(s->counts.eob));
464     else
465         memset(&s->counts, 0, sizeof(s->counts));
466
467     /* FIXME is it faster to not copy here, but do it down in the fw updates
468      * as explicit copies if the fw update is missing (and skip the copy upon
469      * fw update)? */
470     s->prob.p = s->prob_ctx[c].p;
471
472     // txfm updates
473     if (s->lossless) {
474         s->txfmmode = TX_4X4;
475     } else {
476         s->txfmmode = vp8_rac_get_uint(&s->c, 2);
477         if (s->txfmmode == 3)
478             s->txfmmode += vp8_rac_get(&s->c);
479
480         if (s->txfmmode == TX_SWITCHABLE) {
481             for (i = 0; i < 2; i++)
482                 if (vp56_rac_get_prob_branchy(&s->c, 252))
483                     s->prob.p.tx8p[i] = update_prob(&s->c, s->prob.p.tx8p[i]);
484             for (i = 0; i < 2; i++)
485                 for (j = 0; j < 2; j++)
486                     if (vp56_rac_get_prob_branchy(&s->c, 252))
487                         s->prob.p.tx16p[i][j] =
488                             update_prob(&s->c, s->prob.p.tx16p[i][j]);
489             for (i = 0; i < 2; i++)
490                 for (j = 0; j < 3; j++)
491                     if (vp56_rac_get_prob_branchy(&s->c, 252))
492                         s->prob.p.tx32p[i][j] =
493                             update_prob(&s->c, s->prob.p.tx32p[i][j]);
494         }
495     }
496
497     // coef updates
498     for (i = 0; i < 4; i++) {
499         uint8_t (*ref)[2][6][6][3] = s->prob_ctx[c].coef[i];
500         if (vp8_rac_get(&s->c)) {
501             for (j = 0; j < 2; j++)
502                 for (k = 0; k < 2; k++)
503                     for (l = 0; l < 6; l++)
504                         for (m = 0; m < 6; m++) {
505                             uint8_t *p = s->prob.coef[i][j][k][l][m];
506                             uint8_t *r = ref[j][k][l][m];
507                             if (m >= 3 && l == 0) // dc only has 3 pt
508                                 break;
509                             for (n = 0; n < 3; n++) {
510                                 if (vp56_rac_get_prob_branchy(&s->c, 252))
511                                     p[n] = update_prob(&s->c, r[n]);
512                                 else
513                                     p[n] = r[n];
514                             }
515                             p[3] = 0;
516                         }
517         } else {
518             for (j = 0; j < 2; j++)
519                 for (k = 0; k < 2; k++)
520                     for (l = 0; l < 6; l++)
521                         for (m = 0; m < 6; m++) {
522                             uint8_t *p = s->prob.coef[i][j][k][l][m];
523                             uint8_t *r = ref[j][k][l][m];
524                             if (m > 3 && l == 0) // dc only has 3 pt
525                                 break;
526                             memcpy(p, r, 3);
527                             p[3] = 0;
528                         }
529         }
530         if (s->txfmmode == i)
531             break;
532     }
533
534     // mode updates
535     for (i = 0; i < 3; i++)
536         if (vp56_rac_get_prob_branchy(&s->c, 252))
537             s->prob.p.skip[i] = update_prob(&s->c, s->prob.p.skip[i]);
538     if (!s->keyframe && !s->intraonly) {
539         for (i = 0; i < 7; i++)
540             for (j = 0; j < 3; j++)
541                 if (vp56_rac_get_prob_branchy(&s->c, 252))
542                     s->prob.p.mv_mode[i][j] =
543                         update_prob(&s->c, s->prob.p.mv_mode[i][j]);
544
545         if (s->filtermode == FILTER_SWITCHABLE)
546             for (i = 0; i < 4; i++)
547                 for (j = 0; j < 2; j++)
548                     if (vp56_rac_get_prob_branchy(&s->c, 252))
549                         s->prob.p.filter[i][j] =
550                             update_prob(&s->c, s->prob.p.filter[i][j]);
551
552         for (i = 0; i < 4; i++)
553             if (vp56_rac_get_prob_branchy(&s->c, 252))
554                 s->prob.p.intra[i] = update_prob(&s->c, s->prob.p.intra[i]);
555
556         if (s->allowcompinter) {
557             s->comppredmode = vp8_rac_get(&s->c);
558             if (s->comppredmode)
559                 s->comppredmode += vp8_rac_get(&s->c);
560             if (s->comppredmode == PRED_SWITCHABLE)
561                 for (i = 0; i < 5; i++)
562                     if (vp56_rac_get_prob_branchy(&s->c, 252))
563                         s->prob.p.comp[i] =
564                             update_prob(&s->c, s->prob.p.comp[i]);
565         } else {
566             s->comppredmode = PRED_SINGLEREF;
567         }
568
569         if (s->comppredmode != PRED_COMPREF) {
570             for (i = 0; i < 5; i++) {
571                 if (vp56_rac_get_prob_branchy(&s->c, 252))
572                     s->prob.p.single_ref[i][0] =
573                         update_prob(&s->c, s->prob.p.single_ref[i][0]);
574                 if (vp56_rac_get_prob_branchy(&s->c, 252))
575                     s->prob.p.single_ref[i][1] =
576                         update_prob(&s->c, s->prob.p.single_ref[i][1]);
577             }
578         }
579
580         if (s->comppredmode != PRED_SINGLEREF) {
581             for (i = 0; i < 5; i++)
582                 if (vp56_rac_get_prob_branchy(&s->c, 252))
583                     s->prob.p.comp_ref[i] =
584                         update_prob(&s->c, s->prob.p.comp_ref[i]);
585         }
586
587         for (i = 0; i < 4; i++)
588             for (j = 0; j < 9; j++)
589                 if (vp56_rac_get_prob_branchy(&s->c, 252))
590                     s->prob.p.y_mode[i][j] =
591                         update_prob(&s->c, s->prob.p.y_mode[i][j]);
592
593         for (i = 0; i < 4; i++)
594             for (j = 0; j < 4; j++)
595                 for (k = 0; k < 3; k++)
596                     if (vp56_rac_get_prob_branchy(&s->c, 252))
597                         s->prob.p.partition[3 - i][j][k] =
598                             update_prob(&s->c,
599                                         s->prob.p.partition[3 - i][j][k]);
600
601         // mv fields don't use the update_prob subexp model for some reason
602         for (i = 0; i < 3; i++)
603             if (vp56_rac_get_prob_branchy(&s->c, 252))
604                 s->prob.p.mv_joint[i] = (vp8_rac_get_uint(&s->c, 7) << 1) | 1;
605
606         for (i = 0; i < 2; i++) {
607             if (vp56_rac_get_prob_branchy(&s->c, 252))
608                 s->prob.p.mv_comp[i].sign =
609                     (vp8_rac_get_uint(&s->c, 7) << 1) | 1;
610
611             for (j = 0; j < 10; j++)
612                 if (vp56_rac_get_prob_branchy(&s->c, 252))
613                     s->prob.p.mv_comp[i].classes[j] =
614                         (vp8_rac_get_uint(&s->c, 7) << 1) | 1;
615
616             if (vp56_rac_get_prob_branchy(&s->c, 252))
617                 s->prob.p.mv_comp[i].class0 =
618                     (vp8_rac_get_uint(&s->c, 7) << 1) | 1;
619
620             for (j = 0; j < 10; j++)
621                 if (vp56_rac_get_prob_branchy(&s->c, 252))
622                     s->prob.p.mv_comp[i].bits[j] =
623                         (vp8_rac_get_uint(&s->c, 7) << 1) | 1;
624         }
625
626         for (i = 0; i < 2; i++) {
627             for (j = 0; j < 2; j++)
628                 for (k = 0; k < 3; k++)
629                     if (vp56_rac_get_prob_branchy(&s->c, 252))
630                         s->prob.p.mv_comp[i].class0_fp[j][k] =
631                             (vp8_rac_get_uint(&s->c, 7) << 1) | 1;
632
633             for (j = 0; j < 3; j++)
634                 if (vp56_rac_get_prob_branchy(&s->c, 252))
635                     s->prob.p.mv_comp[i].fp[j] =
636                         (vp8_rac_get_uint(&s->c, 7) << 1) | 1;
637         }
638
639         if (s->highprecisionmvs) {
640             for (i = 0; i < 2; i++) {
641                 if (vp56_rac_get_prob_branchy(&s->c, 252))
642                     s->prob.p.mv_comp[i].class0_hp =
643                         (vp8_rac_get_uint(&s->c, 7) << 1) | 1;
644
645                 if (vp56_rac_get_prob_branchy(&s->c, 252))
646                     s->prob.p.mv_comp[i].hp =
647                         (vp8_rac_get_uint(&s->c, 7) << 1) | 1;
648             }
649         }
650     }
651
652     return (data2 - data) + size2;
653 }
654
655 static int decode_subblock(AVCodecContext *avctx, int row, int col,
656                            VP9Filter *lflvl,
657                            ptrdiff_t yoff, ptrdiff_t uvoff, enum BlockLevel bl)
658 {
659     VP9Context *s = avctx->priv_data;
660     int c = ((s->above_partition_ctx[col]       >> (3 - bl)) & 1) |
661             (((s->left_partition_ctx[row & 0x7] >> (3 - bl)) & 1) << 1);
662     int ret;
663     const uint8_t *p = s->keyframe ? ff_vp9_default_kf_partition_probs[bl][c]
664                                    : s->prob.p.partition[bl][c];
665     enum BlockPartition bp;
666     ptrdiff_t hbs = 4 >> bl;
667
668     if (bl == BL_8X8) {
669         bp  = vp8_rac_get_tree(&s->c, ff_vp9_partition_tree, p);
670         ret = ff_vp9_decode_block(avctx, row, col, lflvl, yoff, uvoff, bl, bp);
671     } else if (col + hbs < s->cols) {
672         if (row + hbs < s->rows) {
673             bp = vp8_rac_get_tree(&s->c, ff_vp9_partition_tree, p);
674             switch (bp) {
675             case PARTITION_NONE:
676                 ret = ff_vp9_decode_block(avctx, row, col, lflvl, yoff, uvoff,
677                                           bl, bp);
678                 break;
679             case PARTITION_H:
680                 ret = ff_vp9_decode_block(avctx, row, col, lflvl, yoff, uvoff,
681                                           bl, bp);
682                 if (!ret) {
683                     yoff  += hbs * 8 * s->cur_frame->linesize[0];
684                     uvoff += hbs * 4 * s->cur_frame->linesize[1];
685                     ret    = ff_vp9_decode_block(avctx, row + hbs, col, lflvl,
686                                                  yoff, uvoff, bl, bp);
687                 }
688                 break;
689             case PARTITION_V:
690                 ret = ff_vp9_decode_block(avctx, row, col, lflvl, yoff, uvoff,
691                                           bl, bp);
692                 if (!ret) {
693                     yoff  += hbs * 8;
694                     uvoff += hbs * 4;
695                     ret    = ff_vp9_decode_block(avctx, row, col + hbs, lflvl,
696                                                  yoff, uvoff, bl, bp);
697                 }
698                 break;
699             case PARTITION_SPLIT:
700                 ret = decode_subblock(avctx, row, col, lflvl,
701                                       yoff, uvoff, bl + 1);
702                 if (!ret) {
703                     ret = decode_subblock(avctx, row, col + hbs, lflvl,
704                                           yoff + 8 * hbs, uvoff + 4 * hbs,
705                                           bl + 1);
706                     if (!ret) {
707                         yoff  += hbs * 8 * s->cur_frame->linesize[0];
708                         uvoff += hbs * 4 * s->cur_frame->linesize[1];
709                         ret    = decode_subblock(avctx, row + hbs, col, lflvl,
710                                                  yoff, uvoff, bl + 1);
711                         if (!ret) {
712                             ret = decode_subblock(avctx, row + hbs, col + hbs,
713                                                   lflvl, yoff + 8 * hbs,
714                                                   uvoff + 4 * hbs, bl + 1);
715                         }
716                     }
717                 }
718                 break;
719             default:
720                 av_log(avctx, AV_LOG_ERROR, "Unexpected partition %d.", bp);
721                 return AVERROR_INVALIDDATA;
722             }
723         } else if (vp56_rac_get_prob_branchy(&s->c, p[1])) {
724             bp  = PARTITION_SPLIT;
725             ret = decode_subblock(avctx, row, col, lflvl, yoff, uvoff, bl + 1);
726             if (!ret)
727                 ret = decode_subblock(avctx, row, col + hbs, lflvl,
728                                       yoff + 8 * hbs, uvoff + 4 * hbs, bl + 1);
729         } else {
730             bp  = PARTITION_H;
731             ret = ff_vp9_decode_block(avctx, row, col, lflvl, yoff, uvoff,
732                                       bl, bp);
733         }
734     } else if (row + hbs < s->rows) {
735         if (vp56_rac_get_prob_branchy(&s->c, p[2])) {
736             bp  = PARTITION_SPLIT;
737             ret = decode_subblock(avctx, row, col, lflvl, yoff, uvoff, bl + 1);
738             if (!ret) {
739                 yoff  += hbs * 8 * s->cur_frame->linesize[0];
740                 uvoff += hbs * 4 * s->cur_frame->linesize[1];
741                 ret    = decode_subblock(avctx, row + hbs, col, lflvl,
742                                          yoff, uvoff, bl + 1);
743             }
744         } else {
745             bp  = PARTITION_V;
746             ret = ff_vp9_decode_block(avctx, row, col, lflvl, yoff, uvoff,
747                                       bl, bp);
748         }
749     } else {
750         bp  = PARTITION_SPLIT;
751         ret = decode_subblock(avctx, row, col, lflvl, yoff, uvoff, bl + 1);
752     }
753     s->counts.partition[bl][c][bp]++;
754
755     return ret;
756 }
757
758 static void loopfilter_subblock(AVCodecContext *avctx, VP9Filter *lflvl,
759                                 int row, int col,
760                                 ptrdiff_t yoff, ptrdiff_t uvoff)
761 {
762     VP9Context *s = avctx->priv_data;
763     uint8_t *dst   = s->cur_frame->data[0] + yoff, *lvl = lflvl->level;
764     ptrdiff_t ls_y = s->cur_frame->linesize[0], ls_uv = s->cur_frame->linesize[1];
765     int y, x, p;
766
767     /* FIXME: In how far can we interleave the v/h loopfilter calls? E.g.
768      * if you think of them as acting on a 8x8 block max, we can interleave
769      * each v/h within the single x loop, but that only works if we work on
770      * 8 pixel blocks, and we won't always do that (we want at least 16px
771      * to use SSE2 optimizations, perhaps 32 for AVX2). */
772
773     // filter edges between columns, Y plane (e.g. block1 | block2)
774     for (y = 0; y < 8; y += 2, dst += 16 * ls_y, lvl += 16) {
775         uint8_t *ptr = dst, *l = lvl, *hmask1 = lflvl->mask[0][0][y];
776         uint8_t *hmask2 = lflvl->mask[0][0][y + 1];
777         unsigned hm1 = hmask1[0] | hmask1[1] | hmask1[2], hm13 = hmask1[3];
778         unsigned hm2 = hmask2[1] | hmask2[2], hm23 = hmask2[3];
779         unsigned hm  = hm1 | hm2 | hm13 | hm23;
780
781         for (x = 1; hm & ~(x - 1); x <<= 1, ptr += 8, l++) {
782             if (hm1 & x) {
783                 int L = *l, H = L >> 4;
784                 int E = s->filter.mblim_lut[L], I = s->filter.lim_lut[L];
785
786                 if (col || x > 1) {
787                     if (hmask1[0] & x) {
788                         if (hmask2[0] & x) {
789                             av_assert2(l[8] == L);
790                             s->dsp.loop_filter_16[0](ptr, ls_y, E, I, H);
791                         } else {
792                             s->dsp.loop_filter_8[2][0](ptr, ls_y, E, I, H);
793                         }
794                     } else if (hm2 & x) {
795                         L  = l[8];
796                         H |= (L >> 4) << 8;
797                         E |= s->filter.mblim_lut[L] << 8;
798                         I |= s->filter.lim_lut[L] << 8;
799                         s->dsp.loop_filter_mix2[!!(hmask1[1] & x)]
800                                                [!!(hmask2[1] & x)]
801                                                [0](ptr, ls_y, E, I, H);
802                     } else {
803                         s->dsp.loop_filter_8[!!(hmask1[1] & x)]
804                                             [0](ptr, ls_y, E, I, H);
805                     }
806                 }
807             } else if (hm2 & x) {
808                 int L = l[8], H = L >> 4;
809                 int E = s->filter.mblim_lut[L], I = s->filter.lim_lut[L];
810
811                 if (col || x > 1) {
812                     s->dsp.loop_filter_8[!!(hmask2[1] & x)]
813                                         [0](ptr + 8 * ls_y, ls_y, E, I, H);
814                 }
815             }
816             if (hm13 & x) {
817                 int L = *l, H = L >> 4;
818                 int E = s->filter.mblim_lut[L], I = s->filter.lim_lut[L];
819
820                 if (hm23 & x) {
821                     L  = l[8];
822                     H |= (L >> 4) << 8;
823                     E |= s->filter.mblim_lut[L] << 8;
824                     I |= s->filter.lim_lut[L] << 8;
825                     s->dsp.loop_filter_mix2[0][0][0](ptr + 4, ls_y, E, I, H);
826                 } else {
827                     s->dsp.loop_filter_8[0][0](ptr + 4, ls_y, E, I, H);
828                 }
829             } else if (hm23 & x) {
830                 int L = l[8], H = L >> 4;
831                 int E = s->filter.mblim_lut[L], I = s->filter.lim_lut[L];
832
833                 s->dsp.loop_filter_8[0][0](ptr + 8 * ls_y + 4, ls_y, E, I, H);
834             }
835         }
836     }
837
838     //                                          block1
839     // filter edges between rows, Y plane (e.g. ------)
840     //                                          block2
841     dst = s->cur_frame->data[0] + yoff;
842     lvl = lflvl->level;
843     for (y = 0; y < 8; y++, dst += 8 * ls_y, lvl += 8) {
844         uint8_t *ptr = dst, *l = lvl, *vmask = lflvl->mask[0][1][y];
845         unsigned vm = vmask[0] | vmask[1] | vmask[2], vm3 = vmask[3];
846
847         for (x = 1; vm & ~(x - 1); x <<= 2, ptr += 16, l += 2) {
848             if (row || y) {
849                 if (vm & x) {
850                     int L = *l, H = L >> 4;
851                     int E = s->filter.mblim_lut[L], I = s->filter.lim_lut[L];
852
853                     if (vmask[0] & x) {
854                         if (vmask[0] & (x << 1)) {
855                             av_assert2(l[1] == L);
856                             s->dsp.loop_filter_16[1](ptr, ls_y, E, I, H);
857                         } else {
858                             s->dsp.loop_filter_8[2][1](ptr, ls_y, E, I, H);
859                         }
860                     } else if (vm & (x << 1)) {
861                         L  = l[1];
862                         H |= (L >> 4) << 8;
863                         E |= s->filter.mblim_lut[L] << 8;
864                         I |= s->filter.lim_lut[L] << 8;
865                         s->dsp.loop_filter_mix2[!!(vmask[1] &  x)]
866                                                [!!(vmask[1] & (x << 1))]
867                                                [1](ptr, ls_y, E, I, H);
868                     } else {
869                         s->dsp.loop_filter_8[!!(vmask[1] & x)]
870                                             [1](ptr, ls_y, E, I, H);
871                     }
872                 } else if (vm & (x << 1)) {
873                     int L = l[1], H = L >> 4;
874                     int E = s->filter.mblim_lut[L], I = s->filter.lim_lut[L];
875
876                     s->dsp.loop_filter_8[!!(vmask[1] & (x << 1))]
877                                         [1](ptr + 8, ls_y, E, I, H);
878                 }
879             }
880             if (vm3 & x) {
881                 int L = *l, H = L >> 4;
882                 int E = s->filter.mblim_lut[L], I = s->filter.lim_lut[L];
883
884                 if (vm3 & (x << 1)) {
885                     L  = l[1];
886                     H |= (L >> 4) << 8;
887                     E |= s->filter.mblim_lut[L] << 8;
888                     I |= s->filter.lim_lut[L] << 8;
889                     s->dsp.loop_filter_mix2[0][0][1](ptr + ls_y * 4, ls_y, E, I, H);
890                 } else {
891                     s->dsp.loop_filter_8[0][1](ptr + ls_y * 4, ls_y, E, I, H);
892                 }
893             } else if (vm3 & (x << 1)) {
894                 int L = l[1], H = L >> 4;
895                 int E = s->filter.mblim_lut[L], I = s->filter.lim_lut[L];
896
897                 s->dsp.loop_filter_8[0][1](ptr + ls_y * 4 + 8, ls_y, E, I, H);
898             }
899         }
900     }
901
902     // same principle but for U/V planes
903     for (p = 0; p < 2; p++) {
904         lvl = lflvl->level;
905         dst = s->cur_frame->data[1 + p] + uvoff;
906         for (y = 0; y < 8; y += 4, dst += 16 * ls_uv, lvl += 32) {
907             uint8_t *ptr = dst, *l = lvl, *hmask1 = lflvl->mask[1][0][y];
908             uint8_t *hmask2 = lflvl->mask[1][0][y + 2];
909             unsigned hm1 = hmask1[0] | hmask1[1] | hmask1[2];
910             unsigned hm2 = hmask2[1] | hmask2[2], hm = hm1 | hm2;
911
912             for (x = 1; hm & ~(x - 1); x <<= 1, ptr += 4) {
913                 if (col || x > 1) {
914                     if (hm1 & x) {
915                         int L = *l, H = L >> 4;
916                         int E = s->filter.mblim_lut[L];
917                         int I = s->filter.lim_lut[L];
918
919                         if (hmask1[0] & x) {
920                             if (hmask2[0] & x) {
921                                 av_assert2(l[16] == L);
922                                 s->dsp.loop_filter_16[0](ptr, ls_uv, E, I, H);
923                             } else {
924                                 s->dsp.loop_filter_8[2][0](ptr, ls_uv, E, I, H);
925                             }
926                         } else if (hm2 & x) {
927                             L  = l[16];
928                             H |= (L >> 4) << 8;
929                             E |= s->filter.mblim_lut[L] << 8;
930                             I |= s->filter.lim_lut[L] << 8;
931                             s->dsp.loop_filter_mix2[!!(hmask1[1] & x)]
932                                                    [!!(hmask2[1] & x)]
933                                                    [0](ptr, ls_uv, E, I, H);
934                         } else {
935                             s->dsp.loop_filter_8[!!(hmask1[1] & x)]
936                                                 [0](ptr, ls_uv, E, I, H);
937                         }
938                     } else if (hm2 & x) {
939                         int L = l[16], H = L >> 4;
940                         int E = s->filter.mblim_lut[L];
941                         int I = s->filter.lim_lut[L];
942
943                         s->dsp.loop_filter_8[!!(hmask2[1] & x)]
944                                             [0](ptr + 8 * ls_uv, ls_uv, E, I, H);
945                     }
946                 }
947                 if (x & 0xAA)
948                     l += 2;
949             }
950         }
951         lvl = lflvl->level;
952         dst = s->cur_frame->data[1 + p] + uvoff;
953         for (y = 0; y < 8; y++, dst += 4 * ls_uv) {
954             uint8_t *ptr = dst, *l = lvl, *vmask = lflvl->mask[1][1][y];
955             unsigned vm = vmask[0] | vmask[1] | vmask[2];
956
957             for (x = 1; vm & ~(x - 1); x <<= 4, ptr += 16, l += 4) {
958                 if (row || y) {
959                     if (vm & x) {
960                         int L = *l, H = L >> 4;
961                         int E = s->filter.mblim_lut[L];
962                         int I = s->filter.lim_lut[L];
963
964                         if (vmask[0] & x) {
965                             if (vmask[0] & (x << 2)) {
966                                 av_assert2(l[2] == L);
967                                 s->dsp.loop_filter_16[1](ptr, ls_uv, E, I, H);
968                             } else {
969                                 s->dsp.loop_filter_8[2][1](ptr, ls_uv, E, I, H);
970                             }
971                         } else if (vm & (x << 2)) {
972                             L  = l[2];
973                             H |= (L >> 4) << 8;
974                             E |= s->filter.mblim_lut[L] << 8;
975                             I |= s->filter.lim_lut[L] << 8;
976                             s->dsp.loop_filter_mix2[!!(vmask[1] &  x)]
977                                                    [!!(vmask[1] & (x << 2))]
978                                                    [1](ptr, ls_uv, E, I, H);
979                         } else {
980                             s->dsp.loop_filter_8[!!(vmask[1] & x)]
981                                                 [1](ptr, ls_uv, E, I, H);
982                         }
983                     } else if (vm & (x << 2)) {
984                         int L = l[2], H = L >> 4;
985                         int E = s->filter.mblim_lut[L];
986                         int I = s->filter.lim_lut[L];
987
988                         s->dsp.loop_filter_8[!!(vmask[1] & (x << 2))]
989                                             [1](ptr + 8, ls_uv, E, I, H);
990                     }
991                 }
992             }
993             if (y & 1)
994                 lvl += 16;
995         }
996     }
997 }
998
999 static void set_tile_offset(int *start, int *end, int idx, int log2_n, int n)
1000 {
1001     int sb_start =  (idx      * n) >> log2_n;
1002     int sb_end   = ((idx + 1) * n) >> log2_n;
1003     *start = FFMIN(sb_start, n) << 3;
1004     *end   = FFMIN(sb_end,   n) << 3;
1005 }
1006
1007 static int vp9_decode_frame(AVCodecContext *avctx, AVFrame *frame,
1008                             int *got_frame, const uint8_t *data, int size)
1009 {
1010     VP9Context *s = avctx->priv_data;
1011     int ret, tile_row, tile_col, i, ref = -1, row, col;
1012     ptrdiff_t yoff = 0, uvoff = 0;
1013
1014     ret = decode_frame_header(avctx, data, size, &ref);
1015     if (ret < 0) {
1016         return ret;
1017     } else if (!ret) {
1018         if (!s->refs[ref]->buf[0]) {
1019             av_log(avctx, AV_LOG_ERROR,
1020                    "Requested reference %d not available\n", ref);
1021             return AVERROR_INVALIDDATA;
1022         }
1023
1024         ret = av_frame_ref(frame, s->refs[ref]);
1025         if (ret < 0)
1026             return ret;
1027         *got_frame = 1;
1028         return 0;
1029     }
1030     data += ret;
1031     size -= ret;
1032
1033     s->cur_frame = frame;
1034
1035     av_frame_unref(s->cur_frame);
1036     if ((ret = ff_get_buffer(avctx, s->cur_frame,
1037                              s->refreshrefmask ? AV_GET_BUFFER_FLAG_REF : 0)) < 0)
1038         return ret;
1039     s->cur_frame->key_frame = s->keyframe;
1040     s->cur_frame->pict_type = s->keyframe ? AV_PICTURE_TYPE_I
1041                                           : AV_PICTURE_TYPE_P;
1042
1043     if (s->fullrange)
1044         avctx->color_range = AVCOL_RANGE_JPEG;
1045     else
1046         avctx->color_range = AVCOL_RANGE_MPEG;
1047
1048     switch (s->colorspace) {
1049     case 1: avctx->colorspace = AVCOL_SPC_BT470BG; break;
1050     case 2: avctx->colorspace = AVCOL_SPC_BT709; break;
1051     case 3: avctx->colorspace = AVCOL_SPC_SMPTE170M; break;
1052     case 4: avctx->colorspace = AVCOL_SPC_SMPTE240M; break;
1053     }
1054
1055     // main tile decode loop
1056     memset(s->above_partition_ctx, 0, s->cols);
1057     memset(s->above_skip_ctx, 0, s->cols);
1058     if (s->keyframe || s->intraonly)
1059         memset(s->above_mode_ctx, DC_PRED, s->cols * 2);
1060     else
1061         memset(s->above_mode_ctx, NEARESTMV, s->cols);
1062     memset(s->above_y_nnz_ctx, 0, s->sb_cols * 16);
1063     memset(s->above_uv_nnz_ctx[0], 0, s->sb_cols * 8);
1064     memset(s->above_uv_nnz_ctx[1], 0, s->sb_cols * 8);
1065     memset(s->above_segpred_ctx, 0, s->cols);
1066     for (tile_row = 0; tile_row < s->tiling.tile_rows; tile_row++) {
1067         set_tile_offset(&s->tiling.tile_row_start, &s->tiling.tile_row_end,
1068                         tile_row, s->tiling.log2_tile_rows, s->sb_rows);
1069         for (tile_col = 0; tile_col < s->tiling.tile_cols; tile_col++) {
1070             int64_t tile_size;
1071
1072             if (tile_col == s->tiling.tile_cols - 1 &&
1073                 tile_row == s->tiling.tile_rows - 1) {
1074                 tile_size = size;
1075             } else {
1076                 tile_size = AV_RB32(data);
1077                 data     += 4;
1078                 size     -= 4;
1079             }
1080             if (tile_size > size)
1081                 return AVERROR_INVALIDDATA;
1082             ff_vp56_init_range_decoder(&s->c_b[tile_col], data, tile_size);
1083             if (vp56_rac_get_prob_branchy(&s->c_b[tile_col], 128)) // marker bit
1084                 return AVERROR_INVALIDDATA;
1085             data += tile_size;
1086             size -= tile_size;
1087         }
1088
1089         for (row = s->tiling.tile_row_start;
1090              row < s->tiling.tile_row_end;
1091              row += 8, yoff += s->cur_frame->linesize[0] * 64,
1092              uvoff += s->cur_frame->linesize[1] * 32) {
1093             VP9Filter *lflvl = s->lflvl;
1094             ptrdiff_t yoff2 = yoff, uvoff2 = uvoff;
1095
1096             for (tile_col = 0; tile_col < s->tiling.tile_cols; tile_col++) {
1097                 set_tile_offset(&s->tiling.tile_col_start,
1098                                 &s->tiling.tile_col_end,
1099                                 tile_col, s->tiling.log2_tile_cols, s->sb_cols);
1100
1101                 memset(s->left_partition_ctx, 0, 8);
1102                 memset(s->left_skip_ctx, 0, 8);
1103                 if (s->keyframe || s->intraonly)
1104                     memset(s->left_mode_ctx, DC_PRED, 16);
1105                 else
1106                     memset(s->left_mode_ctx, NEARESTMV, 8);
1107                 memset(s->left_y_nnz_ctx, 0, 16);
1108                 memset(s->left_uv_nnz_ctx, 0, 16);
1109                 memset(s->left_segpred_ctx, 0, 8);
1110
1111                 memcpy(&s->c, &s->c_b[tile_col], sizeof(s->c));
1112                 for (col = s->tiling.tile_col_start;
1113                      col < s->tiling.tile_col_end;
1114                      col += 8, yoff2 += 64, uvoff2 += 32, lflvl++) {
1115                     // FIXME integrate with lf code (i.e. zero after each
1116                     // use, similar to invtxfm coefficients, or similar)
1117                     memset(lflvl->mask, 0, sizeof(lflvl->mask));
1118
1119                     if ((ret = decode_subblock(avctx, row, col, lflvl,
1120                                                yoff2, uvoff2, BL_64X64)) < 0)
1121                         return ret;
1122                 }
1123                 memcpy(&s->c_b[tile_col], &s->c, sizeof(s->c));
1124             }
1125
1126             // backup pre-loopfilter reconstruction data for intra
1127             // prediction of next row of sb64s
1128             if (row + 8 < s->rows) {
1129                 memcpy(s->intra_pred_data[0],
1130                        s->cur_frame->data[0] + yoff +
1131                        63 * s->cur_frame->linesize[0],
1132                        8 * s->cols);
1133                 memcpy(s->intra_pred_data[1],
1134                        s->cur_frame->data[1] + uvoff +
1135                        31 * s->cur_frame->linesize[1],
1136                        4 * s->cols);
1137                 memcpy(s->intra_pred_data[2],
1138                        s->cur_frame->data[2] + uvoff +
1139                        31 * s->cur_frame->linesize[2],
1140                        4 * s->cols);
1141             }
1142
1143             // loopfilter one row
1144             if (s->filter.level) {
1145                 yoff2  = yoff;
1146                 uvoff2 = uvoff;
1147                 lflvl  = s->lflvl;
1148                 for (col = 0; col < s->cols;
1149                      col += 8, yoff2 += 64, uvoff2 += 32, lflvl++)
1150                     loopfilter_subblock(avctx, lflvl, row, col, yoff2, uvoff2);
1151             }
1152         }
1153     }
1154
1155     // bw adaptivity (or in case of parallel decoding mode, fw adaptivity
1156     // probability maintenance between frames)
1157     if (s->refreshctx) {
1158         if (s->parallelmode) {
1159             int j, k, l, m;
1160             for (i = 0; i < 4; i++) {
1161                 for (j = 0; j < 2; j++)
1162                     for (k = 0; k < 2; k++)
1163                         for (l = 0; l < 6; l++)
1164                             for (m = 0; m < 6; m++)
1165                                 memcpy(s->prob_ctx[s->framectxid].coef[i][j][k][l][m],
1166                                        s->prob.coef[i][j][k][l][m], 3);
1167                 if (s->txfmmode == i)
1168                     break;
1169             }
1170             s->prob_ctx[s->framectxid].p = s->prob.p;
1171         } else {
1172             ff_vp9_adapt_probs(s);
1173         }
1174     }
1175     FFSWAP(VP9MVRefPair *, s->mv[0], s->mv[1]);
1176
1177     // ref frame setup
1178     for (i = 0; i < 8; i++)
1179         if (s->refreshrefmask & (1 << i)) {
1180             av_frame_unref(s->refs[i]);
1181             ret = av_frame_ref(s->refs[i], s->cur_frame);
1182             if (ret < 0)
1183                 return ret;
1184         }
1185
1186     if (s->invisible)
1187         av_frame_unref(s->cur_frame);
1188     else
1189         *got_frame = 1;
1190
1191     return 0;
1192 }
1193
1194 static int vp9_decode_packet(AVCodecContext *avctx, void *frame,
1195                              int *got_frame, AVPacket *avpkt)
1196 {
1197     const uint8_t *data = avpkt->data;
1198     int size            = avpkt->size;
1199     int marker, ret;
1200
1201     /* Read superframe index - this is a collection of individual frames
1202      * that together lead to one visible frame */
1203     marker = data[size - 1];
1204     if ((marker & 0xe0) == 0xc0) {
1205         int nbytes   = 1 + ((marker >> 3) & 0x3);
1206         int n_frames = 1 + (marker & 0x7);
1207         int idx_sz   = 2 + n_frames * nbytes;
1208
1209         if (size >= idx_sz && data[size - idx_sz] == marker) {
1210             const uint8_t *idx = data + size + 1 - idx_sz;
1211
1212             while (n_frames--) {
1213                 unsigned sz = AV_RL32(idx);
1214
1215                 if (nbytes < 4)
1216                     sz &= (1 << (8 * nbytes)) - 1;
1217                 idx += nbytes;
1218
1219                 if (sz > size) {
1220                     av_log(avctx, AV_LOG_ERROR,
1221                            "Superframe packet size too big: %u > %d\n",
1222                            sz, size);
1223                     return AVERROR_INVALIDDATA;
1224                 }
1225
1226                 ret = vp9_decode_frame(avctx, frame, got_frame, data, sz);
1227                 if (ret < 0)
1228                     return ret;
1229                 data += sz;
1230                 size -= sz;
1231             }
1232             return size;
1233         }
1234     }
1235
1236     /* If we get here, there was no valid superframe index, i.e. this is just
1237      * one whole single frame. Decode it as such from the complete input buf. */
1238     if ((ret = vp9_decode_frame(avctx, frame, got_frame, data, size)) < 0)
1239         return ret;
1240     return size;
1241 }
1242
1243 static av_cold int vp9_decode_free(AVCodecContext *avctx)
1244 {
1245     VP9Context *s = avctx->priv_data;
1246     int i;
1247
1248     for (i = 0; i < FF_ARRAY_ELEMS(s->refs); i++)
1249         av_frame_free(&s->refs[i]);
1250
1251     av_freep(&s->c_b);
1252     av_freep(&s->above_partition_ctx);
1253
1254     return 0;
1255 }
1256
1257 static av_cold int vp9_decode_init(AVCodecContext *avctx)
1258 {
1259     VP9Context *s = avctx->priv_data;
1260     int i;
1261
1262     avctx->pix_fmt = AV_PIX_FMT_YUV420P;
1263
1264     ff_vp9dsp_init(&s->dsp);
1265     ff_videodsp_init(&s->vdsp, 8);
1266
1267     for (i = 0; i < FF_ARRAY_ELEMS(s->refs); i++) {
1268         s->refs[i] = av_frame_alloc();
1269         if (!s->refs[i]) {
1270             vp9_decode_free(avctx);
1271             return AVERROR(ENOMEM);
1272         }
1273     }
1274
1275     s->filter.sharpness = -1;
1276
1277     return 0;
1278 }
1279
1280 AVCodec ff_vp9_decoder = {
1281     .name           = "vp9",
1282     .long_name      = NULL_IF_CONFIG_SMALL("Google VP9"),
1283     .type           = AVMEDIA_TYPE_VIDEO,
1284     .id             = AV_CODEC_ID_VP9,
1285     .priv_data_size = sizeof(VP9Context),
1286     .init           = vp9_decode_init,
1287     .decode         = vp9_decode_packet,
1288     .flush          = vp9_decode_flush,
1289     .close          = vp9_decode_free,
1290     .capabilities   = CODEC_CAP_DR1,
1291 };