]> git.sesse.net Git - ffmpeg/blob - libavfilter/dnn_backend_tf.c
avformat/mxfdec: add support for opAtom without index
[ffmpeg] / libavfilter / dnn_backend_tf.c
1 /*
2  * Copyright (c) 2018 Sergey Lavrushkin
3  *
4  * This file is part of FFmpeg.
5  *
6  * FFmpeg is free software; you can redistribute it and/or
7  * modify it under the terms of the GNU Lesser General Public
8  * License as published by the Free Software Foundation; either
9  * version 2.1 of the License, or (at your option) any later version.
10  *
11  * FFmpeg is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14  * Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public
17  * License along with FFmpeg; if not, write to the Free Software
18  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
19  */
20
21 /**
22  * @file
23  * DNN tensorflow backend implementation.
24  */
25
26 #include "dnn_backend_tf.h"
27 #include "dnn_srcnn.h"
28 #include "dnn_espcn.h"
29 #include "libavformat/avio.h"
30
31 #include <tensorflow/c/c_api.h>
32
33 typedef struct TFModel{
34     TF_Graph* graph;
35     TF_Session* session;
36     TF_Status* status;
37     TF_Output input, output;
38     TF_Tensor* input_tensor;
39     DNNData* output_data;
40 } TFModel;
41
42 static void free_buffer(void* data, size_t length)
43 {
44     av_freep(&data);
45 }
46
47 static TF_Buffer* read_graph(const char* model_filename)
48 {
49     TF_Buffer* graph_buf;
50     unsigned char* graph_data = NULL;
51     AVIOContext* model_file_context;
52     long size, bytes_read;
53
54     if (avio_open(&model_file_context, model_filename, AVIO_FLAG_READ) < 0){
55         return NULL;
56     }
57
58     size = avio_size(model_file_context);
59
60     graph_data = av_malloc(size);
61     if (!graph_data){
62         avio_closep(&model_file_context);
63         return NULL;
64     }
65     bytes_read = avio_read(model_file_context, graph_data, size);
66     avio_closep(&model_file_context);
67     if (bytes_read != size){
68         av_freep(&graph_data);
69         return NULL;
70     }
71
72     graph_buf = TF_NewBuffer();
73     graph_buf->data = (void*)graph_data;
74     graph_buf->length = size;
75     graph_buf->data_deallocator = free_buffer;
76
77     return graph_buf;
78 }
79
80 static DNNReturnType set_input_output_tf(void* model, DNNData* input, DNNData* output)
81 {
82     TFModel* tf_model = (TFModel*)model;
83     int64_t input_dims[] = {1, input->height, input->width, input->channels};
84     TF_SessionOptions* sess_opts;
85     const TF_Operation* init_op = TF_GraphOperationByName(tf_model->graph, "init");
86     TF_Tensor* output_tensor;
87
88     // Input operation should be named 'x'
89     tf_model->input.oper = TF_GraphOperationByName(tf_model->graph, "x");
90     if (!tf_model->input.oper){
91         return DNN_ERROR;
92     }
93     tf_model->input.index = 0;
94     if (tf_model->input_tensor){
95         TF_DeleteTensor(tf_model->input_tensor);
96     }
97     tf_model->input_tensor = TF_AllocateTensor(TF_FLOAT, input_dims, 4,
98                                                input_dims[1] * input_dims[2] * input_dims[3] * sizeof(float));
99     if (!tf_model->input_tensor){
100         return DNN_ERROR;
101     }
102     input->data = (float*)TF_TensorData(tf_model->input_tensor);
103
104     // Output operation should be named 'y'
105     tf_model->output.oper = TF_GraphOperationByName(tf_model->graph, "y");
106     if (!tf_model->output.oper){
107         return DNN_ERROR;
108     }
109     tf_model->output.index = 0;
110
111     if (tf_model->session){
112         TF_CloseSession(tf_model->session, tf_model->status);
113         TF_DeleteSession(tf_model->session, tf_model->status);
114     }
115
116     sess_opts = TF_NewSessionOptions();
117     tf_model->session = TF_NewSession(tf_model->graph, sess_opts, tf_model->status);
118     TF_DeleteSessionOptions(sess_opts);
119     if (TF_GetCode(tf_model->status) != TF_OK)
120     {
121         return DNN_ERROR;
122     }
123
124     // Run initialization operation with name "init" if it is present in graph
125     if (init_op){
126         TF_SessionRun(tf_model->session, NULL,
127                       NULL, NULL, 0,
128                       NULL, NULL, 0,
129                       &init_op, 1, NULL, tf_model->status);
130         if (TF_GetCode(tf_model->status) != TF_OK)
131         {
132             return DNN_ERROR;
133         }
134     }
135
136     // Execute network to get output height, width and number of channels
137     TF_SessionRun(tf_model->session, NULL,
138                   &tf_model->input, &tf_model->input_tensor, 1,
139                   &tf_model->output, &output_tensor, 1,
140                   NULL, 0, NULL, tf_model->status);
141     if (TF_GetCode(tf_model->status) != TF_OK){
142         return DNN_ERROR;
143     }
144     else{
145         output->height = TF_Dim(output_tensor, 1);
146         output->width = TF_Dim(output_tensor, 2);
147         output->channels = TF_Dim(output_tensor, 3);
148         output->data = av_malloc(output->height * output->width * output->channels * sizeof(float));
149         if (!output->data){
150             return DNN_ERROR;
151         }
152         tf_model->output_data = output;
153         TF_DeleteTensor(output_tensor);
154     }
155
156     return DNN_SUCCESS;
157 }
158
159 DNNModel* ff_dnn_load_model_tf(const char* model_filename)
160 {
161     DNNModel* model = NULL;
162     TFModel* tf_model = NULL;
163     TF_Buffer* graph_def;
164     TF_ImportGraphDefOptions* graph_opts;
165
166     model = av_malloc(sizeof(DNNModel));
167     if (!model){
168         return NULL;
169     }
170
171     tf_model = av_malloc(sizeof(TFModel));
172     if (!tf_model){
173         av_freep(&model);
174         return NULL;
175     }
176     tf_model->session = NULL;
177     tf_model->input_tensor = NULL;
178     tf_model->output_data = NULL;
179
180     graph_def = read_graph(model_filename);
181     if (!graph_def){
182         av_freep(&tf_model);
183         av_freep(&model);
184         return NULL;
185     }
186     tf_model->graph = TF_NewGraph();
187     tf_model->status = TF_NewStatus();
188     graph_opts = TF_NewImportGraphDefOptions();
189     TF_GraphImportGraphDef(tf_model->graph, graph_def, graph_opts, tf_model->status);
190     TF_DeleteImportGraphDefOptions(graph_opts);
191     TF_DeleteBuffer(graph_def);
192     if (TF_GetCode(tf_model->status) != TF_OK){
193         TF_DeleteGraph(tf_model->graph);
194         TF_DeleteStatus(tf_model->status);
195         av_freep(&tf_model);
196         av_freep(&model);
197         return NULL;
198     }
199
200     model->model = (void*)tf_model;
201     model->set_input_output = &set_input_output_tf;
202
203     return model;
204 }
205
206 DNNModel* ff_dnn_load_default_model_tf(DNNDefaultModel model_type)
207 {
208     DNNModel* model = NULL;
209     TFModel* tf_model = NULL;
210     TF_Buffer* graph_def;
211     unsigned char* graph_data = NULL;
212     TF_ImportGraphDefOptions* graph_opts;
213
214     graph_def = TF_NewBuffer();
215     switch (model_type){
216     case DNN_SRCNN:
217         graph_data = av_malloc(srcnn_tf_size);
218         if (!graph_data){
219             TF_DeleteBuffer(graph_def);
220             return NULL;
221         }
222         memcpy(graph_data, srcnn_tf_model, srcnn_tf_size);
223         graph_def->data = (void*)graph_data;
224         graph_def->length = srcnn_tf_size;
225         graph_def->data_deallocator = free_buffer;
226         break;
227     case DNN_ESPCN:
228         graph_data = av_malloc(espcn_tf_size);
229         if (!graph_data){
230             TF_DeleteBuffer(graph_def);
231             return NULL;
232         }
233         memcpy(graph_data, espcn_tf_model, espcn_tf_size);
234         graph_def->data = (void*)graph_data;
235         graph_def->length = espcn_tf_size;
236         graph_def->data_deallocator = free_buffer;
237         break;
238     default:
239         TF_DeleteBuffer(graph_def);
240         return NULL;
241     }
242
243     model = av_malloc(sizeof(DNNModel));
244     if (!model){
245         TF_DeleteBuffer(graph_def);
246         return NULL;
247     }
248
249     tf_model = av_malloc(sizeof(TFModel));
250     if (!tf_model){
251         TF_DeleteBuffer(graph_def);
252         av_freep(&model);
253         return NULL;
254     }
255     tf_model->session = NULL;
256     tf_model->input_tensor = NULL;
257     tf_model->output_data = NULL;
258
259     tf_model->graph = TF_NewGraph();
260     tf_model->status = TF_NewStatus();
261     graph_opts = TF_NewImportGraphDefOptions();
262     TF_GraphImportGraphDef(tf_model->graph, graph_def, graph_opts, tf_model->status);
263     TF_DeleteImportGraphDefOptions(graph_opts);
264     TF_DeleteBuffer(graph_def);
265     if (TF_GetCode(tf_model->status) != TF_OK){
266         TF_DeleteGraph(tf_model->graph);
267         TF_DeleteStatus(tf_model->status);
268         av_freep(&tf_model);
269         av_freep(&model);
270         return NULL;
271     }
272
273     model->model = (void*)tf_model;
274     model->set_input_output = &set_input_output_tf;
275
276     return model;
277 }
278
279 DNNReturnType ff_dnn_execute_model_tf(const DNNModel* model)
280 {
281     TFModel* tf_model = (TFModel*)model->model;
282     TF_Tensor* output_tensor;
283
284     TF_SessionRun(tf_model->session, NULL,
285                   &tf_model->input, &tf_model->input_tensor, 1,
286                   &tf_model->output, &output_tensor, 1,
287                   NULL, 0, NULL, tf_model->status);
288
289     if (TF_GetCode(tf_model->status) != TF_OK){
290         return DNN_ERROR;
291     }
292     else{
293         memcpy(tf_model->output_data->data, TF_TensorData(output_tensor),
294                tf_model->output_data->height * tf_model->output_data->width *
295                tf_model->output_data->channels * sizeof(float));
296         TF_DeleteTensor(output_tensor);
297
298         return DNN_SUCCESS;
299     }
300 }
301
302 void ff_dnn_free_model_tf(DNNModel** model)
303 {
304     TFModel* tf_model;
305
306     if (*model){
307         tf_model = (TFModel*)(*model)->model;
308         if (tf_model->graph){
309             TF_DeleteGraph(tf_model->graph);
310         }
311         if (tf_model->session){
312             TF_CloseSession(tf_model->session, tf_model->status);
313             TF_DeleteSession(tf_model->session, tf_model->status);
314         }
315         if (tf_model->status){
316             TF_DeleteStatus(tf_model->status);
317         }
318         if (tf_model->input_tensor){
319             TF_DeleteTensor(tf_model->input_tensor);
320         }
321         av_freep(&tf_model->output_data->data);
322         av_freep(&tf_model);
323         av_freep(model);
324     }
325 }