]> git.sesse.net Git - ffmpeg/blob - libavcodec/agm.c
avcodec/agm: add support for higher compression
[ffmpeg] / libavcodec / agm.c
1 /*
2  * Amuse Graphics Movie decoder
3  *
4  * Copyright (c) 2018 Paul B Mahol
5  *
6  * This file is part of FFmpeg.
7  *
8  * FFmpeg is free software; you can redistribute it and/or
9  * modify it under the terms of the GNU Lesser General Public
10  * License as published by the Free Software Foundation; either
11  * version 2.1 of the License, or (at your option) any later version.
12  *
13  * FFmpeg is distributed in the hope that it will be useful,
14  * but WITHOUT ANY WARRANTY; without even the implied warranty of
15  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
16  * Lesser General Public License for more details.
17  *
18  * You should have received a copy of the GNU Lesser General Public
19  * License along with FFmpeg; if not, write to the Free Software
20  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
21  */
22
23 #include <stdio.h>
24 #include <stdlib.h>
25 #include <string.h>
26
27 #define BITSTREAM_READER_LE
28
29 #include "avcodec.h"
30 #include "bytestream.h"
31 #include "copy_block.h"
32 #include "get_bits.h"
33 #include "idctdsp.h"
34 #include "internal.h"
35
36 static const uint8_t unscaled_luma[64] = {
37     16, 11, 10, 16, 24, 40, 51, 61, 12, 12, 14, 19,
38     26, 58, 60, 55, 14, 13, 16, 24, 40, 57, 69, 56,
39     14, 17, 22, 29, 51, 87, 80, 62, 18, 22, 37, 56,
40     68,109,103, 77, 24, 35, 55, 64, 81,104,113, 92,
41     49, 64, 78, 87,103,121,120,101, 72, 92, 95, 98,
42     112,100,103,99
43 };
44
45 static const uint8_t unscaled_chroma[64] = {
46     17, 18, 24, 47, 99, 99, 99, 99, 18, 21, 26, 66,
47     99, 99, 99, 99, 24, 26, 56, 99, 99, 99, 99, 99,
48     47, 66, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99,
49     99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99,
50     99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99,
51     99, 99, 99, 99
52 };
53
54 typedef struct MotionVector {
55     int16_t x, y;
56 } MotionVector;
57
58 typedef struct AGMContext {
59     const AVClass  *class;
60     AVCodecContext *avctx;
61     GetBitContext   gb;
62     GetByteContext  gbyte;
63
64     int key_frame;
65     int bitstream_size;
66     int compression;
67     int blocks_w;
68     int blocks_h;
69     int size[3];
70     int plus;
71     unsigned flags;
72     unsigned fflags;
73
74     uint8_t *output;
75     unsigned output_size;
76
77     MotionVector *mvectors;
78     unsigned      mvectors_size;
79
80     VLC vlc;
81
82     AVFrame *prev_frame;
83
84     int luma_quant_matrix[64];
85     int chroma_quant_matrix[64];
86
87     ScanTable scantable;
88     DECLARE_ALIGNED(32, int16_t, block)[64];
89
90     int16_t *wblocks;
91     unsigned wblocks_size;
92
93     int      *map;
94     unsigned  map_size;
95
96     IDCTDSPContext idsp;
97 } AGMContext;
98
99 static int read_code(GetBitContext *gb, int *oskip, int *level, int *map, int mode)
100 {
101     int len = 0, skip = 0, max;
102
103     if (show_bits(gb, 2)) {
104         switch (show_bits(gb, 4)) {
105         case 1:
106         case 9:
107             len = 1;
108             skip = 3;
109             break;
110         case 2:
111             len = 3;
112             skip = 4;
113             break;
114         case 3:
115             len = 7;
116             skip = 4;
117             break;
118         case 5:
119         case 13:
120             len = 2;
121             skip = 3;
122             break;
123         case 6:
124             len = 4;
125             skip = 4;
126             break;
127         case 7:
128             len = 8;
129             skip = 4;
130             break;
131         case 10:
132             len = 5;
133             skip = 4;
134             break;
135         case 11:
136             len = 9;
137             skip = 4;
138             break;
139         case 14:
140             len = 6;
141             skip = 4;
142             break;
143         case 15:
144             len = ((show_bits(gb, 5) & 0x10) | 0xA0) >> 4;
145             skip = 5;
146             break;
147         default:
148             return AVERROR_INVALIDDATA;
149         }
150
151         skip_bits(gb, skip);
152         *level = get_bits(gb, len);
153         *map = 1;
154         *oskip = 0;
155         max = 1 << (len - 1);
156         if (*level < max)
157             *level = -(max + *level);
158     } else if (show_bits(gb, 3) & 4) {
159         skip_bits(gb, 3);
160         if (mode == 1) {
161             if (show_bits(gb, 4)) {
162                 if (show_bits(gb, 4) == 1) {
163                     skip_bits(gb, 4);
164                     *oskip = get_bits(gb, 16);
165                 } else {
166                     *oskip = get_bits(gb, 4);
167                 }
168             } else {
169                 skip_bits(gb, 4);
170                 *oskip = get_bits(gb, 10);
171             }
172         } else if (mode == 0) {
173             *oskip = get_bits(gb, 10);
174         }
175         *level = 0;
176     } else {
177         skip_bits(gb, 3);
178         if (mode == 0)
179             *oskip = get_bits(gb, 4);
180         else if (mode == 1)
181             *oskip = 0;
182         *level = 0;
183     }
184
185     return 0;
186 }
187
188 static int decode_intra_blocks(AGMContext *s, GetBitContext *gb,
189                                const int *quant_matrix, int *skip, int *dc_level)
190 {
191     const uint8_t *scantable = s->scantable.permutated;
192     int level, ret, map = 0;
193
194     memset(s->wblocks, 0, s->wblocks_size);
195
196     for (int i = 0; i < 64; i++) {
197         int16_t *block = s->wblocks + scantable[i];
198
199         for (int j = 0; j < s->blocks_w;) {
200             if (*skip > 0) {
201                 int rskip;
202
203                 rskip = FFMIN(*skip, s->blocks_w - j);
204                 j += rskip;
205                 if (i == 0) {
206                     for (int k = 0; k < rskip; k++)
207                         block[64 * k] = *dc_level * quant_matrix[0];
208                 }
209                 block += rskip * 64;
210                 *skip -= rskip;
211             } else {
212                 ret = read_code(gb, skip, &level, &map, s->flags & 1);
213                 if (ret < 0)
214                     return ret;
215
216                 if (i == 0)
217                     *dc_level += level;
218
219                 block[0] = (i == 0 ? *dc_level : level) * quant_matrix[i];
220                 block += 64;
221                 j++;
222             }
223         }
224     }
225
226     return 0;
227 }
228
229 static int decode_inter_blocks(AGMContext *s, GetBitContext *gb,
230                                const int *quant_matrix, int *skip,
231                                int *map)
232 {
233     const uint8_t *scantable = s->scantable.permutated;
234     int level, ret;
235
236     memset(s->wblocks, 0, s->wblocks_size);
237     memset(s->map, 0, s->map_size);
238
239     for (int i = 0; i < 64; i++) {
240         int16_t *block = s->wblocks + scantable[i];
241
242         for (int j = 0; j < s->blocks_w;) {
243             if (*skip > 0) {
244                 int rskip;
245
246                 rskip = FFMIN(*skip, s->blocks_w - j);
247                 j += rskip;
248                 block += rskip * 64;
249                 *skip -= rskip;
250             } else {
251                 ret = read_code(gb, skip, &level, &map[j], s->flags & 1);
252                 if (ret < 0)
253                     return ret;
254
255                 block[0] = level * quant_matrix[i];
256                 block += 64;
257                 j++;
258             }
259         }
260     }
261
262     return 0;
263 }
264
265 static int decode_intra_block(AGMContext *s, GetBitContext *gb,
266                               const int *quant_matrix, int *skip, int *dc_level)
267 {
268     const uint8_t *scantable = s->scantable.permutated;
269     const int offset = s->plus ? 0 : 1024;
270     int16_t *block = s->block;
271     int level, ret, map = 0;
272
273     memset(block, 0, sizeof(s->block));
274
275     if (*skip > 0) {
276         (*skip)--;
277     } else {
278         ret = read_code(gb, skip, &level, &map, s->flags & 1);
279         if (ret < 0)
280             return ret;
281         *dc_level += level;
282     }
283     block[scantable[0]] = offset + *dc_level * quant_matrix[0];
284
285     for (int i = 1; i < 64;) {
286         if (*skip > 0) {
287             int rskip;
288
289             rskip = FFMIN(*skip, 64 - i);
290             i += rskip;
291             *skip -= rskip;
292         } else {
293             ret = read_code(gb, skip, &level, &map, s->flags & 1);
294             if (ret < 0)
295                 return ret;
296
297             block[scantable[i]] = level * quant_matrix[i];
298             i++;
299         }
300     }
301
302     return 0;
303 }
304
305 static int decode_intra_plane(AGMContext *s, GetBitContext *gb, int size,
306                               const int *quant_matrix, AVFrame *frame,
307                               int plane)
308 {
309     int ret, skip = 0, dc_level = 0;
310     const int offset = s->plus ? 0 : 1024;
311
312     if ((ret = init_get_bits8(gb, s->gbyte.buffer, size)) < 0)
313         return ret;
314
315     if (s->flags & 1) {
316         av_fast_padded_malloc(&s->wblocks, &s->wblocks_size,
317                               64 * s->blocks_w * sizeof(*s->wblocks));
318         if (!s->wblocks)
319             return AVERROR(ENOMEM);
320
321         for (int y = 0; y < s->blocks_h; y++) {
322             ret = decode_intra_blocks(s, gb, quant_matrix, &skip, &dc_level);
323             if (ret < 0)
324                 return ret;
325
326             for (int x = 0; x < s->blocks_w; x++) {
327                 s->wblocks[64 * x] += offset;
328                 s->idsp.idct_put(frame->data[plane] + (s->blocks_h - 1 - y) * 8 * frame->linesize[plane] + x * 8,
329                                  frame->linesize[plane], s->wblocks + 64 * x);
330             }
331         }
332     } else {
333         for (int y = 0; y < s->blocks_h; y++) {
334             for (int x = 0; x < s->blocks_w; x++) {
335                 ret = decode_intra_block(s, gb, quant_matrix, &skip, &dc_level);
336                 if (ret < 0)
337                     return ret;
338
339                 s->idsp.idct_put(frame->data[plane] + (s->blocks_h - 1 - y) * 8 * frame->linesize[plane] + x * 8,
340                                  frame->linesize[plane], s->block);
341             }
342         }
343     }
344
345     align_get_bits(gb);
346     if (get_bits_left(gb) < 0)
347         av_log(s->avctx, AV_LOG_WARNING, "overread\n");
348     if (get_bits_left(gb) > 0)
349         av_log(s->avctx, AV_LOG_WARNING, "underread: %d\n", get_bits_left(gb));
350
351     return 0;
352 }
353
354 static int decode_inter_block(AGMContext *s, GetBitContext *gb,
355                               const int *quant_matrix, int *skip,
356                               int *map)
357 {
358     const uint8_t *scantable = s->scantable.permutated;
359     int16_t *block = s->block;
360     int level, ret;
361
362     memset(block, 0, sizeof(s->block));
363
364     for (int i = 0; i < 64;) {
365         if (*skip > 0) {
366             int rskip;
367
368             rskip = FFMIN(*skip, 64 - i);
369             i += rskip;
370             *skip -= rskip;
371         } else {
372             ret = read_code(gb, skip, &level, map, s->flags & 1);
373             if (ret < 0)
374                 return ret;
375
376             block[scantable[i]] = level * quant_matrix[i];
377             i++;
378         }
379     }
380
381     return 0;
382 }
383
384 static int decode_inter_plane(AGMContext *s, GetBitContext *gb, int size,
385                               const int *quant_matrix, AVFrame *frame,
386                               AVFrame *prev, int plane)
387 {
388     int ret, skip = 0;
389
390     if ((ret = init_get_bits8(gb, s->gbyte.buffer, size)) < 0)
391         return ret;
392
393     if (s->flags == 3) {
394         av_fast_padded_malloc(&s->wblocks, &s->wblocks_size,
395                               64 * s->blocks_w * sizeof(*s->wblocks));
396         if (!s->wblocks)
397             return AVERROR(ENOMEM);
398
399         av_fast_padded_malloc(&s->map, &s->map_size,
400                               s->blocks_w * sizeof(*s->map));
401         if (!s->map)
402             return AVERROR(ENOMEM);
403
404         for (int y = 0; y < s->blocks_h; y++) {
405             ret = decode_inter_blocks(s, gb, quant_matrix, &skip, s->map);
406             if (ret < 0)
407                 return ret;
408
409             for (int x = 0; x < s->blocks_w; x++) {
410                 int shift = plane == 0;
411                 int mvpos = (y >> shift) * (s->blocks_w >> shift) + (x >> shift);
412                 int orig_mv_x = s->mvectors[mvpos].x;
413                 int mv_x = s->mvectors[mvpos].x / (1 + !shift);
414                 int mv_y = s->mvectors[mvpos].y / (1 + !shift);
415                 int h = s->avctx->coded_height >> !shift;
416                 int w = s->avctx->coded_width  >> !shift;
417                 int map = s->map[x];
418
419                 if (orig_mv_x >= -32) {
420                     if (y * 8 + mv_y < 0 || y * 8 + mv_y >= h ||
421                         x * 8 + mv_x < 0 || x * 8 + mv_x >= w)
422                         return AVERROR_INVALIDDATA;
423
424                     copy_block8(frame->data[plane] + (s->blocks_h - 1 - y) * 8 * frame->linesize[plane] + x * 8,
425                                 prev->data[plane] + ((s->blocks_h - 1 - y) * 8 - mv_y) * prev->linesize[plane] + (x * 8 + mv_x),
426                                 frame->linesize[plane], prev->linesize[plane], 8);
427                     if (map) {
428                         s->idsp.idct(s->wblocks + x * 64);
429                         for (int i = 0; i < 64; i++)
430                             s->wblocks[i + x * 64] = (s->wblocks[i + x * 64] + 1) & 0xFFFC;
431                         s->idsp.add_pixels_clamped(&s->wblocks[x*64], frame->data[plane] + (s->blocks_h - 1 - y) * 8 * frame->linesize[plane] + x * 8,
432                                                    frame->linesize[plane]);
433                     }
434                 } else if (map) {
435                     s->idsp.idct_put(frame->data[plane] + (s->blocks_h - 1 - y) * 8 * frame->linesize[plane] + x * 8,
436                                      frame->linesize[plane], s->wblocks + x * 64);
437                 }
438             }
439         }
440     } else if (s->flags & 2) {
441         for (int y = 0; y < s->blocks_h; y++) {
442             for (int x = 0; x < s->blocks_w; x++) {
443                 int shift = plane == 0;
444                 int mvpos = (y >> shift) * (s->blocks_w >> shift) + (x >> shift);
445                 int orig_mv_x = s->mvectors[mvpos].x;
446                 int mv_x = s->mvectors[mvpos].x / (1 + !shift);
447                 int mv_y = s->mvectors[mvpos].y / (1 + !shift);
448                 int h = s->avctx->coded_height >> !shift;
449                 int w = s->avctx->coded_width  >> !shift;
450                 int map = 0;
451
452                 ret = decode_inter_block(s, gb, quant_matrix, &skip, &map);
453                 if (ret < 0)
454                     return ret;
455
456                 if (orig_mv_x >= -32) {
457                     if (y * 8 + mv_y < 0 || y * 8 + mv_y >= h ||
458                         x * 8 + mv_x < 0 || x * 8 + mv_x >= w)
459                         return AVERROR_INVALIDDATA;
460
461                     copy_block8(frame->data[plane] + (s->blocks_h - 1 - y) * 8 * frame->linesize[plane] + x * 8,
462                                 prev->data[plane] + ((s->blocks_h - 1 - y) * 8 - mv_y) * prev->linesize[plane] + (x * 8 + mv_x),
463                                 frame->linesize[plane], prev->linesize[plane], 8);
464                     if (map) {
465                         s->idsp.idct(s->block);
466                         for (int i = 0; i < 64; i++)
467                             s->block[i] = (s->block[i] + 1) & 0xFFFC;
468                         s->idsp.add_pixels_clamped(s->block, frame->data[plane] + (s->blocks_h - 1 - y) * 8 * frame->linesize[plane] + x * 8,
469                                                    frame->linesize[plane]);
470                     }
471                 } else if (map) {
472                     s->idsp.idct_put(frame->data[plane] + (s->blocks_h - 1 - y) * 8 * frame->linesize[plane] + x * 8,
473                                      frame->linesize[plane], s->block);
474                 }
475             }
476         }
477     } else if (s->flags & 1) {
478         av_fast_padded_malloc(&s->wblocks, &s->wblocks_size,
479                               64 * s->blocks_w * sizeof(*s->wblocks));
480         if (!s->wblocks)
481             return AVERROR(ENOMEM);
482
483         av_fast_padded_malloc(&s->map, &s->map_size,
484                               s->blocks_w * sizeof(*s->map));
485         if (!s->map)
486             return AVERROR(ENOMEM);
487
488         for (int y = 0; y < s->blocks_h; y++) {
489             ret = decode_inter_blocks(s, gb, quant_matrix, &skip, s->map);
490             if (ret < 0)
491                 return ret;
492
493             for (int x = 0; x < s->blocks_w; x++) {
494                 if (!s->map[x])
495                     continue;
496                 s->idsp.idct_add(frame->data[plane] + (s->blocks_h - 1 - y) * 8 * frame->linesize[plane] + x * 8,
497                                  frame->linesize[plane], s->wblocks + 64 * x);
498             }
499         }
500     } else {
501         for (int y = 0; y < s->blocks_h; y++) {
502             for (int x = 0; x < s->blocks_w; x++) {
503                 int map = 0;
504
505                 ret = decode_inter_block(s, gb, quant_matrix, &skip, &map);
506                 if (ret < 0)
507                     return ret;
508
509                 if (!map)
510                     continue;
511                 s->idsp.idct_add(frame->data[plane] + (s->blocks_h - 1 - y) * 8 * frame->linesize[plane] + x * 8,
512                                  frame->linesize[plane], s->block);
513             }
514         }
515     }
516
517     align_get_bits(gb);
518     if (get_bits_left(gb) < 0)
519         av_log(s->avctx, AV_LOG_WARNING, "overread\n");
520     if (get_bits_left(gb) > 0)
521         av_log(s->avctx, AV_LOG_WARNING, "underread: %d\n", get_bits_left(gb));
522
523     return 0;
524 }
525
526 static void compute_quant_matrix(AGMContext *s, double qscale)
527 {
528     int luma[64], chroma[64];
529     double f = 1.0 - fabs(qscale);
530
531     if (!s->key_frame && (s->flags & 2)) {
532         if (qscale >= 0.0) {
533             for (int i = 0; i < 64; i++) {
534                 luma[i]   = FFMAX(1, 16 * f);
535                 chroma[i] = FFMAX(1, 16 * f);
536             }
537         } else {
538             for (int i = 0; i < 64; i++) {
539                 luma[i]   = FFMAX(1, 16 - qscale * 32);
540                 chroma[i] = FFMAX(1, 16 - qscale * 32);
541             }
542         }
543     } else {
544         if (qscale >= 0.0) {
545             for (int i = 0; i < 64; i++) {
546                 luma[i]   = FFMAX(1, unscaled_luma  [(i & 7) * 8 + (i >> 3)] * f);
547                 chroma[i] = FFMAX(1, unscaled_chroma[(i & 7) * 8 + (i >> 3)] * f);
548             }
549         } else {
550             for (int i = 0; i < 64; i++) {
551                 luma[i]   = FFMAX(1, 255.0 - (255 - unscaled_luma  [(i & 7) * 8 + (i >> 3)]) * f);
552                 chroma[i] = FFMAX(1, 255.0 - (255 - unscaled_chroma[(i & 7) * 8 + (i >> 3)]) * f);
553             }
554         }
555     }
556
557     for (int i = 0; i < 64; i++) {
558         int pos = ff_zigzag_direct[i];
559
560         s->luma_quant_matrix[i]   = luma[pos]   * ((pos / 8) & 1 ? -1 : 1);
561         s->chroma_quant_matrix[i] = chroma[pos] * ((pos / 8) & 1 ? -1 : 1);
562     }
563 }
564
565 static int decode_intra(AVCodecContext *avctx, GetBitContext *gb, AVFrame *frame)
566 {
567     AGMContext *s = avctx->priv_data;
568     int ret;
569
570     compute_quant_matrix(s, (2 * s->compression - 100) / 100.0);
571
572     s->blocks_w = avctx->coded_width  >> 3;
573     s->blocks_h = avctx->coded_height >> 3;
574
575     ret = decode_intra_plane(s, gb, s->size[0], s->luma_quant_matrix, frame, 0);
576     if (ret < 0)
577         return ret;
578
579     bytestream2_skip(&s->gbyte, s->size[0]);
580
581     s->blocks_w = avctx->coded_width  >> 4;
582     s->blocks_h = avctx->coded_height >> 4;
583
584     ret = decode_intra_plane(s, gb, s->size[1], s->chroma_quant_matrix, frame, 2);
585     if (ret < 0)
586         return ret;
587
588     bytestream2_skip(&s->gbyte, s->size[1]);
589
590     s->blocks_w = avctx->coded_width  >> 4;
591     s->blocks_h = avctx->coded_height >> 4;
592
593     ret = decode_intra_plane(s, gb, s->size[2], s->chroma_quant_matrix, frame, 1);
594     if (ret < 0)
595         return ret;
596
597     return 0;
598 }
599
600 static int decode_motion_vectors(AVCodecContext *avctx, GetBitContext *gb)
601 {
602     AGMContext *s = avctx->priv_data;
603     int nb_mvs = ((avctx->height + 15) >> 4) * ((avctx->width + 15) >> 4);
604     int ret, skip = 0, value, map;
605
606     av_fast_padded_malloc(&s->mvectors, &s->mvectors_size,
607                           nb_mvs * sizeof(*s->mvectors));
608     if (!s->mvectors)
609         return AVERROR(ENOMEM);
610
611     if ((ret = init_get_bits8(gb, s->gbyte.buffer, bytestream2_get_bytes_left(&s->gbyte) -
612                                                    (s->size[0] + s->size[1] + s->size[2]))) < 0)
613         return ret;
614
615     memset(s->mvectors, 0, sizeof(*s->mvectors) * nb_mvs);
616
617     for (int i = 0; i < nb_mvs; i++) {
618         ret = read_code(gb, &skip, &value, &map, 1);
619         if (ret < 0)
620             return ret;
621         s->mvectors[i].x = value;
622         i += skip;
623     }
624
625     for (int i = 0; i < nb_mvs; i++) {
626         ret = read_code(gb, &skip, &value, &map, 1);
627         if (ret < 0)
628             return ret;
629         s->mvectors[i].y = value;
630         i += skip;
631     }
632
633     if (get_bits_left(gb) <= 0)
634         return AVERROR_INVALIDDATA;
635     skip = (get_bits_count(gb) >> 3) + 1;
636     bytestream2_skip(&s->gbyte, skip);
637
638     return 0;
639 }
640
641 static int decode_inter(AVCodecContext *avctx, GetBitContext *gb,
642                         AVFrame *frame, AVFrame *prev)
643 {
644     AGMContext *s = avctx->priv_data;
645     int ret;
646
647     compute_quant_matrix(s, (2 * s->compression - 100) / 100.0);
648
649     if (s->flags & 2) {
650         ret = decode_motion_vectors(avctx, gb);
651         if (ret < 0)
652             return ret;
653     }
654
655     s->blocks_w = avctx->coded_width  >> 3;
656     s->blocks_h = avctx->coded_height >> 3;
657
658     ret = decode_inter_plane(s, gb, s->size[0], s->luma_quant_matrix, frame, prev, 0);
659     if (ret < 0)
660         return ret;
661
662     bytestream2_skip(&s->gbyte, s->size[0]);
663
664     s->blocks_w = avctx->coded_width  >> 4;
665     s->blocks_h = avctx->coded_height >> 4;
666
667     ret = decode_inter_plane(s, gb, s->size[1], s->chroma_quant_matrix, frame, prev, 2);
668     if (ret < 0)
669         return ret;
670
671     bytestream2_skip(&s->gbyte, s->size[1]);
672
673     s->blocks_w = avctx->coded_width  >> 4;
674     s->blocks_h = avctx->coded_height >> 4;
675
676     ret = decode_inter_plane(s, gb, s->size[2], s->chroma_quant_matrix, frame, prev, 1);
677     if (ret < 0)
678         return ret;
679
680     return 0;
681 }
682
683 typedef struct Node {
684     int parent;
685     int child[2];
686 } Node;
687
688 static void get_tree_codes(uint32_t *codes, Node *nodes, int idx, uint32_t pfx, int bitpos)
689 {
690     if (idx < 256 && idx >= 0) {
691         codes[idx] = pfx;
692     } else {
693         get_tree_codes(codes, nodes, nodes[idx].child[0], pfx + (0 << bitpos), bitpos + 1);
694         get_tree_codes(codes, nodes, nodes[idx].child[1], pfx + (1 << bitpos), bitpos + 1);
695     }
696 }
697
698 static void make_new_tree(const uint8_t *bitlens, uint32_t *codes)
699 {
700     int zlcount = 0, curlen, idx, nindex, last, llast;
701     int blcounts[32] = { 0 };
702     int syms[8192];
703     Node nodes[512];
704     int node_idx[1024];
705     int old_idx[512];
706
707     for (int i = 0; i < 256; i++) {
708         int bitlen = bitlens[i];
709         int blcount = blcounts[bitlen];
710
711         zlcount += bitlen < 1;
712         syms[(bitlen << 8) + blcount] = i;
713         blcounts[bitlen]++;
714     }
715
716     for (int i = 0; i < 512; i++) {
717         nodes[i].child[0] = -1;
718         nodes[i].child[1] = -1;
719     }
720
721     for (int i = 0; i < 256; i++) {
722         node_idx[i] = 257 + i;;
723     }
724
725     curlen = 1;
726     node_idx[512] = 256;
727     last = 255;
728     nindex = 1;
729
730     for (curlen = 1; curlen < 32; curlen++) {
731         if (blcounts[curlen] > 0) {
732             int max_zlcount = zlcount + blcounts[curlen];
733
734             for (int i = 0; zlcount < 256 && zlcount < max_zlcount; zlcount++, i++) {
735                 int p = node_idx[nindex - 1 + 512];
736                 int ch = syms[256 * curlen + i];
737
738                 if (nodes[p].child[0] == -1) {
739                     nodes[p].child[0] = ch;
740                 } else {
741                     nodes[p].child[1] = ch;
742                     nindex--;
743                 }
744                 nodes[ch].parent = p;
745             }
746         }
747         llast = last - 1;
748         idx = 0;
749         while (nindex > 0) {
750             int p, ch;
751
752             last = llast - idx;
753             p = node_idx[nindex - 1 + 512];
754             ch = node_idx[last];
755             if (nodes[p].child[0] == -1) {
756                 nodes[p].child[0] = ch;
757             } else {
758                 nodes[p].child[1] = ch;
759                 nindex--;
760             }
761             old_idx[idx] = ch;
762             nodes[ch].parent = p;
763             if (idx == llast)
764                 goto next;
765             idx++;
766             if (nindex <= 0) {
767                 for (int i = 0; i < idx; i++)
768                     node_idx[512 + i] = old_idx[i];
769             }
770         }
771         nindex = idx;
772     }
773
774 next:
775
776     get_tree_codes(codes, nodes, 256, 0, 0);
777 }
778
779 static int build_huff(const uint8_t *bitlen, VLC *vlc)
780 {
781     uint32_t new_codes[256];
782     uint8_t bits[256];
783     uint8_t symbols[256];
784     uint32_t codes[256];
785     int nb_codes = 0;
786
787     make_new_tree(bitlen, new_codes);
788
789     for (int i = 0; i < 256; i++) {
790         if (bitlen[i]) {
791             bits[nb_codes] = bitlen[i];
792             codes[nb_codes] = new_codes[i];
793             symbols[nb_codes] = i;
794             nb_codes++;
795         }
796     }
797
798     ff_free_vlc(vlc);
799     return ff_init_vlc_sparse(vlc, 13, nb_codes,
800                               bits, 1, 1,
801                               codes, 4, 4,
802                               symbols, 1, 1,
803                               INIT_VLC_LE);
804 }
805
806 static int decode_huffman2(AVCodecContext *avctx, int header, int size)
807 {
808     AGMContext *s = avctx->priv_data;
809     GetBitContext *gb = &s->gb;
810     uint8_t lens[256];
811     unsigned output_size;
812     int ret, x, len;
813
814     if ((ret = init_get_bits8(gb, s->gbyte.buffer,
815                               bytestream2_get_bytes_left(&s->gbyte))) < 0)
816         return ret;
817
818     output_size = get_bits_long(gb, 32);
819
820     av_fast_padded_malloc(&s->output, &s->output_size, output_size);
821     if (!s->output)
822         return AVERROR(ENOMEM);
823
824     x = get_bits(gb, 1);
825     len = 4 + get_bits(gb, 1);
826     if (x) {
827         int cb[8] = { 0 };
828         int count = get_bits(gb, 3) + 1;
829
830         for (int i = 0; i < count; i++)
831             cb[i] = get_bits(gb, len);
832
833         for (int i = 0; i < 256; i++) {
834             int idx = get_bits(gb, 3);
835             lens[i] = cb[idx];
836         }
837     } else {
838         for (int i = 0; i < 256; i++)
839             lens[i] = get_bits(gb, len);
840     }
841
842     if ((ret = build_huff(lens, &s->vlc)) < 0)
843         return ret;
844
845     x = 0;
846     while (get_bits_left(gb) > 0 && x < output_size) {
847         int val = get_vlc2(gb, s->vlc.table, s->vlc.bits, 3);
848         if (val < 0)
849             return AVERROR_INVALIDDATA;
850         s->output[x++] = val;
851     }
852
853     return 0;
854 }
855
856 static int decode_frame(AVCodecContext *avctx, void *data,
857                         int *got_frame, AVPacket *avpkt)
858 {
859     AGMContext *s = avctx->priv_data;
860     GetBitContext *gb = &s->gb;
861     GetByteContext *gbyte = &s->gbyte;
862     AVFrame *frame = data;
863     int w, h, width, height, header;
864     unsigned compressed_size;
865     int ret;
866
867     if (!avpkt->size)
868         return 0;
869
870     bytestream2_init(gbyte, avpkt->data, avpkt->size);
871
872     header = bytestream2_get_le32(gbyte);
873     s->fflags = bytestream2_get_le32(gbyte);
874     s->bitstream_size = s->fflags & 0x1FFFFFFF;
875     s->fflags >>= 29;
876     av_log(avctx, AV_LOG_DEBUG, "fflags: %X\n", s->fflags);
877     if (avpkt->size < s->bitstream_size + 8)
878         return AVERROR_INVALIDDATA;
879
880     s->key_frame = avpkt->flags & AV_PKT_FLAG_KEY;
881     frame->key_frame = s->key_frame;
882     frame->pict_type = s->key_frame ? AV_PICTURE_TYPE_I : AV_PICTURE_TYPE_P;
883
884     if (header) {
885         if (avctx->codec_tag == MKTAG('A', 'G', 'M', '0') ||
886             avctx->codec_tag == MKTAG('A', 'G', 'M', '1'))
887             return AVERROR_PATCHWELCOME;
888         else
889             ret = decode_huffman2(avctx, header, (avpkt->size - s->bitstream_size) - 8);
890         if (ret < 0)
891             return ret;
892         bytestream2_init(gbyte, s->output, s->output_size);
893     }
894
895     s->flags = 0;
896     w = bytestream2_get_le32(gbyte);
897     h = bytestream2_get_le32(gbyte);
898     if (w == INT32_MIN || h == INT32_MIN)
899         return AVERROR_INVALIDDATA;
900     if (w < 0) {
901         w = -w;
902         s->flags |= 2;
903     }
904     if (h < 0) {
905         h = -h;
906         s->flags |= 1;
907     }
908
909     width  = avctx->width;
910     height = avctx->height;
911     if (w < width || h < height || w & 7 || h & 7)
912         return AVERROR_INVALIDDATA;
913
914     ret = ff_set_dimensions(avctx, w, h);
915     if (ret < 0)
916         return ret;
917     avctx->width = width;
918     avctx->height = height;
919
920     s->compression = bytestream2_get_le32(gbyte);
921     if (s->compression < 0 || s->compression > 100)
922         return AVERROR_INVALIDDATA;
923
924     for (int i = 0; i < 3; i++)
925         s->size[i] = bytestream2_get_le32(gbyte);
926     if (header)
927         compressed_size = s->output_size;
928     else
929         compressed_size = avpkt->size;
930     if (s->size[0] < 0 || s->size[1] < 0 || s->size[2] < 0 ||
931         32LL + s->size[0] + s->size[1] + s->size[2] > compressed_size) {
932         return AVERROR_INVALIDDATA;
933     }
934
935     if ((ret = ff_get_buffer(avctx, frame, AV_GET_BUFFER_FLAG_REF)) < 0)
936         return ret;
937
938     if (frame->key_frame) {
939         ret = decode_intra(avctx, gb, frame);
940     } else {
941         if (!s->prev_frame->data[0]) {
942             av_log(avctx, AV_LOG_ERROR, "Missing reference frame.\n");
943             return AVERROR_INVALIDDATA;
944         }
945
946         if (!(s->flags & 2)) {
947             ret = av_frame_copy(frame, s->prev_frame);
948             if (ret < 0)
949                 return ret;
950         }
951
952         ret = decode_inter(avctx, gb, frame, s->prev_frame);
953     }
954     if (ret < 0)
955         return ret;
956
957     av_frame_unref(s->prev_frame);
958     if ((ret = av_frame_ref(s->prev_frame, frame)) < 0)
959         return ret;
960
961     frame->crop_top  = avctx->coded_height - avctx->height;
962     frame->crop_left = avctx->coded_width  - avctx->width;
963
964     *got_frame = 1;
965
966     return avpkt->size;
967 }
968
969 static av_cold int decode_init(AVCodecContext *avctx)
970 {
971     AGMContext *s = avctx->priv_data;
972
973     avctx->pix_fmt = AV_PIX_FMT_YUV420P;
974     s->avctx = avctx;
975     s->plus = avctx->codec_tag == MKTAG('A', 'G', 'M', '3') ||
976               avctx->codec_tag == MKTAG('A', 'G', 'M', '7');
977
978     avctx->idct_algo = FF_IDCT_SIMPLE;
979     ff_idctdsp_init(&s->idsp, avctx);
980     ff_init_scantable(s->idsp.idct_permutation, &s->scantable, ff_zigzag_direct);
981
982     s->prev_frame = av_frame_alloc();
983     if (!s->prev_frame)
984         return AVERROR(ENOMEM);
985
986     return 0;
987 }
988
989 static void decode_flush(AVCodecContext *avctx)
990 {
991     AGMContext *s = avctx->priv_data;
992
993     av_frame_unref(s->prev_frame);
994 }
995
996 static av_cold int decode_close(AVCodecContext *avctx)
997 {
998     AGMContext *s = avctx->priv_data;
999
1000     ff_free_vlc(&s->vlc);
1001     av_frame_free(&s->prev_frame);
1002     av_freep(&s->mvectors);
1003     s->mvectors_size = 0;
1004     av_freep(&s->wblocks);
1005     s->wblocks_size = 0;
1006     av_freep(&s->output);
1007     s->output_size = 0;
1008     av_freep(&s->map);
1009     s->map_size = 0;
1010
1011     return 0;
1012 }
1013
1014 AVCodec ff_agm_decoder = {
1015     .name             = "agm",
1016     .long_name        = NULL_IF_CONFIG_SMALL("Amuse Graphics Movie"),
1017     .type             = AVMEDIA_TYPE_VIDEO,
1018     .id               = AV_CODEC_ID_AGM,
1019     .priv_data_size   = sizeof(AGMContext),
1020     .init             = decode_init,
1021     .close            = decode_close,
1022     .decode           = decode_frame,
1023     .flush            = decode_flush,
1024     .capabilities     = AV_CODEC_CAP_DR1,
1025     .caps_internal    = FF_CODEC_CAP_INIT_THREADSAFE |
1026                         FF_CODEC_CAP_INIT_CLEANUP |
1027                         FF_CODEC_CAP_EXPORTS_CROPPING,
1028 };