]> git.sesse.net Git - ffmpeg/blob - libavfilter/dnn/dnn_io_proc.c
dnn: change dnn interface to replace DNNData* with AVFrame*
[ffmpeg] / libavfilter / dnn / dnn_io_proc.c
1 /*
2  * Copyright (c) 2020
3  *
4  * This file is part of FFmpeg.
5  *
6  * FFmpeg is free software; you can redistribute it and/or
7  * modify it under the terms of the GNU Lesser General Public
8  * License as published by the Free Software Foundation; either
9  * version 2.1 of the License, or (at your option) any later version.
10  *
11  * FFmpeg is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14  * Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public
17  * License along with FFmpeg; if not, write to the Free Software
18  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
19  */
20
21 #include "dnn_io_proc.h"
22 #include "libavutil/imgutils.h"
23 #include "libswscale/swscale.h"
24
25 DNNReturnType proc_from_dnn_to_frame(AVFrame *frame, DNNData *output, void *log_ctx)
26 {
27     struct SwsContext *sws_ctx;
28     int bytewidth = av_image_get_linesize(frame->format, frame->width, 0);
29     if (output->dt != DNN_FLOAT) {
30         av_log(log_ctx, AV_LOG_ERROR, "do not support data type rather than DNN_FLOAT\n");
31         return DNN_ERROR;
32     }
33
34     switch (frame->format) {
35     case AV_PIX_FMT_RGB24:
36     case AV_PIX_FMT_BGR24:
37         sws_ctx = sws_getContext(frame->width * 3,
38                                  frame->height,
39                                  AV_PIX_FMT_GRAYF32,
40                                  frame->width * 3,
41                                  frame->height,
42                                  AV_PIX_FMT_GRAY8,
43                                  0, NULL, NULL, NULL);
44         sws_scale(sws_ctx, (const uint8_t *[4]){(const uint8_t *)output->data, 0, 0, 0},
45                            (const int[4]){frame->width * 3 * sizeof(float), 0, 0, 0}, 0, frame->height,
46                            (uint8_t * const*)frame->data, frame->linesize);
47         sws_freeContext(sws_ctx);
48         return DNN_SUCCESS;
49     case AV_PIX_FMT_GRAYF32:
50         av_image_copy_plane(frame->data[0], frame->linesize[0],
51                             output->data, bytewidth,
52                             bytewidth, frame->height);
53         return DNN_SUCCESS;
54     case AV_PIX_FMT_YUV420P:
55     case AV_PIX_FMT_YUV422P:
56     case AV_PIX_FMT_YUV444P:
57     case AV_PIX_FMT_YUV410P:
58     case AV_PIX_FMT_YUV411P:
59     case AV_PIX_FMT_GRAY8:
60         sws_ctx = sws_getContext(frame->width,
61                                  frame->height,
62                                  AV_PIX_FMT_GRAYF32,
63                                  frame->width,
64                                  frame->height,
65                                  AV_PIX_FMT_GRAY8,
66                                  0, NULL, NULL, NULL);
67         sws_scale(sws_ctx, (const uint8_t *[4]){(const uint8_t *)output->data, 0, 0, 0},
68                            (const int[4]){frame->width * sizeof(float), 0, 0, 0}, 0, frame->height,
69                            (uint8_t * const*)frame->data, frame->linesize);
70         sws_freeContext(sws_ctx);
71         return DNN_SUCCESS;
72     default:
73         av_log(log_ctx, AV_LOG_ERROR, "do not support frame format %d\n", frame->format);
74         return DNN_ERROR;
75     }
76
77     return DNN_SUCCESS;
78 }
79
80 DNNReturnType proc_from_frame_to_dnn(AVFrame *frame, DNNData *input, void *log_ctx)
81 {
82     struct SwsContext *sws_ctx;
83     int bytewidth = av_image_get_linesize(frame->format, frame->width, 0);
84     if (input->dt != DNN_FLOAT) {
85         av_log(log_ctx, AV_LOG_ERROR, "do not support data type rather than DNN_FLOAT\n");
86         return DNN_ERROR;
87     }
88
89     switch (frame->format) {
90     case AV_PIX_FMT_RGB24:
91     case AV_PIX_FMT_BGR24:
92         sws_ctx = sws_getContext(frame->width * 3,
93                                  frame->height,
94                                  AV_PIX_FMT_GRAY8,
95                                  frame->width * 3,
96                                  frame->height,
97                                  AV_PIX_FMT_GRAYF32,
98                                  0, NULL, NULL, NULL);
99         sws_scale(sws_ctx, (const uint8_t **)frame->data,
100                            frame->linesize, 0, frame->height,
101                            (uint8_t * const*)(&input->data),
102                            (const int [4]){frame->width * 3 * sizeof(float), 0, 0, 0});
103         sws_freeContext(sws_ctx);
104         break;
105     case AV_PIX_FMT_GRAYF32:
106         av_image_copy_plane(input->data, bytewidth,
107                             frame->data[0], frame->linesize[0],
108                             bytewidth, frame->height);
109         break;
110     case AV_PIX_FMT_YUV420P:
111     case AV_PIX_FMT_YUV422P:
112     case AV_PIX_FMT_YUV444P:
113     case AV_PIX_FMT_YUV410P:
114     case AV_PIX_FMT_YUV411P:
115     case AV_PIX_FMT_GRAY8:
116         sws_ctx = sws_getContext(frame->width,
117                                  frame->height,
118                                  AV_PIX_FMT_GRAY8,
119                                  frame->width,
120                                  frame->height,
121                                  AV_PIX_FMT_GRAYF32,
122                                  0, NULL, NULL, NULL);
123         sws_scale(sws_ctx, (const uint8_t **)frame->data,
124                            frame->linesize, 0, frame->height,
125                            (uint8_t * const*)(&input->data),
126                            (const int [4]){frame->width * sizeof(float), 0, 0, 0});
127         sws_freeContext(sws_ctx);
128         break;
129     default:
130         av_log(log_ctx, AV_LOG_ERROR, "do not support frame format %d\n", frame->format);
131         return DNN_ERROR;
132     }
133
134     return DNN_SUCCESS;
135 }