]> git.sesse.net Git - ffmpeg/blob - libavfilter/dnn/dnn_backend_tf.c
dnn: add userdata for load model parameter
[ffmpeg] / libavfilter / dnn / dnn_backend_tf.c
1 /*
2  * Copyright (c) 2018 Sergey Lavrushkin
3  *
4  * This file is part of FFmpeg.
5  *
6  * FFmpeg is free software; you can redistribute it and/or
7  * modify it under the terms of the GNU Lesser General Public
8  * License as published by the Free Software Foundation; either
9  * version 2.1 of the License, or (at your option) any later version.
10  *
11  * FFmpeg is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14  * Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public
17  * License along with FFmpeg; if not, write to the Free Software
18  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
19  */
20
21 /**
22  * @file
23  * DNN tensorflow backend implementation.
24  */
25
26 #include "dnn_backend_tf.h"
27 #include "dnn_backend_native.h"
28 #include "dnn_backend_native_layer_conv2d.h"
29 #include "dnn_backend_native_layer_depth2space.h"
30 #include "libavformat/avio.h"
31 #include "libavutil/avassert.h"
32 #include "dnn_backend_native_layer_pad.h"
33 #include "dnn_backend_native_layer_maximum.h"
34
35 #include <tensorflow/c/c_api.h>
36
37 typedef struct TFContext {
38     const AVClass *class;
39 } TFContext;
40
41 typedef struct TFModel{
42     TFContext ctx;
43     TF_Graph *graph;
44     TF_Session *session;
45     TF_Status *status;
46     TF_Output input;
47     TF_Tensor *input_tensor;
48     TF_Tensor **output_tensors;
49     uint32_t nb_output;
50 } TFModel;
51
52 static const AVClass dnn_tensorflow_class = {
53     .class_name = "dnn_tensorflow",
54     .item_name  = av_default_item_name,
55     .option     = NULL,
56     .version    = LIBAVUTIL_VERSION_INT,
57     .category   = AV_CLASS_CATEGORY_FILTER,
58 };
59
60 static void free_buffer(void *data, size_t length)
61 {
62     av_freep(&data);
63 }
64
65 static TF_Buffer *read_graph(const char *model_filename)
66 {
67     TF_Buffer *graph_buf;
68     unsigned char *graph_data = NULL;
69     AVIOContext *model_file_context;
70     long size, bytes_read;
71
72     if (avio_open(&model_file_context, model_filename, AVIO_FLAG_READ) < 0){
73         return NULL;
74     }
75
76     size = avio_size(model_file_context);
77
78     graph_data = av_malloc(size);
79     if (!graph_data){
80         avio_closep(&model_file_context);
81         return NULL;
82     }
83     bytes_read = avio_read(model_file_context, graph_data, size);
84     avio_closep(&model_file_context);
85     if (bytes_read != size){
86         av_freep(&graph_data);
87         return NULL;
88     }
89
90     graph_buf = TF_NewBuffer();
91     graph_buf->data = (void *)graph_data;
92     graph_buf->length = size;
93     graph_buf->data_deallocator = free_buffer;
94
95     return graph_buf;
96 }
97
98 static TF_Tensor *allocate_input_tensor(const DNNData *input)
99 {
100     TF_DataType dt;
101     size_t size;
102     int64_t input_dims[] = {1, input->height, input->width, input->channels};
103     switch (input->dt) {
104     case DNN_FLOAT:
105         dt = TF_FLOAT;
106         size = sizeof(float);
107         break;
108     case DNN_UINT8:
109         dt = TF_UINT8;
110         size = 1;
111         break;
112     default:
113         av_assert0(!"should not reach here");
114     }
115
116     return TF_AllocateTensor(dt, input_dims, 4,
117                              input_dims[1] * input_dims[2] * input_dims[3] * size);
118 }
119
120 static DNNReturnType get_input_tf(void *model, DNNData *input, const char *input_name)
121 {
122     TFModel *tf_model = (TFModel *)model;
123     TFContext *ctx = &tf_model->ctx;
124     TF_Status *status;
125     int64_t dims[4];
126
127     TF_Output tf_output;
128     tf_output.oper = TF_GraphOperationByName(tf_model->graph, input_name);
129     if (!tf_output.oper) {
130         av_log(ctx, AV_LOG_ERROR, "Could not find \"%s\" in model\n", input_name);
131         return DNN_ERROR;
132     }
133
134     tf_output.index = 0;
135     input->dt = TF_OperationOutputType(tf_output);
136
137     status = TF_NewStatus();
138     TF_GraphGetTensorShape(tf_model->graph, tf_output, dims, 4, status);
139     if (TF_GetCode(status) != TF_OK){
140         TF_DeleteStatus(status);
141         av_log(ctx, AV_LOG_ERROR, "Failed to get input tensor shape: number of dimension incorrect\n");
142         return DNN_ERROR;
143     }
144     TF_DeleteStatus(status);
145
146     // currently only NHWC is supported
147     av_assert0(dims[0] == 1);
148     input->height = dims[1];
149     input->width = dims[2];
150     input->channels = dims[3];
151
152     return DNN_SUCCESS;
153 }
154
155 static DNNReturnType set_input_tf(void *model, DNNData *input, const char *input_name)
156 {
157     TFModel *tf_model = (TFModel *)model;
158     TFContext *ctx = &tf_model->ctx;
159     TF_SessionOptions *sess_opts;
160     const TF_Operation *init_op = TF_GraphOperationByName(tf_model->graph, "init");
161
162     // Input operation
163     tf_model->input.oper = TF_GraphOperationByName(tf_model->graph, input_name);
164     if (!tf_model->input.oper){
165         av_log(ctx, AV_LOG_ERROR, "Could not find \"%s\" in model\n", input_name);
166         return DNN_ERROR;
167     }
168     tf_model->input.index = 0;
169     if (tf_model->input_tensor){
170         TF_DeleteTensor(tf_model->input_tensor);
171     }
172     tf_model->input_tensor = allocate_input_tensor(input);
173     if (!tf_model->input_tensor){
174         av_log(ctx, AV_LOG_ERROR, "Failed to allocate memory for input tensor\n");
175         return DNN_ERROR;
176     }
177     input->data = (float *)TF_TensorData(tf_model->input_tensor);
178
179     // session
180     if (tf_model->session){
181         TF_CloseSession(tf_model->session, tf_model->status);
182         TF_DeleteSession(tf_model->session, tf_model->status);
183     }
184
185     sess_opts = TF_NewSessionOptions();
186     tf_model->session = TF_NewSession(tf_model->graph, sess_opts, tf_model->status);
187     TF_DeleteSessionOptions(sess_opts);
188     if (TF_GetCode(tf_model->status) != TF_OK)
189     {
190         av_log(ctx, AV_LOG_ERROR, "Failed to create new session with model graph\n");
191         return DNN_ERROR;
192     }
193
194     // Run initialization operation with name "init" if it is present in graph
195     if (init_op){
196         TF_SessionRun(tf_model->session, NULL,
197                       NULL, NULL, 0,
198                       NULL, NULL, 0,
199                       &init_op, 1, NULL, tf_model->status);
200         if (TF_GetCode(tf_model->status) != TF_OK)
201         {
202             av_log(ctx, AV_LOG_ERROR, "Failed to run session when initializing\n");
203             return DNN_ERROR;
204         }
205     }
206
207     return DNN_SUCCESS;
208 }
209
210 static DNNReturnType load_tf_model(TFModel *tf_model, const char *model_filename)
211 {
212     TFContext *ctx = &tf_model->ctx;
213     TF_Buffer *graph_def;
214     TF_ImportGraphDefOptions *graph_opts;
215
216     graph_def = read_graph(model_filename);
217     if (!graph_def){
218         av_log(ctx, AV_LOG_ERROR, "Failed to read model \"%s\" graph\n", model_filename);
219         return DNN_ERROR;
220     }
221     tf_model->graph = TF_NewGraph();
222     tf_model->status = TF_NewStatus();
223     graph_opts = TF_NewImportGraphDefOptions();
224     TF_GraphImportGraphDef(tf_model->graph, graph_def, graph_opts, tf_model->status);
225     TF_DeleteImportGraphDefOptions(graph_opts);
226     TF_DeleteBuffer(graph_def);
227     if (TF_GetCode(tf_model->status) != TF_OK){
228         TF_DeleteGraph(tf_model->graph);
229         TF_DeleteStatus(tf_model->status);
230         av_log(ctx, AV_LOG_ERROR, "Failed to import serialized graph to model graph\n");
231         return DNN_ERROR;
232     }
233
234     return DNN_SUCCESS;
235 }
236
237 #define NAME_BUFFER_SIZE 256
238
239 static DNNReturnType add_conv_layer(TFModel *tf_model, TF_Operation *transpose_op, TF_Operation **cur_op,
240                                     ConvolutionalParams* params, const int layer)
241 {
242     TFContext *ctx = &tf_model->ctx;
243     TF_Operation *op;
244     TF_OperationDescription *op_desc;
245     TF_Output input;
246     int64_t strides[] = {1, 1, 1, 1};
247     TF_Tensor *tensor;
248     int64_t dims[4];
249     int dims_len;
250     char name_buffer[NAME_BUFFER_SIZE];
251     int32_t size;
252
253     size = params->input_num * params->output_num * params->kernel_size * params->kernel_size;
254     input.index = 0;
255
256     snprintf(name_buffer, NAME_BUFFER_SIZE, "conv_kernel%d", layer);
257     op_desc = TF_NewOperation(tf_model->graph, "Const", name_buffer);
258     TF_SetAttrType(op_desc, "dtype", TF_FLOAT);
259     dims[0] = params->output_num;
260     dims[1] = params->kernel_size;
261     dims[2] = params->kernel_size;
262     dims[3] = params->input_num;
263     dims_len = 4;
264     tensor = TF_AllocateTensor(TF_FLOAT, dims, dims_len, size * sizeof(float));
265     memcpy(TF_TensorData(tensor), params->kernel, size * sizeof(float));
266     TF_SetAttrTensor(op_desc, "value", tensor, tf_model->status);
267     if (TF_GetCode(tf_model->status) != TF_OK){
268         av_log(ctx, AV_LOG_ERROR, "Failed to set value for kernel of conv layer %d\n", layer);
269         return DNN_ERROR;
270     }
271     op = TF_FinishOperation(op_desc, tf_model->status);
272     if (TF_GetCode(tf_model->status) != TF_OK){
273         av_log(ctx, AV_LOG_ERROR, "Failed to add kernel to conv layer %d\n", layer);
274         return DNN_ERROR;
275     }
276
277     snprintf(name_buffer, NAME_BUFFER_SIZE, "transpose%d", layer);
278     op_desc = TF_NewOperation(tf_model->graph, "Transpose", name_buffer);
279     input.oper = op;
280     TF_AddInput(op_desc, input);
281     input.oper = transpose_op;
282     TF_AddInput(op_desc, input);
283     TF_SetAttrType(op_desc, "T", TF_FLOAT);
284     TF_SetAttrType(op_desc, "Tperm", TF_INT32);
285     op = TF_FinishOperation(op_desc, tf_model->status);
286     if (TF_GetCode(tf_model->status) != TF_OK){
287         av_log(ctx, AV_LOG_ERROR, "Failed to add transpose to conv layer %d\n", layer);
288         return DNN_ERROR;
289     }
290
291     snprintf(name_buffer, NAME_BUFFER_SIZE, "conv2d%d", layer);
292     op_desc = TF_NewOperation(tf_model->graph, "Conv2D", name_buffer);
293     input.oper = *cur_op;
294     TF_AddInput(op_desc, input);
295     input.oper = op;
296     TF_AddInput(op_desc, input);
297     TF_SetAttrType(op_desc, "T", TF_FLOAT);
298     TF_SetAttrIntList(op_desc, "strides", strides, 4);
299     TF_SetAttrString(op_desc, "padding", "VALID", 5);
300     *cur_op = TF_FinishOperation(op_desc, tf_model->status);
301     if (TF_GetCode(tf_model->status) != TF_OK){
302         av_log(ctx, AV_LOG_ERROR, "Failed to add conv2d to conv layer %d\n", layer);
303         return DNN_ERROR;
304     }
305
306     snprintf(name_buffer, NAME_BUFFER_SIZE, "conv_biases%d", layer);
307     op_desc = TF_NewOperation(tf_model->graph, "Const", name_buffer);
308     TF_SetAttrType(op_desc, "dtype", TF_FLOAT);
309     dims[0] = params->output_num;
310     dims_len = 1;
311     tensor = TF_AllocateTensor(TF_FLOAT, dims, dims_len, params->output_num * sizeof(float));
312     memcpy(TF_TensorData(tensor), params->biases, params->output_num * sizeof(float));
313     TF_SetAttrTensor(op_desc, "value", tensor, tf_model->status);
314     if (TF_GetCode(tf_model->status) != TF_OK){
315         av_log(ctx, AV_LOG_ERROR, "Failed to set value for conv_biases of conv layer %d\n", layer);
316         return DNN_ERROR;
317     }
318     op = TF_FinishOperation(op_desc, tf_model->status);
319     if (TF_GetCode(tf_model->status) != TF_OK){
320         av_log(ctx, AV_LOG_ERROR, "Failed to add conv_biases to conv layer %d\n", layer);
321         return DNN_ERROR;
322     }
323
324     snprintf(name_buffer, NAME_BUFFER_SIZE, "bias_add%d", layer);
325     op_desc = TF_NewOperation(tf_model->graph, "BiasAdd", name_buffer);
326     input.oper = *cur_op;
327     TF_AddInput(op_desc, input);
328     input.oper = op;
329     TF_AddInput(op_desc, input);
330     TF_SetAttrType(op_desc, "T", TF_FLOAT);
331     *cur_op = TF_FinishOperation(op_desc, tf_model->status);
332     if (TF_GetCode(tf_model->status) != TF_OK){
333         av_log(ctx, AV_LOG_ERROR, "Failed to add bias_add to conv layer %d\n", layer);
334         return DNN_ERROR;
335     }
336
337     snprintf(name_buffer, NAME_BUFFER_SIZE, "activation%d", layer);
338     switch (params->activation){
339     case RELU:
340         op_desc = TF_NewOperation(tf_model->graph, "Relu", name_buffer);
341         break;
342     case TANH:
343         op_desc = TF_NewOperation(tf_model->graph, "Tanh", name_buffer);
344         break;
345     case SIGMOID:
346         op_desc = TF_NewOperation(tf_model->graph, "Sigmoid", name_buffer);
347         break;
348     default:
349         av_log(ctx, AV_LOG_ERROR, "Unsupported convolutional activation function\n");
350         return DNN_ERROR;
351     }
352     input.oper = *cur_op;
353     TF_AddInput(op_desc, input);
354     TF_SetAttrType(op_desc, "T", TF_FLOAT);
355     *cur_op = TF_FinishOperation(op_desc, tf_model->status);
356     if (TF_GetCode(tf_model->status) != TF_OK){
357         av_log(ctx, AV_LOG_ERROR, "Failed to add activation function to conv layer %d\n", layer);
358         return DNN_ERROR;
359     }
360
361     return DNN_SUCCESS;
362 }
363
364 static DNNReturnType add_depth_to_space_layer(TFModel *tf_model, TF_Operation **cur_op,
365                                               DepthToSpaceParams *params, const int layer)
366 {
367     TFContext *ctx = &tf_model->ctx;
368     TF_OperationDescription *op_desc;
369     TF_Output input;
370     char name_buffer[NAME_BUFFER_SIZE];
371
372     snprintf(name_buffer, NAME_BUFFER_SIZE, "depth_to_space%d", layer);
373     op_desc = TF_NewOperation(tf_model->graph, "DepthToSpace", name_buffer);
374     input.oper = *cur_op;
375     input.index = 0;
376     TF_AddInput(op_desc, input);
377     TF_SetAttrType(op_desc, "T", TF_FLOAT);
378     TF_SetAttrInt(op_desc, "block_size", params->block_size);
379     *cur_op = TF_FinishOperation(op_desc, tf_model->status);
380     if (TF_GetCode(tf_model->status) != TF_OK){
381         av_log(ctx, AV_LOG_ERROR, "Failed to add depth_to_space to layer %d\n", layer);
382         return DNN_ERROR;
383     }
384
385     return DNN_SUCCESS;
386 }
387
388 static DNNReturnType add_pad_layer(TFModel *tf_model, TF_Operation **cur_op,
389                                               LayerPadParams *params, const int layer)
390 {
391     TFContext *ctx = &tf_model->ctx;
392     TF_Operation *op;
393     TF_Tensor *tensor;
394     TF_OperationDescription *op_desc;
395     TF_Output input;
396     int32_t *pads;
397     int64_t pads_shape[] = {4, 2};
398
399     char name_buffer[NAME_BUFFER_SIZE];
400     snprintf(name_buffer, NAME_BUFFER_SIZE, "pad%d", layer);
401
402     op_desc = TF_NewOperation(tf_model->graph, "Const", name_buffer);
403     TF_SetAttrType(op_desc, "dtype", TF_INT32);
404     tensor = TF_AllocateTensor(TF_INT32, pads_shape, 2, 4 * 2 * sizeof(int32_t));
405     pads = (int32_t *)TF_TensorData(tensor);
406     pads[0] = params->paddings[0][0];
407     pads[1] = params->paddings[0][1];
408     pads[2] = params->paddings[1][0];
409     pads[3] = params->paddings[1][1];
410     pads[4] = params->paddings[2][0];
411     pads[5] = params->paddings[2][1];
412     pads[6] = params->paddings[3][0];
413     pads[7] = params->paddings[3][1];
414     TF_SetAttrTensor(op_desc, "value", tensor, tf_model->status);
415     if (TF_GetCode(tf_model->status) != TF_OK){
416         av_log(ctx, AV_LOG_ERROR, "Failed to set value for pad of layer %d\n", layer);
417         return DNN_ERROR;
418     }
419     op = TF_FinishOperation(op_desc, tf_model->status);
420     if (TF_GetCode(tf_model->status) != TF_OK){
421         av_log(ctx, AV_LOG_ERROR, "Failed to add pad to layer %d\n", layer);
422         return DNN_ERROR;
423     }
424
425     op_desc = TF_NewOperation(tf_model->graph, "MirrorPad", "mirror_pad");
426     input.oper = *cur_op;
427     input.index = 0;
428     TF_AddInput(op_desc, input);
429     input.oper = op;
430     TF_AddInput(op_desc, input);
431     TF_SetAttrType(op_desc, "T", TF_FLOAT);
432     TF_SetAttrType(op_desc, "Tpaddings", TF_INT32);
433     TF_SetAttrString(op_desc, "mode", "SYMMETRIC", 9);
434     *cur_op = TF_FinishOperation(op_desc, tf_model->status);
435     if (TF_GetCode(tf_model->status) != TF_OK){
436         av_log(ctx, AV_LOG_ERROR, "Failed to add mirror_pad to layer %d\n", layer);
437         return DNN_ERROR;
438     }
439
440     return DNN_SUCCESS;
441 }
442
443 static DNNReturnType add_maximum_layer(TFModel *tf_model, TF_Operation **cur_op,
444                                        DnnLayerMaximumParams *params, const int layer)
445 {
446     TFContext *ctx = &tf_model->ctx;
447     TF_Operation *op;
448     TF_Tensor *tensor;
449     TF_OperationDescription *op_desc;
450     TF_Output input;
451     float *y;
452
453     char name_buffer[NAME_BUFFER_SIZE];
454     snprintf(name_buffer, NAME_BUFFER_SIZE, "maximum/y%d", layer);
455
456     op_desc = TF_NewOperation(tf_model->graph, "Const", name_buffer);
457     TF_SetAttrType(op_desc, "dtype", TF_FLOAT);
458     tensor = TF_AllocateTensor(TF_FLOAT, NULL, 0, TF_DataTypeSize(TF_FLOAT));
459     y = (float *)TF_TensorData(tensor);
460     *y = params->val.y;
461     TF_SetAttrTensor(op_desc, "value", tensor, tf_model->status);
462     if (TF_GetCode(tf_model->status) != TF_OK){
463         av_log(ctx, AV_LOG_ERROR, "Failed to set value for maximum/y of layer %d", layer);
464         return DNN_ERROR;
465     }
466     op = TF_FinishOperation(op_desc, tf_model->status);
467     if (TF_GetCode(tf_model->status) != TF_OK){
468         av_log(ctx, AV_LOG_ERROR, "Failed to add maximum/y to layer %d\n", layer);
469         return DNN_ERROR;
470     }
471
472     snprintf(name_buffer, NAME_BUFFER_SIZE, "maximum%d", layer);
473     op_desc = TF_NewOperation(tf_model->graph, "Maximum", name_buffer);
474     input.oper = *cur_op;
475     input.index = 0;
476     TF_AddInput(op_desc, input);
477     input.oper = op;
478     TF_AddInput(op_desc, input);
479     TF_SetAttrType(op_desc, "T", TF_FLOAT);
480     *cur_op = TF_FinishOperation(op_desc, tf_model->status);
481     if (TF_GetCode(tf_model->status) != TF_OK){
482         av_log(ctx, AV_LOG_ERROR, "Failed to add maximum to layer %d\n", layer);
483         return DNN_ERROR;
484     }
485
486     return DNN_SUCCESS;
487 }
488
489 static DNNReturnType load_native_model(TFModel *tf_model, const char *model_filename)
490 {
491     TFContext *ctx = &tf_model->ctx;
492     int32_t layer;
493     TF_OperationDescription *op_desc;
494     TF_Operation *op;
495     TF_Operation *transpose_op;
496     TF_Tensor *tensor;
497     TF_Output input;
498     int32_t *transpose_perm;
499     int64_t transpose_perm_shape[] = {4};
500     int64_t input_shape[] = {1, -1, -1, -1};
501     DNNReturnType layer_add_res;
502     DNNModel *model = NULL;
503     NativeModel *native_model;
504
505     model = ff_dnn_load_model_native(model_filename, NULL, NULL);
506     if (!model){
507         av_log(ctx, AV_LOG_ERROR, "Failed to load native model\n");
508         return DNN_ERROR;
509     }
510
511     native_model = (NativeModel *)model->model;
512     tf_model->graph = TF_NewGraph();
513     tf_model->status = TF_NewStatus();
514
515 #define CLEANUP_ON_ERROR(tf_model) \
516     { \
517         TF_DeleteGraph(tf_model->graph); \
518         TF_DeleteStatus(tf_model->status); \
519         av_log(ctx, AV_LOG_ERROR, "Failed to set value or add operator to layer\n"); \
520         return DNN_ERROR; \
521     }
522
523     op_desc = TF_NewOperation(tf_model->graph, "Placeholder", "x");
524     TF_SetAttrType(op_desc, "dtype", TF_FLOAT);
525     TF_SetAttrShape(op_desc, "shape", input_shape, 4);
526     op = TF_FinishOperation(op_desc, tf_model->status);
527     if (TF_GetCode(tf_model->status) != TF_OK){
528         CLEANUP_ON_ERROR(tf_model);
529     }
530
531     op_desc = TF_NewOperation(tf_model->graph, "Const", "transpose_perm");
532     TF_SetAttrType(op_desc, "dtype", TF_INT32);
533     tensor = TF_AllocateTensor(TF_INT32, transpose_perm_shape, 1, 4 * sizeof(int32_t));
534     transpose_perm = (int32_t *)TF_TensorData(tensor);
535     transpose_perm[0] = 1;
536     transpose_perm[1] = 2;
537     transpose_perm[2] = 3;
538     transpose_perm[3] = 0;
539     TF_SetAttrTensor(op_desc, "value", tensor, tf_model->status);
540     if (TF_GetCode(tf_model->status) != TF_OK){
541         CLEANUP_ON_ERROR(tf_model);
542     }
543     transpose_op = TF_FinishOperation(op_desc, tf_model->status);
544
545     for (layer = 0; layer < native_model->layers_num; ++layer){
546         switch (native_model->layers[layer].type){
547         case DLT_INPUT:
548             layer_add_res = DNN_SUCCESS;
549             break;
550         case DLT_CONV2D:
551             layer_add_res = add_conv_layer(tf_model, transpose_op, &op,
552                                            (ConvolutionalParams *)native_model->layers[layer].params, layer);
553             break;
554         case DLT_DEPTH_TO_SPACE:
555             layer_add_res = add_depth_to_space_layer(tf_model, &op,
556                                                      (DepthToSpaceParams *)native_model->layers[layer].params, layer);
557             break;
558         case DLT_MIRROR_PAD:
559             layer_add_res = add_pad_layer(tf_model, &op,
560                                           (LayerPadParams *)native_model->layers[layer].params, layer);
561             break;
562         case DLT_MAXIMUM:
563             layer_add_res = add_maximum_layer(tf_model, &op,
564                                           (DnnLayerMaximumParams *)native_model->layers[layer].params, layer);
565             break;
566         default:
567             CLEANUP_ON_ERROR(tf_model);
568         }
569
570         if (layer_add_res != DNN_SUCCESS){
571             CLEANUP_ON_ERROR(tf_model);
572         }
573     }
574
575     op_desc = TF_NewOperation(tf_model->graph, "Identity", "y");
576     input.oper = op;
577     input.index = 0;
578     TF_AddInput(op_desc, input);
579     TF_FinishOperation(op_desc, tf_model->status);
580     if (TF_GetCode(tf_model->status) != TF_OK){
581         CLEANUP_ON_ERROR(tf_model);
582     }
583
584     ff_dnn_free_model_native(&model);
585
586     return DNN_SUCCESS;
587 }
588
589 DNNModel *ff_dnn_load_model_tf(const char *model_filename, const char *options, void *userdata)
590 {
591     DNNModel *model = NULL;
592     TFModel *tf_model = NULL;
593
594     model = av_malloc(sizeof(DNNModel));
595     if (!model){
596         return NULL;
597     }
598
599     tf_model = av_mallocz(sizeof(TFModel));
600     if (!tf_model){
601         av_freep(&model);
602         return NULL;
603     }
604     tf_model->ctx.class = &dnn_tensorflow_class;
605
606     if (load_tf_model(tf_model, model_filename) != DNN_SUCCESS){
607         if (load_native_model(tf_model, model_filename) != DNN_SUCCESS){
608             av_freep(&tf_model);
609             av_freep(&model);
610
611             return NULL;
612         }
613     }
614
615     model->model = (void *)tf_model;
616     model->set_input = &set_input_tf;
617     model->get_input = &get_input_tf;
618     model->options = options;
619     model->userdata = userdata;
620
621     return model;
622 }
623
624 DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, DNNData *outputs, const char **output_names, uint32_t nb_output)
625 {
626     TF_Output *tf_outputs;
627     TFModel *tf_model = (TFModel *)model->model;
628     TFContext *ctx = &tf_model->ctx;
629
630     tf_outputs = av_malloc_array(nb_output, sizeof(*tf_outputs));
631     if (tf_outputs == NULL) {
632         av_log(ctx, AV_LOG_ERROR, "Failed to allocate memory for *tf_outputs\n"); \
633         return DNN_ERROR;
634     }
635
636     if (tf_model->output_tensors) {
637         for (uint32_t i = 0; i < tf_model->nb_output; ++i) {
638             if (tf_model->output_tensors[i]) {
639                 TF_DeleteTensor(tf_model->output_tensors[i]);
640                 tf_model->output_tensors[i] = NULL;
641             }
642         }
643     }
644     av_freep(&tf_model->output_tensors);
645     tf_model->nb_output = nb_output;
646     tf_model->output_tensors = av_mallocz_array(nb_output, sizeof(*tf_model->output_tensors));
647     if (!tf_model->output_tensors) {
648         av_freep(&tf_outputs);
649         av_log(ctx, AV_LOG_ERROR, "Failed to allocate memory for output tensor\n"); \
650         return DNN_ERROR;
651     }
652
653     for (int i = 0; i < nb_output; ++i) {
654         tf_outputs[i].oper = TF_GraphOperationByName(tf_model->graph, output_names[i]);
655         if (!tf_outputs[i].oper) {
656             av_freep(&tf_outputs);
657             av_log(ctx, AV_LOG_ERROR, "Could not find output \"%s\" in model\n", output_names[i]); \
658             return DNN_ERROR;
659         }
660         tf_outputs[i].index = 0;
661     }
662
663     TF_SessionRun(tf_model->session, NULL,
664                   &tf_model->input, &tf_model->input_tensor, 1,
665                   tf_outputs, tf_model->output_tensors, nb_output,
666                   NULL, 0, NULL, tf_model->status);
667     if (TF_GetCode(tf_model->status) != TF_OK) {
668         av_freep(&tf_outputs);
669         av_log(ctx, AV_LOG_ERROR, "Failed to run session when executing model\n");
670         return DNN_ERROR;
671     }
672
673     for (uint32_t i = 0; i < nb_output; ++i) {
674         outputs[i].height = TF_Dim(tf_model->output_tensors[i], 1);
675         outputs[i].width = TF_Dim(tf_model->output_tensors[i], 2);
676         outputs[i].channels = TF_Dim(tf_model->output_tensors[i], 3);
677         outputs[i].data = TF_TensorData(tf_model->output_tensors[i]);
678         outputs[i].dt = TF_TensorType(tf_model->output_tensors[i]);
679     }
680
681     av_freep(&tf_outputs);
682     return DNN_SUCCESS;
683 }
684
685 void ff_dnn_free_model_tf(DNNModel **model)
686 {
687     TFModel *tf_model;
688
689     if (*model){
690         tf_model = (TFModel *)(*model)->model;
691         if (tf_model->graph){
692             TF_DeleteGraph(tf_model->graph);
693         }
694         if (tf_model->session){
695             TF_CloseSession(tf_model->session, tf_model->status);
696             TF_DeleteSession(tf_model->session, tf_model->status);
697         }
698         if (tf_model->status){
699             TF_DeleteStatus(tf_model->status);
700         }
701         if (tf_model->input_tensor){
702             TF_DeleteTensor(tf_model->input_tensor);
703         }
704         if (tf_model->output_tensors) {
705             for (uint32_t i = 0; i < tf_model->nb_output; ++i) {
706                 if (tf_model->output_tensors[i]) {
707                     TF_DeleteTensor(tf_model->output_tensors[i]);
708                     tf_model->output_tensors[i] = NULL;
709                 }
710             }
711         }
712         av_freep(&tf_model->output_tensors);
713         av_freep(&tf_model);
714         av_freep(model);
715     }
716 }