]> git.sesse.net Git - ffmpeg/blob - libavfilter/dnn/dnn_backend_openvino.c
dnn: move output name from DNNModel.set_input_output to DNNModule.execute_model
[ffmpeg] / libavfilter / dnn / dnn_backend_openvino.c
1 /*
2  * Copyright (c) 2020
3  *
4  * This file is part of FFmpeg.
5  *
6  * FFmpeg is free software; you can redistribute it and/or
7  * modify it under the terms of the GNU Lesser General Public
8  * License as published by the Free Software Foundation; either
9  * version 2.1 of the License, or (at your option) any later version.
10  *
11  * FFmpeg is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14  * Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public
17  * License along with FFmpeg; if not, write to the Free Software
18  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
19  */
20
21 /**
22  * @file
23  * DNN OpenVINO backend implementation.
24  */
25
26 #include "dnn_backend_openvino.h"
27 #include "libavformat/avio.h"
28 #include "libavutil/avassert.h"
29 #include <c_api/ie_c_api.h>
30
31 typedef struct OVModel{
32     ie_core_t *core;
33     ie_network_t *network;
34     ie_executable_network_t *exe_network;
35     ie_infer_request_t *infer_request;
36     ie_blob_t *input_blob;
37 } OVModel;
38
39 static DNNDataType precision_to_datatype(precision_e precision)
40 {
41     switch (precision)
42     {
43     case FP32:
44         return DNN_FLOAT;
45     default:
46         av_assert0(!"not supported yet.");
47         return DNN_FLOAT;
48     }
49 }
50
51 static DNNReturnType get_input_ov(void *model, DNNData *input, const char *input_name)
52 {
53     OVModel *ov_model = (OVModel *)model;
54     char *model_input_name = NULL;
55     IEStatusCode status;
56     size_t model_input_count = 0;
57     dimensions_t dims;
58     precision_e precision;
59
60     status = ie_network_get_inputs_number(ov_model->network, &model_input_count);
61     if (status != OK)
62         return DNN_ERROR;
63
64     for (size_t i = 0; i < model_input_count; i++) {
65         status = ie_network_get_input_name(ov_model->network, i, &model_input_name);
66         if (status != OK)
67             return DNN_ERROR;
68         if (strcmp(model_input_name, input_name) == 0) {
69             ie_network_name_free(&model_input_name);
70             status |= ie_network_get_input_dims(ov_model->network, input_name, &dims);
71             status |= ie_network_get_input_precision(ov_model->network, input_name, &precision);
72             if (status != OK)
73                 return DNN_ERROR;
74
75             // The order of dims in the openvino is fixed and it is always NCHW for 4-D data.
76             // while we pass NHWC data from FFmpeg to openvino
77             status = ie_network_set_input_layout(ov_model->network, input_name, NHWC);
78             if (status != OK)
79                 return DNN_ERROR;
80
81             input->channels = dims.dims[1];
82             input->height   = dims.dims[2];
83             input->width    = dims.dims[3];
84             input->dt       = precision_to_datatype(precision);
85             return DNN_SUCCESS;
86         }
87
88         ie_network_name_free(&model_input_name);
89     }
90
91     return DNN_ERROR;
92 }
93
94 static DNNReturnType set_input_ov(void *model, DNNData *input, const char *input_name)
95 {
96     OVModel *ov_model = (OVModel *)model;
97     IEStatusCode status;
98     dimensions_t dims;
99     precision_e precision;
100     ie_blob_buffer_t blob_buffer;
101
102     status = ie_exec_network_create_infer_request(ov_model->exe_network, &ov_model->infer_request);
103     if (status != OK)
104         goto err;
105
106     status = ie_infer_request_get_blob(ov_model->infer_request, input_name, &ov_model->input_blob);
107     if (status != OK)
108         goto err;
109
110     status |= ie_blob_get_dims(ov_model->input_blob, &dims);
111     status |= ie_blob_get_precision(ov_model->input_blob, &precision);
112     if (status != OK)
113         goto err;
114
115     av_assert0(input->channels == dims.dims[1]);
116     av_assert0(input->height   == dims.dims[2]);
117     av_assert0(input->width    == dims.dims[3]);
118     av_assert0(input->dt       == precision_to_datatype(precision));
119
120     status = ie_blob_get_buffer(ov_model->input_blob, &blob_buffer);
121     if (status != OK)
122         goto err;
123     input->data = blob_buffer.buffer;
124
125     return DNN_SUCCESS;
126
127 err:
128     if (ov_model->input_blob)
129         ie_blob_free(&ov_model->input_blob);
130     if (ov_model->infer_request)
131         ie_infer_request_free(&ov_model->infer_request);
132     return DNN_ERROR;
133 }
134
135 DNNModel *ff_dnn_load_model_ov(const char *model_filename, const char *options)
136 {
137     DNNModel *model = NULL;
138     OVModel *ov_model = NULL;
139     IEStatusCode status;
140     ie_config_t config = {NULL, NULL, NULL};
141
142     model = av_malloc(sizeof(DNNModel));
143     if (!model){
144         return NULL;
145     }
146
147     ov_model = av_mallocz(sizeof(OVModel));
148     if (!ov_model)
149         goto err;
150
151     status = ie_core_create("", &ov_model->core);
152     if (status != OK)
153         goto err;
154
155     status = ie_core_read_network(ov_model->core, model_filename, NULL, &ov_model->network);
156     if (status != OK)
157         goto err;
158
159     status = ie_core_load_network(ov_model->core, ov_model->network, "CPU", &config, &ov_model->exe_network);
160     if (status != OK)
161         goto err;
162
163     model->model = (void *)ov_model;
164     model->set_input = &set_input_ov;
165     model->get_input = &get_input_ov;
166     model->options = options;
167
168     return model;
169
170 err:
171     if (model)
172         av_freep(&model);
173     if (ov_model) {
174         if (ov_model->exe_network)
175             ie_exec_network_free(&ov_model->exe_network);
176         if (ov_model->network)
177             ie_network_free(&ov_model->network);
178         if (ov_model->core)
179             ie_core_free(&ov_model->core);
180         av_freep(&ov_model);
181     }
182     return NULL;
183 }
184
185 DNNReturnType ff_dnn_execute_model_ov(const DNNModel *model, DNNData *outputs, const char **output_names, uint32_t nb_output)
186 {
187     dimensions_t dims;
188     precision_e precision;
189     ie_blob_buffer_t blob_buffer;
190     OVModel *ov_model = (OVModel *)model->model;
191     IEStatusCode status = ie_infer_request_infer(ov_model->infer_request);
192     if (status != OK)
193         return DNN_ERROR;
194
195     for (uint32_t i = 0; i < nb_output; ++i) {
196         const char *output_name = output_names[i];
197         ie_blob_t *output_blob = NULL;
198         status = ie_infer_request_get_blob(ov_model->infer_request, output_name, &output_blob);
199         if (status != OK)
200             return DNN_ERROR;
201
202         status = ie_blob_get_buffer(output_blob, &blob_buffer);
203         if (status != OK)
204             return DNN_ERROR;
205
206         status |= ie_blob_get_dims(output_blob, &dims);
207         status |= ie_blob_get_precision(output_blob, &precision);
208         if (status != OK)
209             return DNN_ERROR;
210
211         outputs[i].channels = dims.dims[1];
212         outputs[i].height   = dims.dims[2];
213         outputs[i].width    = dims.dims[3];
214         outputs[i].dt       = precision_to_datatype(precision);
215         outputs[i].data     = blob_buffer.buffer;
216     }
217
218     return DNN_SUCCESS;
219 }
220
221 void ff_dnn_free_model_ov(DNNModel **model)
222 {
223     if (*model){
224         OVModel *ov_model = (OVModel *)(*model)->model;
225         if (ov_model->input_blob)
226             ie_blob_free(&ov_model->input_blob);
227         if (ov_model->infer_request)
228             ie_infer_request_free(&ov_model->infer_request);
229         if (ov_model->exe_network)
230             ie_exec_network_free(&ov_model->exe_network);
231         if (ov_model->network)
232             ie_network_free(&ov_model->network);
233         if (ov_model->core)
234             ie_core_free(&ov_model->core);
235         av_freep(&ov_model);
236         av_freep(model);
237     }
238 }