]> git.sesse.net Git - ffmpeg/blob - libavfilter/dnn/dnn_backend_openvino.c
5d6d3ed542fc07890315c13262b06c7f63489f28
[ffmpeg] / libavfilter / dnn / dnn_backend_openvino.c
1 /*
2  * Copyright (c) 2020
3  *
4  * This file is part of FFmpeg.
5  *
6  * FFmpeg is free software; you can redistribute it and/or
7  * modify it under the terms of the GNU Lesser General Public
8  * License as published by the Free Software Foundation; either
9  * version 2.1 of the License, or (at your option) any later version.
10  *
11  * FFmpeg is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14  * Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public
17  * License along with FFmpeg; if not, write to the Free Software
18  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
19  */
20
21 /**
22  * @file
23  * DNN OpenVINO backend implementation.
24  */
25
26 #include "dnn_backend_openvino.h"
27 #include "libavformat/avio.h"
28 #include "libavutil/avassert.h"
29 #include <c_api/ie_c_api.h>
30
31 typedef struct OVContext {
32     const AVClass *class;
33 } OVContext;
34
35 typedef struct OVModel{
36     OVContext ctx;
37     ie_core_t *core;
38     ie_network_t *network;
39     ie_executable_network_t *exe_network;
40     ie_infer_request_t *infer_request;
41     ie_blob_t *input_blob;
42 } OVModel;
43
44 static const AVClass dnn_openvino_class = {
45     .class_name = "dnn_openvino",
46     .item_name  = av_default_item_name,
47     .option     = NULL,
48     .version    = LIBAVUTIL_VERSION_INT,
49     .category   = AV_CLASS_CATEGORY_FILTER,
50 };
51
52 static DNNDataType precision_to_datatype(precision_e precision)
53 {
54     switch (precision)
55     {
56     case FP32:
57         return DNN_FLOAT;
58     default:
59         av_assert0(!"not supported yet.");
60         return DNN_FLOAT;
61     }
62 }
63
64 static DNNReturnType get_input_ov(void *model, DNNData *input, const char *input_name)
65 {
66     OVModel *ov_model = (OVModel *)model;
67     OVContext *ctx = &ov_model->ctx;
68     char *model_input_name = NULL;
69     IEStatusCode status;
70     size_t model_input_count = 0;
71     dimensions_t dims;
72     precision_e precision;
73
74     status = ie_network_get_inputs_number(ov_model->network, &model_input_count);
75     if (status != OK) {
76         av_log(ctx, AV_LOG_ERROR, "Failed to get input count\n");
77         return DNN_ERROR;
78     }
79
80     for (size_t i = 0; i < model_input_count; i++) {
81         status = ie_network_get_input_name(ov_model->network, i, &model_input_name);
82         if (status != OK) {
83             av_log(ctx, AV_LOG_ERROR, "Failed to get No.%d input's name\n", (int)i);
84             return DNN_ERROR;
85         }
86         if (strcmp(model_input_name, input_name) == 0) {
87             ie_network_name_free(&model_input_name);
88             status |= ie_network_get_input_dims(ov_model->network, input_name, &dims);
89             status |= ie_network_get_input_precision(ov_model->network, input_name, &precision);
90             if (status != OK) {
91                 av_log(ctx, AV_LOG_ERROR, "Failed to get No.%d input's dims or precision\n", (int)i);
92                 return DNN_ERROR;
93             }
94
95             // The order of dims in the openvino is fixed and it is always NCHW for 4-D data.
96             // while we pass NHWC data from FFmpeg to openvino
97             status = ie_network_set_input_layout(ov_model->network, input_name, NHWC);
98             if (status != OK) {
99                 av_log(ctx, AV_LOG_ERROR, "Input \"%s\" does not match layout NHWC\n", input_name);
100                 return DNN_ERROR;
101             }
102
103             input->channels = dims.dims[1];
104             input->height   = dims.dims[2];
105             input->width    = dims.dims[3];
106             input->dt       = precision_to_datatype(precision);
107             return DNN_SUCCESS;
108         }
109
110         ie_network_name_free(&model_input_name);
111     }
112
113     av_log(ctx, AV_LOG_ERROR, "Could not find \"%s\" in model\n", model_input_name);
114     return DNN_ERROR;
115 }
116
117 static DNNReturnType set_input_ov(void *model, DNNData *input, const char *input_name)
118 {
119     OVModel *ov_model = (OVModel *)model;
120     OVContext *ctx = &ov_model->ctx;
121     IEStatusCode status;
122     dimensions_t dims;
123     precision_e precision;
124     ie_blob_buffer_t blob_buffer;
125
126     status = ie_exec_network_create_infer_request(ov_model->exe_network, &ov_model->infer_request);
127     if (status != OK)
128         goto err;
129
130     status = ie_infer_request_get_blob(ov_model->infer_request, input_name, &ov_model->input_blob);
131     if (status != OK)
132         goto err;
133
134     status |= ie_blob_get_dims(ov_model->input_blob, &dims);
135     status |= ie_blob_get_precision(ov_model->input_blob, &precision);
136     if (status != OK)
137         goto err;
138
139     av_assert0(input->channels == dims.dims[1]);
140     av_assert0(input->height   == dims.dims[2]);
141     av_assert0(input->width    == dims.dims[3]);
142     av_assert0(input->dt       == precision_to_datatype(precision));
143
144     status = ie_blob_get_buffer(ov_model->input_blob, &blob_buffer);
145     if (status != OK)
146         goto err;
147     input->data = blob_buffer.buffer;
148
149     return DNN_SUCCESS;
150
151 err:
152     if (ov_model->input_blob)
153         ie_blob_free(&ov_model->input_blob);
154     if (ov_model->infer_request)
155         ie_infer_request_free(&ov_model->infer_request);
156     av_log(ctx, AV_LOG_ERROR, "Failed to create inference instance or get input data/dims/precision/memory\n");
157     return DNN_ERROR;
158 }
159
160 DNNModel *ff_dnn_load_model_ov(const char *model_filename, const char *options)
161 {
162     DNNModel *model = NULL;
163     OVModel *ov_model = NULL;
164     IEStatusCode status;
165     ie_config_t config = {NULL, NULL, NULL};
166
167     model = av_malloc(sizeof(DNNModel));
168     if (!model){
169         return NULL;
170     }
171
172     ov_model = av_mallocz(sizeof(OVModel));
173     if (!ov_model)
174         goto err;
175     ov_model->ctx.class = &dnn_openvino_class;
176
177     status = ie_core_create("", &ov_model->core);
178     if (status != OK)
179         goto err;
180
181     status = ie_core_read_network(ov_model->core, model_filename, NULL, &ov_model->network);
182     if (status != OK)
183         goto err;
184
185     status = ie_core_load_network(ov_model->core, ov_model->network, "CPU", &config, &ov_model->exe_network);
186     if (status != OK)
187         goto err;
188
189     model->model = (void *)ov_model;
190     model->set_input = &set_input_ov;
191     model->get_input = &get_input_ov;
192     model->options = options;
193
194     return model;
195
196 err:
197     if (model)
198         av_freep(&model);
199     if (ov_model) {
200         if (ov_model->exe_network)
201             ie_exec_network_free(&ov_model->exe_network);
202         if (ov_model->network)
203             ie_network_free(&ov_model->network);
204         if (ov_model->core)
205             ie_core_free(&ov_model->core);
206         av_freep(&ov_model);
207     }
208     return NULL;
209 }
210
211 DNNReturnType ff_dnn_execute_model_ov(const DNNModel *model, DNNData *outputs, const char **output_names, uint32_t nb_output)
212 {
213     dimensions_t dims;
214     precision_e precision;
215     ie_blob_buffer_t blob_buffer;
216     OVModel *ov_model = (OVModel *)model->model;
217     OVContext *ctx = &ov_model->ctx;
218     IEStatusCode status = ie_infer_request_infer(ov_model->infer_request);
219     if (status != OK) {
220         av_log(ctx, AV_LOG_ERROR, "Failed to start synchronous model inference\n");
221         return DNN_ERROR;
222     }
223
224     for (uint32_t i = 0; i < nb_output; ++i) {
225         const char *output_name = output_names[i];
226         ie_blob_t *output_blob = NULL;
227         status = ie_infer_request_get_blob(ov_model->infer_request, output_name, &output_blob);
228         if (status != OK) {
229             av_log(ctx, AV_LOG_ERROR, "Failed to get model output data\n");
230             return DNN_ERROR;
231         }
232
233         status = ie_blob_get_buffer(output_blob, &blob_buffer);
234         if (status != OK) {
235             av_log(ctx, AV_LOG_ERROR, "Failed to access output memory\n");
236             return DNN_ERROR;
237         }
238
239         status |= ie_blob_get_dims(output_blob, &dims);
240         status |= ie_blob_get_precision(output_blob, &precision);
241         if (status != OK) {
242             av_log(ctx, AV_LOG_ERROR, "Failed to get dims or precision of output\n");
243             return DNN_ERROR;
244         }
245
246         outputs[i].channels = dims.dims[1];
247         outputs[i].height   = dims.dims[2];
248         outputs[i].width    = dims.dims[3];
249         outputs[i].dt       = precision_to_datatype(precision);
250         outputs[i].data     = blob_buffer.buffer;
251     }
252
253     return DNN_SUCCESS;
254 }
255
256 void ff_dnn_free_model_ov(DNNModel **model)
257 {
258     if (*model){
259         OVModel *ov_model = (OVModel *)(*model)->model;
260         if (ov_model->input_blob)
261             ie_blob_free(&ov_model->input_blob);
262         if (ov_model->infer_request)
263             ie_infer_request_free(&ov_model->infer_request);
264         if (ov_model->exe_network)
265             ie_exec_network_free(&ov_model->exe_network);
266         if (ov_model->network)
267             ie_network_free(&ov_model->network);
268         if (ov_model->core)
269             ie_core_free(&ov_model->core);
270         av_freep(&ov_model);
271         av_freep(model);
272     }
273 }