]> git.sesse.net Git - ffmpeg/blob - libavcodec/mss3.c
Merge remote-tracking branch 'qatar/master'
[ffmpeg] / libavcodec / mss3.c
1 /*
2  * Microsoft Screen 3 (aka Microsoft ATC Screen) decoder
3  * Copyright (c) 2012 Konstantin Shishkov
4  *
5  * This file is part of FFmpeg.
6  *
7  * FFmpeg is free software; you can redistribute it and/or
8  * modify it under the terms of the GNU Lesser General Public
9  * License as published by the Free Software Foundation; either
10  * version 2.1 of the License, or (at your option) any later version.
11  *
12  * FFmpeg is distributed in the hope that it will be useful,
13  * but WITHOUT ANY WARRANTY; without even the implied warranty of
14  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
15  * Lesser General Public License for more details.
16  *
17  * You should have received a copy of the GNU Lesser General Public
18  * License along with FFmpeg; if not, write to the Free Software
19  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
20  */
21
22 /**
23  * @file
24  * Microsoft Screen 3 (aka Microsoft ATC Screen) decoder
25  */
26
27 #include "avcodec.h"
28 #include "bytestream.h"
29
30 #define HEADER_SIZE 27
31
32 #define MODEL2_SCALE       13
33 #define MODEL_SCALE        15
34 #define MODEL256_SEC_SCALE  9
35
36 typedef struct Model2 {
37     int      upd_val, till_rescale;
38     unsigned zero_freq,  zero_weight;
39     unsigned total_freq, total_weight;
40 } Model2;
41
42 typedef struct Model {
43     int weights[16], freqs[16];
44     int num_syms;
45     int tot_weight;
46     int upd_val, max_upd_val, till_rescale;
47 } Model;
48
49 typedef struct Model256 {
50     int weights[256], freqs[256];
51     int tot_weight;
52     int secondary[68];
53     int sec_size;
54     int upd_val, max_upd_val, till_rescale;
55 } Model256;
56
57 #define RAC_BOTTOM 0x01000000
58 typedef struct RangeCoder {
59     const uint8_t *src, *src_end;
60
61     uint32_t range, low;
62     int got_error;
63 } RangeCoder;
64
65 enum BlockType {
66     FILL_BLOCK = 0,
67     IMAGE_BLOCK,
68     DCT_BLOCK,
69     HAAR_BLOCK,
70     SKIP_BLOCK
71 };
72
73 typedef struct BlockTypeContext {
74     int      last_type;
75     Model    bt_model[5];
76 } BlockTypeContext;
77
78 typedef struct FillBlockCoder {
79     int      fill_val;
80     Model    coef_model;
81 } FillBlockCoder;
82
83 typedef struct ImageBlockCoder {
84     Model256 esc_model, vec_entry_model;
85     Model    vec_size_model;
86     Model    vq_model[125];
87 } ImageBlockCoder;
88
89 typedef struct DCTBlockCoder {
90     int      *prev_dc;
91     int      prev_dc_stride;
92     int      prev_dc_height;
93     int      quality;
94     uint16_t qmat[64];
95     Model    dc_model;
96     Model2   sign_model;
97     Model256 ac_model;
98 } DCTBlockCoder;
99
100 typedef struct HaarBlockCoder {
101     int      quality, scale;
102     Model256 coef_model;
103     Model    coef_hi_model;
104 } HaarBlockCoder;
105
106 typedef struct MSS3Context {
107     AVCodecContext   *avctx;
108     AVFrame          pic;
109
110     int              got_error;
111     RangeCoder       coder;
112     BlockTypeContext btype[3];
113     FillBlockCoder   fill_coder[3];
114     ImageBlockCoder  image_coder[3];
115     DCTBlockCoder    dct_coder[3];
116     HaarBlockCoder   haar_coder[3];
117
118     int              dctblock[64];
119     int              hblock[16 * 16];
120 } MSS3Context;
121
122 static const uint8_t mss3_luma_quant[64] = {
123     16,  11,  10,  16,  24,  40,  51,  61,
124     12,  12,  14,  19,  26,  58,  60,  55,
125     14,  13,  16,  24,  40,  57,  69,  56,
126     14,  17,  22,  29,  51,  87,  80,  62,
127     18,  22,  37,  56,  68, 109, 103,  77,
128     24,  35,  55,  64,  81, 104, 113,  92,
129     49,  64,  78,  87, 103, 121, 120, 101,
130     72,  92,  95,  98, 112, 100, 103,  99
131 };
132
133 static const uint8_t mss3_chroma_quant[64] = {
134     17, 18, 24, 47, 99, 99, 99, 99,
135     18, 21, 26, 66, 99, 99, 99, 99,
136     24, 26, 56, 99, 99, 99, 99, 99,
137     47, 66, 99, 99, 99, 99, 99, 99,
138     99, 99, 99, 99, 99, 99, 99, 99,
139     99, 99, 99, 99, 99, 99, 99, 99,
140     99, 99, 99, 99, 99, 99, 99, 99,
141     99, 99, 99, 99, 99, 99, 99, 99
142 };
143
144 const uint8_t zigzag_scan[64] = {
145     0,   1,  8, 16,  9,  2,  3, 10,
146     17, 24, 32, 25, 18, 11,  4,  5,
147     12, 19, 26, 33, 40, 48, 41, 34,
148     27, 20, 13,  6,  7, 14, 21, 28,
149     35, 42, 49, 56, 57, 50, 43, 36,
150     29, 22, 15, 23, 30, 37, 44, 51,
151     58, 59, 52, 45, 38, 31, 39, 46,
152     53, 60, 61, 54, 47, 55, 62, 63
153 };
154
155
156 static void model2_reset(Model2 *m)
157 {
158     m->zero_weight  = 1;
159     m->total_weight = 2;
160     m->zero_freq    = 0x1000;
161     m->total_freq   = 0x2000;
162     m->upd_val      = 4;
163     m->till_rescale = 4;
164 }
165
166 static void model2_update(Model2 *m, int bit)
167 {
168     unsigned scale;
169
170     if (!bit)
171         m->zero_weight++;
172     m->till_rescale--;
173     if (m->till_rescale)
174         return;
175
176     m->total_weight += m->upd_val;
177     if (m->total_weight > 0x2000) {
178         m->total_weight = (m->total_weight + 1) >> 1;
179         m->zero_weight  = (m->zero_weight  + 1) >> 1;
180         if (m->total_weight == m->zero_weight)
181             m->total_weight = m->zero_weight + 1;
182     }
183     m->upd_val = m->upd_val * 5 >> 2;
184     if (m->upd_val > 64)
185         m->upd_val = 64;
186     scale = 0x80000000u / m->total_weight;
187     m->zero_freq    = m->zero_weight  * scale >> 18;
188     m->total_freq   = m->total_weight * scale >> 18;
189     m->till_rescale = m->upd_val;
190 }
191
192 static void model_update(Model *m, int val)
193 {
194     int i, sum = 0;
195     unsigned scale;
196
197     m->weights[val]++;
198     m->till_rescale--;
199     if (m->till_rescale)
200         return;
201     m->tot_weight += m->upd_val;
202
203     if (m->tot_weight > 0x8000) {
204         m->tot_weight = 0;
205         for (i = 0; i < m->num_syms; i++) {
206             m->weights[i]  = (m->weights[i] + 1) >> 1;
207             m->tot_weight +=  m->weights[i];
208         }
209     }
210     scale = 0x80000000u / m->tot_weight;
211     for (i = 0; i < m->num_syms; i++) {
212         m->freqs[i] = sum * scale >> 16;
213         sum += m->weights[i];
214     }
215
216     m->upd_val = m->upd_val * 5 >> 2;
217     if (m->upd_val > m->max_upd_val)
218         m->upd_val = m->max_upd_val;
219     m->till_rescale = m->upd_val;
220 }
221
222 static void model_reset(Model *m)
223 {
224     int i;
225
226     m->tot_weight   = 0;
227     for (i = 0; i < m->num_syms - 1; i++)
228         m->weights[i] = 1;
229     m->weights[m->num_syms - 1] = 0;
230
231     m->upd_val      = m->num_syms;
232     m->till_rescale = 1;
233     model_update(m, m->num_syms - 1);
234     m->till_rescale =
235     m->upd_val      = (m->num_syms + 6) >> 1;
236 }
237
238 static av_cold void model_init(Model *m, int num_syms)
239 {
240     m->num_syms    = num_syms;
241     m->max_upd_val = 8 * num_syms + 48;
242
243     model_reset(m);
244 }
245
246 static void model256_update(Model256 *m, int val)
247 {
248     int i, sum = 0;
249     unsigned scale;
250     int send, sidx = 1;
251
252     m->weights[val]++;
253     m->till_rescale--;
254     if (m->till_rescale)
255         return;
256     m->tot_weight += m->upd_val;
257
258     if (m->tot_weight > 0x8000) {
259         m->tot_weight = 0;
260         for (i = 0; i < 256; i++) {
261             m->weights[i]  = (m->weights[i] + 1) >> 1;
262             m->tot_weight +=  m->weights[i];
263         }
264     }
265     scale = 0x80000000u / m->tot_weight;
266     m->secondary[0] = 0;
267     for (i = 0; i < 256; i++) {
268         m->freqs[i] = sum * scale >> 16;
269         sum += m->weights[i];
270         send = m->freqs[i] >> MODEL256_SEC_SCALE;
271         while (sidx <= send)
272             m->secondary[sidx++] = i - 1;
273     }
274     while (sidx < m->sec_size)
275         m->secondary[sidx++] = 255;
276
277     m->upd_val = m->upd_val * 5 >> 2;
278     if (m->upd_val > m->max_upd_val)
279         m->upd_val = m->max_upd_val;
280     m->till_rescale = m->upd_val;
281 }
282
283 static void model256_reset(Model256 *m)
284 {
285     int i;
286
287     for (i = 0; i < 255; i++)
288         m->weights[i] = 1;
289     m->weights[255] = 0;
290
291     m->tot_weight   = 0;
292     m->upd_val      = 256;
293     m->till_rescale = 1;
294     model256_update(m, 255);
295     m->till_rescale =
296     m->upd_val      = (256 + 6) >> 1;
297 }
298
299 static av_cold void model256_init(Model256 *m)
300 {
301     m->max_upd_val = 8 * 256 + 48;
302     m->sec_size    = (1 << 6) + 2;
303
304     model256_reset(m);
305 }
306
307 static void rac_init(RangeCoder *c, const uint8_t *src, int size)
308 {
309     int i;
310
311     c->src       = src;
312     c->src_end   = src + size;
313     c->low       = 0;
314     for (i = 0; i < FFMIN(size, 4); i++)
315         c->low = (c->low << 8) | *c->src++;
316     c->range     = 0xFFFFFFFF;
317     c->got_error = 0;
318 }
319
320 static void rac_normalise(RangeCoder *c)
321 {
322     for (;;) {
323         c->range <<= 8;
324         c->low   <<= 8;
325         if (c->src < c->src_end) {
326             c->low |= *c->src++;
327         } else if (!c->low) {
328             c->got_error = 1;
329             return;
330         }
331         if (c->range >= RAC_BOTTOM)
332             return;
333     }
334 }
335
336 static int rac_get_bit(RangeCoder *c)
337 {
338     int bit;
339
340     c->range >>= 1;
341
342     bit = (c->range <= c->low);
343     if (bit)
344         c->low -= c->range;
345
346     if (c->range < RAC_BOTTOM)
347         rac_normalise(c);
348
349     return bit;
350 }
351
352 static int rac_get_bits(RangeCoder *c, int nbits)
353 {
354     int val;
355
356     c->range >>= nbits;
357     val = c->low / c->range;
358     c->low -= c->range * val;
359
360     if (c->range < RAC_BOTTOM)
361         rac_normalise(c);
362
363     return val;
364 }
365
366 static int rac_get_model2_sym(RangeCoder *c, Model2 *m)
367 {
368     int bit, helper;
369
370     helper = m->zero_freq * (c->range >> MODEL2_SCALE);
371     bit    = (c->low >= helper);
372     if (bit) {
373         c->low   -= helper;
374         c->range -= helper;
375     } else {
376         c->range  = helper;
377     }
378
379     if (c->range < RAC_BOTTOM)
380         rac_normalise(c);
381
382     model2_update(m, bit);
383
384     return bit;
385 }
386
387 static int rac_get_model_sym(RangeCoder *c, Model *m)
388 {
389     int prob, prob2, helper, val;
390     int end, end2;
391
392     prob       = 0;
393     prob2      = c->range;
394     c->range >>= MODEL_SCALE;
395     val        = 0;
396     end        = m->num_syms >> 1;
397     end2       = m->num_syms;
398     do {
399         helper = m->freqs[end] * c->range;
400         if (helper <= c->low) {
401             val   = end;
402             prob  = helper;
403         } else {
404             end2  = end;
405             prob2 = helper;
406         }
407         end = (end2 + val) >> 1;
408     } while (end != val);
409     c->low  -= prob;
410     c->range = prob2 - prob;
411     if (c->range < RAC_BOTTOM)
412         rac_normalise(c);
413
414     model_update(m, val);
415
416     return val;
417 }
418
419 static int rac_get_model256_sym(RangeCoder *c, Model256 *m)
420 {
421     int prob, prob2, helper, val;
422     int start, end;
423     int ssym;
424
425     prob2      = c->range;
426     c->range >>= MODEL_SCALE;
427
428     helper     = c->low / c->range;
429     ssym       = helper >> MODEL256_SEC_SCALE;
430     val        = m->secondary[ssym];
431
432     end = start = m->secondary[ssym + 1] + 1;
433     while (end > val + 1) {
434         ssym = (end + val) >> 1;
435         if (m->freqs[ssym] <= helper) {
436             end = start;
437             val = ssym;
438         } else {
439             end   = (end + val) >> 1;
440             start = ssym;
441         }
442     }
443     prob = m->freqs[val] * c->range;
444     if (val != 255)
445         prob2 = m->freqs[val + 1] * c->range;
446
447     c->low  -= prob;
448     c->range = prob2 - prob;
449     if (c->range < RAC_BOTTOM)
450         rac_normalise(c);
451
452     model256_update(m, val);
453
454     return val;
455 }
456
457 static int decode_block_type(RangeCoder *c, BlockTypeContext *bt)
458 {
459     bt->last_type = rac_get_model_sym(c, &bt->bt_model[bt->last_type]);
460
461     return bt->last_type;
462 }
463
464 static int decode_coeff(RangeCoder *c, Model *m)
465 {
466     int val, sign;
467
468     val = rac_get_model_sym(c, m);
469     if (val) {
470         sign = rac_get_bit(c);
471         if (val > 1) {
472             val--;
473             val = (1 << val) + rac_get_bits(c, val);
474         }
475         if (!sign)
476             val = -val;
477     }
478
479     return val;
480 }
481
482 static void decode_fill_block(RangeCoder *c, FillBlockCoder *fc,
483                               uint8_t *dst, int stride, int block_size)
484 {
485     int i;
486
487     fc->fill_val += decode_coeff(c, &fc->coef_model);
488
489     for (i = 0; i < block_size; i++, dst += stride)
490         memset(dst, fc->fill_val, block_size);
491 }
492
493 static void decode_image_block(RangeCoder *c, ImageBlockCoder *ic,
494                                uint8_t *dst, int stride, int block_size)
495 {
496     int i, j;
497     int vec_size;
498     int vec[4];
499     int prev_line[16];
500     int A, B, C;
501
502     vec_size = rac_get_model_sym(c, &ic->vec_size_model) + 2;
503     for (i = 0; i < vec_size; i++)
504         vec[i] = rac_get_model256_sym(c, &ic->vec_entry_model);
505     for (; i < 4; i++)
506         vec[i] = 0;
507     memset(prev_line, 0, sizeof(prev_line));
508
509     for (j = 0; j < block_size; j++) {
510         A = 0;
511         B = 0;
512         for (i = 0; i < block_size; i++) {
513             C = B;
514             B = prev_line[i];
515             A = rac_get_model_sym(c, &ic->vq_model[A + B * 5 + C * 25]);
516
517             prev_line[i] = A;
518             if (A < 4)
519                dst[i] = vec[A];
520             else
521                dst[i] = rac_get_model256_sym(c, &ic->esc_model);
522         }
523         dst += stride;
524     }
525 }
526
527 static int decode_dct(RangeCoder *c, DCTBlockCoder *bc, int *block,
528                       int bx, int by)
529 {
530     int skip, val, sign, pos = 1, zz_pos, dc;
531     int blk_pos = bx + by * bc->prev_dc_stride;
532
533     memset(block, 0, sizeof(*block) * 64);
534
535     dc = decode_coeff(c, &bc->dc_model);
536     if (by) {
537         if (bx) {
538             int l, tl, t;
539
540             l  = bc->prev_dc[blk_pos - 1];
541             tl = bc->prev_dc[blk_pos - 1 - bc->prev_dc_stride];
542             t  = bc->prev_dc[blk_pos     - bc->prev_dc_stride];
543
544             if (FFABS(t - tl) <= FFABS(l - tl))
545                 dc += l;
546             else
547                 dc += t;
548         } else {
549             dc += bc->prev_dc[blk_pos - bc->prev_dc_stride];
550         }
551     } else if (bx) {
552         dc += bc->prev_dc[bx - 1];
553     }
554     bc->prev_dc[blk_pos] = dc;
555     block[0]             = dc * bc->qmat[0];
556
557     while (pos < 64) {
558         val = rac_get_model256_sym(c, &bc->ac_model);
559         if (!val)
560             return 0;
561         if (val == 0xF0) {
562             pos += 16;
563             continue;
564         }
565         skip = val >> 4;
566         val  = val & 0xF;
567         if (!val)
568             return -1;
569         pos += skip;
570         if (pos >= 64)
571             return -1;
572
573         sign = rac_get_model2_sym(c, &bc->sign_model);
574         if (val > 1) {
575             val--;
576             val = (1 << val) + rac_get_bits(c, val);
577         }
578         if (!sign)
579             val = -val;
580
581         zz_pos = zigzag_scan[pos];
582         block[zz_pos] = val * bc->qmat[zz_pos];
583         pos++;
584     }
585
586     return pos == 64 ? 0 : -1;
587 }
588
589 #define DCT_TEMPLATE(blk, step, SOP, shift)                         \
590     const int t0 = -39409 * blk[7 * step] -  58980 * blk[1 * step]; \
591     const int t1 =  39410 * blk[1 * step] -  58980 * blk[7 * step]; \
592     const int t2 = -33410 * blk[5 * step] - 167963 * blk[3 * step]; \
593     const int t3 =  33410 * blk[3 * step] - 167963 * blk[5 * step]; \
594     const int t4 =          blk[3 * step] +          blk[7 * step]; \
595     const int t5 =          blk[1 * step] +          blk[5 * step]; \
596     const int t6 =  77062 * t4            +  51491 * t5;            \
597     const int t7 =  77062 * t5            -  51491 * t4;            \
598     const int t8 =  35470 * blk[2 * step] -  85623 * blk[6 * step]; \
599     const int t9 =  35470 * blk[6 * step] +  85623 * blk[2 * step]; \
600     const int tA = SOP(blk[0 * step] - blk[4 * step]);              \
601     const int tB = SOP(blk[0 * step] + blk[4 * step]);              \
602                                                                     \
603     blk[0 * step] = (  t1 + t6  + t9 + tB) >> shift;                \
604     blk[1 * step] = (  t3 + t7  + t8 + tA) >> shift;                \
605     blk[2 * step] = (  t2 + t6  - t8 + tA) >> shift;                \
606     blk[3 * step] = (  t0 + t7  - t9 + tB) >> shift;                \
607     blk[4 * step] = (-(t0 + t7) - t9 + tB) >> shift;                \
608     blk[5 * step] = (-(t2 + t6) - t8 + tA) >> shift;                \
609     blk[6 * step] = (-(t3 + t7) + t8 + tA) >> shift;                \
610     blk[7 * step] = (-(t1 + t6) + t9 + tB) >> shift;                \
611
612 #define SOP_ROW(a) ((a) << 16) + 0x2000
613 #define SOP_COL(a) ((a + 32) << 16)
614
615 static void dct_put(uint8_t *dst, int stride, int *block)
616 {
617     int i, j;
618     int *ptr;
619
620     ptr = block;
621     for (i = 0; i < 8; i++) {
622         DCT_TEMPLATE(ptr, 1, SOP_ROW, 13);
623         ptr += 8;
624     }
625
626     ptr = block;
627     for (i = 0; i < 8; i++) {
628         DCT_TEMPLATE(ptr, 8, SOP_COL, 22);
629         ptr++;
630     }
631
632     ptr = block;
633     for (j = 0; j < 8; j++) {
634         for (i = 0; i < 8; i++)
635             dst[i] = av_clip_uint8(ptr[i] + 128);
636         dst += stride;
637         ptr += 8;
638     }
639 }
640
641 static void decode_dct_block(RangeCoder *c, DCTBlockCoder *bc,
642                              uint8_t *dst, int stride, int block_size,
643                              int *block, int mb_x, int mb_y)
644 {
645     int i, j;
646     int bx, by;
647     int nblocks = block_size >> 3;
648
649     bx = mb_x * nblocks;
650     by = mb_y * nblocks;
651
652     for (j = 0; j < nblocks; j++) {
653         for (i = 0; i < nblocks; i++) {
654             if (decode_dct(c, bc, block, bx + i, by + j)) {
655                 c->got_error = 1;
656                 return;
657             }
658             dct_put(dst + i * 8, stride, block);
659         }
660         dst += 8 * stride;
661     }
662 }
663
664 static void decode_haar_block(RangeCoder *c, HaarBlockCoder *hc,
665                               uint8_t *dst, int stride, int block_size,
666                               int *block)
667 {
668     const int hsize = block_size >> 1;
669     int A, B, C, D, t1, t2, t3, t4;
670     int i, j;
671
672     for (j = 0; j < block_size; j++) {
673         for (i = 0; i < block_size; i++) {
674             if (i < hsize && j < hsize)
675                 block[i] = rac_get_model256_sym(c, &hc->coef_model);
676             else
677                 block[i] = decode_coeff(c, &hc->coef_hi_model);
678             block[i] *= hc->scale;
679         }
680         block += block_size;
681     }
682     block -= block_size * block_size;
683
684     for (j = 0; j < hsize; j++) {
685         for (i = 0; i < hsize; i++) {
686             A = block[i];
687             B = block[i + hsize];
688             C = block[i + hsize * block_size];
689             D = block[i + hsize * block_size + hsize];
690
691             t1 = A - B;
692             t2 = C - D;
693             t3 = A + B;
694             t4 = C + D;
695             dst[i * 2]              = av_clip_uint8(t1 - t2);
696             dst[i * 2 + stride]     = av_clip_uint8(t1 + t2);
697             dst[i * 2 + 1]          = av_clip_uint8(t3 - t4);
698             dst[i * 2 + 1 + stride] = av_clip_uint8(t3 + t4);
699         }
700         block += block_size;
701         dst   += stride * 2;
702     }
703 }
704
705 static void gen_quant_mat(uint16_t *qmat, const uint8_t *ref, float scale)
706 {
707     int i;
708
709     for (i = 0; i < 64; i++)
710         qmat[i] = (uint16_t)(ref[i] * scale + 50.0) / 100;
711 }
712
713 static void reset_coders(MSS3Context *ctx, int quality)
714 {
715     int i, j;
716
717     for (i = 0; i < 3; i++) {
718         ctx->btype[i].last_type = SKIP_BLOCK;
719         for (j = 0; j < 5; j++)
720             model_reset(&ctx->btype[i].bt_model[j]);
721         ctx->fill_coder[i].fill_val = 0;
722         model_reset(&ctx->fill_coder[i].coef_model);
723         model256_reset(&ctx->image_coder[i].esc_model);
724         model256_reset(&ctx->image_coder[i].vec_entry_model);
725         model_reset(&ctx->image_coder[i].vec_size_model);
726         for (j = 0; j < 125; j++)
727             model_reset(&ctx->image_coder[i].vq_model[j]);
728         if (ctx->dct_coder[i].quality != quality) {
729             float scale;
730             ctx->dct_coder[i].quality = quality;
731             if (quality > 50)
732                 scale = 200.0f - 2 * quality;
733             else
734                 scale = 5000.0f / quality;
735             gen_quant_mat(ctx->dct_coder[i].qmat,
736                           i ? mss3_chroma_quant : mss3_luma_quant,
737                           scale);
738         }
739         memset(ctx->dct_coder[i].prev_dc, 0,
740                sizeof(*ctx->dct_coder[i].prev_dc) *
741                ctx->dct_coder[i].prev_dc_stride *
742                ctx->dct_coder[i].prev_dc_height);
743         model_reset(&ctx->dct_coder[i].dc_model);
744         model2_reset(&ctx->dct_coder[i].sign_model);
745         model256_reset(&ctx->dct_coder[i].ac_model);
746         if (ctx->haar_coder[i].quality != quality) {
747             ctx->haar_coder[i].quality = quality;
748             ctx->haar_coder[i].scale   = 17 - 7 * quality / 50;
749         }
750         model_reset(&ctx->haar_coder[i].coef_hi_model);
751         model256_reset(&ctx->haar_coder[i].coef_model);
752     }
753 }
754
755 static av_cold void init_coders(MSS3Context *ctx)
756 {
757     int i, j;
758
759     for (i = 0; i < 3; i++) {
760         for (j = 0; j < 5; j++)
761             model_init(&ctx->btype[i].bt_model[j], 5);
762         model_init(&ctx->fill_coder[i].coef_model, 12);
763         model256_init(&ctx->image_coder[i].esc_model);
764         model256_init(&ctx->image_coder[i].vec_entry_model);
765         model_init(&ctx->image_coder[i].vec_size_model, 3);
766         for (j = 0; j < 125; j++)
767             model_init(&ctx->image_coder[i].vq_model[j], 5);
768         model_init(&ctx->dct_coder[i].dc_model, 12);
769         model256_init(&ctx->dct_coder[i].ac_model);
770         model_init(&ctx->haar_coder[i].coef_hi_model, 12);
771         model256_init(&ctx->haar_coder[i].coef_model);
772     }
773 }
774
775 static int mss3_decode_frame(AVCodecContext *avctx, void *data, int *data_size,
776                              AVPacket *avpkt)
777 {
778     const uint8_t *buf = avpkt->data;
779     int buf_size = avpkt->size;
780     MSS3Context *c = avctx->priv_data;
781     RangeCoder *acoder = &c->coder;
782     GetByteContext gb;
783     uint8_t *dst[3];
784     int dec_width, dec_height, dec_x, dec_y, quality, keyframe;
785     int x, y, i, mb_width, mb_height, blk_size, btype;
786     int ret;
787
788     if (buf_size < HEADER_SIZE) {
789         av_log(avctx, AV_LOG_ERROR,
790                "Frame should have at least %d bytes, got %d instead\n",
791                HEADER_SIZE, buf_size);
792         return AVERROR_INVALIDDATA;
793     }
794
795     bytestream2_init(&gb, buf, buf_size);
796     keyframe   = bytestream2_get_be32(&gb);
797     if (keyframe & ~0x301) {
798         av_log(avctx, AV_LOG_ERROR, "Invalid frame type %X\n", keyframe);
799         return AVERROR_INVALIDDATA;
800     }
801     keyframe   = !(keyframe & 1);
802     bytestream2_skip(&gb, 6);
803     dec_x      = bytestream2_get_be16(&gb);
804     dec_y      = bytestream2_get_be16(&gb);
805     dec_width  = bytestream2_get_be16(&gb);
806     dec_height = bytestream2_get_be16(&gb);
807
808     if (dec_x + dec_width > avctx->width ||
809         dec_y + dec_height > avctx->height ||
810         (dec_width | dec_height) & 0xF) {
811         av_log(avctx, AV_LOG_ERROR, "Invalid frame dimensions %dx%d +%d,%d\n",
812                dec_width, dec_height, dec_x, dec_y);
813         return AVERROR_INVALIDDATA;
814     }
815     bytestream2_skip(&gb, 4);
816     quality    = bytestream2_get_byte(&gb);
817     if (quality < 1 || quality > 100) {
818         av_log(avctx, AV_LOG_ERROR, "Invalid quality setting %d\n", quality);
819         return AVERROR_INVALIDDATA;
820     }
821     bytestream2_skip(&gb, 4);
822
823     if (keyframe && !bytestream2_get_bytes_left(&gb)) {
824         av_log(avctx, AV_LOG_ERROR, "Keyframe without data found\n");
825         return AVERROR_INVALIDDATA;
826     }
827     if (!keyframe && c->got_error)
828         return buf_size;
829     c->got_error = 0;
830
831     c->pic.reference    = 3;
832     c->pic.buffer_hints = FF_BUFFER_HINTS_VALID | FF_BUFFER_HINTS_PRESERVE |
833                           FF_BUFFER_HINTS_REUSABLE;
834     if ((ret = avctx->reget_buffer(avctx, &c->pic)) < 0) {
835         av_log(avctx, AV_LOG_ERROR, "reget_buffer() failed\n");
836         return ret;
837     }
838     c->pic.key_frame = keyframe;
839     c->pic.pict_type = keyframe ? AV_PICTURE_TYPE_I : AV_PICTURE_TYPE_P;
840     if (!bytestream2_get_bytes_left(&gb)) {
841         *data_size = sizeof(AVFrame);
842         *(AVFrame*)data = c->pic;
843
844         return buf_size;
845     }
846
847     reset_coders(c, quality);
848
849     rac_init(acoder, buf + HEADER_SIZE, buf_size - HEADER_SIZE);
850
851     mb_width  = dec_width  >> 4;
852     mb_height = dec_height >> 4;
853     dst[0] = c->pic.data[0] + dec_x     +  dec_y      * c->pic.linesize[0];
854     dst[1] = c->pic.data[1] + dec_x / 2 + (dec_y / 2) * c->pic.linesize[1];
855     dst[2] = c->pic.data[2] + dec_x / 2 + (dec_y / 2) * c->pic.linesize[2];
856     for (y = 0; y < mb_height; y++) {
857         for (x = 0; x < mb_width; x++) {
858             for (i = 0; i < 3; i++) {
859                 blk_size = 8 << !i;
860
861                 btype = decode_block_type(acoder, c->btype + i);
862                 switch (btype) {
863                 case FILL_BLOCK:
864                     decode_fill_block(acoder, c->fill_coder + i,
865                                       dst[i] + x * blk_size,
866                                       c->pic.linesize[i], blk_size);
867                     break;
868                 case IMAGE_BLOCK:
869                     decode_image_block(acoder, c->image_coder + i,
870                                        dst[i] + x * blk_size,
871                                        c->pic.linesize[i], blk_size);
872                     break;
873                 case DCT_BLOCK:
874                     decode_dct_block(acoder, c->dct_coder + i,
875                                      dst[i] + x * blk_size,
876                                      c->pic.linesize[i], blk_size,
877                                      c->dctblock, x, y);
878                     break;
879                 case HAAR_BLOCK:
880                     decode_haar_block(acoder, c->haar_coder + i,
881                                       dst[i] + x * blk_size,
882                                       c->pic.linesize[i], blk_size,
883                                       c->hblock);
884                     break;
885                 }
886                 if (c->got_error || acoder->got_error) {
887                     av_log(avctx, AV_LOG_ERROR, "Error decoding block %d,%d\n",
888                            x, y);
889                     c->got_error = 1;
890                     return AVERROR_INVALIDDATA;
891                 }
892             }
893         }
894         dst[0] += c->pic.linesize[0] * 16;
895         dst[1] += c->pic.linesize[1] * 8;
896         dst[2] += c->pic.linesize[2] * 8;
897     }
898
899     *data_size = sizeof(AVFrame);
900     *(AVFrame*)data = c->pic;
901
902     return buf_size;
903 }
904
905 static av_cold int mss3_decode_init(AVCodecContext *avctx)
906 {
907     MSS3Context * const c = avctx->priv_data;
908     int i;
909
910     c->avctx = avctx;
911
912     if ((avctx->width & 0xF) || (avctx->height & 0xF)) {
913         av_log(avctx, AV_LOG_ERROR,
914                "Image dimensions should be a multiple of 16.\n");
915         return AVERROR_INVALIDDATA;
916     }
917
918     c->got_error = 0;
919     for (i = 0; i < 3; i++) {
920         int b_width  = avctx->width  >> (2 + !!i);
921         int b_height = avctx->height >> (2 + !!i);
922         c->dct_coder[i].prev_dc_stride = b_width;
923         c->dct_coder[i].prev_dc_height = b_height;
924         c->dct_coder[i].prev_dc = av_malloc(sizeof(*c->dct_coder[i].prev_dc) *
925                                             b_width * b_height);
926         if (!c->dct_coder[i].prev_dc) {
927             av_log(avctx, AV_LOG_ERROR, "Cannot allocate buffer\n");
928             while (i >= 0) {
929                 av_freep(&c->dct_coder[i].prev_dc);
930                 i--;
931             }
932             return AVERROR(ENOMEM);
933         }
934     }
935
936     avctx->pix_fmt     = PIX_FMT_YUV420P;
937     avctx->coded_frame = &c->pic;
938
939     init_coders(c);
940
941     return 0;
942 }
943
944 static av_cold int mss3_decode_end(AVCodecContext *avctx)
945 {
946     MSS3Context * const c = avctx->priv_data;
947     int i;
948
949     if (c->pic.data[0])
950         avctx->release_buffer(avctx, &c->pic);
951     for (i = 0; i < 3; i++)
952         av_freep(&c->dct_coder[i].prev_dc);
953
954     return 0;
955 }
956
957 AVCodec ff_msa1_decoder = {
958     .name           = "msa1",
959     .type           = AVMEDIA_TYPE_VIDEO,
960     .id             = CODEC_ID_MSA1,
961     .priv_data_size = sizeof(MSS3Context),
962     .init           = mss3_decode_init,
963     .close          = mss3_decode_end,
964     .decode         = mss3_decode_frame,
965     .capabilities   = CODEC_CAP_DR1,
966     .long_name      = NULL_IF_CONFIG_SMALL("MS ATC Screen"),
967 };