]> git.sesse.net Git - ffmpeg/blob - libavfilter/dnn/dnn_backend_tf.c
avfilter/dnn/dnn_backend_tf: simplify the code with ff_hex_to_data
[ffmpeg] / libavfilter / dnn / dnn_backend_tf.c
1 /*
2  * Copyright (c) 2018 Sergey Lavrushkin
3  *
4  * This file is part of FFmpeg.
5  *
6  * FFmpeg is free software; you can redistribute it and/or
7  * modify it under the terms of the GNU Lesser General Public
8  * License as published by the Free Software Foundation; either
9  * version 2.1 of the License, or (at your option) any later version.
10  *
11  * FFmpeg is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14  * Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public
17  * License along with FFmpeg; if not, write to the Free Software
18  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
19  */
20
21 /**
22  * @file
23  * DNN tensorflow backend implementation.
24  */
25
26 #include "dnn_backend_tf.h"
27 #include "dnn_backend_native.h"
28 #include "dnn_backend_native_layer_conv2d.h"
29 #include "dnn_backend_native_layer_depth2space.h"
30 #include "libavformat/avio.h"
31 #include "libavformat/internal.h"
32 #include "libavutil/avassert.h"
33 #include "../internal.h"
34 #include "dnn_backend_native_layer_pad.h"
35 #include "dnn_backend_native_layer_maximum.h"
36 #include "dnn_io_proc.h"
37
38 #include <tensorflow/c/c_api.h>
39
40 typedef struct TFOptions{
41     char *sess_config;
42 } TFOptions;
43
44 typedef struct TFContext {
45     const AVClass *class;
46     TFOptions options;
47 } TFContext;
48
49 typedef struct TFModel{
50     TFContext ctx;
51     DNNModel *model;
52     TF_Graph *graph;
53     TF_Session *session;
54     TF_Status *status;
55 } TFModel;
56
57 #define OFFSET(x) offsetof(TFContext, x)
58 #define FLAGS AV_OPT_FLAG_FILTERING_PARAM
59 static const AVOption dnn_tensorflow_options[] = {
60     { "sess_config", "config for SessionOptions", OFFSET(options.sess_config), AV_OPT_TYPE_STRING, { .str = NULL }, 0, 0, FLAGS },
61     { NULL }
62 };
63
64 AVFILTER_DEFINE_CLASS(dnn_tensorflow);
65
66 static DNNReturnType execute_model_tf(const DNNModel *model, const char *input_name, AVFrame *in_frame,
67                                       const char **output_names, uint32_t nb_output, AVFrame *out_frame,
68                                       int do_ioproc);
69
70 static void free_buffer(void *data, size_t length)
71 {
72     av_freep(&data);
73 }
74
75 static TF_Buffer *read_graph(const char *model_filename)
76 {
77     TF_Buffer *graph_buf;
78     unsigned char *graph_data = NULL;
79     AVIOContext *model_file_context;
80     long size, bytes_read;
81
82     if (avio_open(&model_file_context, model_filename, AVIO_FLAG_READ) < 0){
83         return NULL;
84     }
85
86     size = avio_size(model_file_context);
87
88     graph_data = av_malloc(size);
89     if (!graph_data){
90         avio_closep(&model_file_context);
91         return NULL;
92     }
93     bytes_read = avio_read(model_file_context, graph_data, size);
94     avio_closep(&model_file_context);
95     if (bytes_read != size){
96         av_freep(&graph_data);
97         return NULL;
98     }
99
100     graph_buf = TF_NewBuffer();
101     graph_buf->data = graph_data;
102     graph_buf->length = size;
103     graph_buf->data_deallocator = free_buffer;
104
105     return graph_buf;
106 }
107
108 static TF_Tensor *allocate_input_tensor(const DNNData *input)
109 {
110     TF_DataType dt;
111     size_t size;
112     int64_t input_dims[] = {1, input->height, input->width, input->channels};
113     switch (input->dt) {
114     case DNN_FLOAT:
115         dt = TF_FLOAT;
116         size = sizeof(float);
117         break;
118     case DNN_UINT8:
119         dt = TF_UINT8;
120         size = 1;
121         break;
122     default:
123         av_assert0(!"should not reach here");
124     }
125
126     return TF_AllocateTensor(dt, input_dims, 4,
127                              input_dims[1] * input_dims[2] * input_dims[3] * size);
128 }
129
130 static DNNReturnType get_input_tf(void *model, DNNData *input, const char *input_name)
131 {
132     TFModel *tf_model = model;
133     TFContext *ctx = &tf_model->ctx;
134     TF_Status *status;
135     int64_t dims[4];
136
137     TF_Output tf_output;
138     tf_output.oper = TF_GraphOperationByName(tf_model->graph, input_name);
139     if (!tf_output.oper) {
140         av_log(ctx, AV_LOG_ERROR, "Could not find \"%s\" in model\n", input_name);
141         return DNN_ERROR;
142     }
143
144     tf_output.index = 0;
145     input->dt = TF_OperationOutputType(tf_output);
146
147     status = TF_NewStatus();
148     TF_GraphGetTensorShape(tf_model->graph, tf_output, dims, 4, status);
149     if (TF_GetCode(status) != TF_OK){
150         TF_DeleteStatus(status);
151         av_log(ctx, AV_LOG_ERROR, "Failed to get input tensor shape: number of dimension incorrect\n");
152         return DNN_ERROR;
153     }
154     TF_DeleteStatus(status);
155
156     // currently only NHWC is supported
157     av_assert0(dims[0] == 1);
158     input->height = dims[1];
159     input->width = dims[2];
160     input->channels = dims[3];
161
162     return DNN_SUCCESS;
163 }
164
165 static DNNReturnType get_output_tf(void *model, const char *input_name, int input_width, int input_height,
166                                    const char *output_name, int *output_width, int *output_height)
167 {
168     DNNReturnType ret;
169     TFModel *tf_model = model;
170     TFContext *ctx = &tf_model->ctx;
171     AVFrame *in_frame = av_frame_alloc();
172     AVFrame *out_frame = NULL;
173
174     if (!in_frame) {
175         av_log(ctx, AV_LOG_ERROR, "Failed to allocate memory for input frame\n");
176         return DNN_ERROR;
177     }
178
179     out_frame = av_frame_alloc();
180     if (!out_frame) {
181         av_log(ctx, AV_LOG_ERROR, "Failed to allocate memory for output frame\n");
182         av_frame_free(&in_frame);
183         return DNN_ERROR;
184     }
185
186     in_frame->width = input_width;
187     in_frame->height = input_height;
188
189     ret = execute_model_tf(tf_model->model, input_name, in_frame, &output_name, 1, out_frame, 0);
190     *output_width = out_frame->width;
191     *output_height = out_frame->height;
192
193     av_frame_free(&out_frame);
194     av_frame_free(&in_frame);
195     return ret;
196 }
197
198 static DNNReturnType load_tf_model(TFModel *tf_model, const char *model_filename)
199 {
200     TFContext *ctx = &tf_model->ctx;
201     TF_Buffer *graph_def;
202     TF_ImportGraphDefOptions *graph_opts;
203     TF_SessionOptions *sess_opts;
204     const TF_Operation *init_op;
205     uint8_t *sess_config = NULL;
206     int sess_config_length = 0;
207
208     // prepare the sess config data
209     if (tf_model->ctx.options.sess_config != NULL) {
210         const char *config;
211         /*
212         tf_model->ctx.options.sess_config is hex to present the serialized proto
213         required by TF_SetConfig below, so we need to first generate the serialized
214         proto in a python script, tools/python/tf_sess_config.py is a script example
215         to generate the configs of sess_config.
216         */
217         if (strncmp(tf_model->ctx.options.sess_config, "0x", 2) != 0) {
218             av_log(ctx, AV_LOG_ERROR, "sess_config should start with '0x'\n");
219             return DNN_ERROR;
220         }
221         config = tf_model->ctx.options.sess_config + 2;
222         sess_config_length = ff_hex_to_data(NULL, config);
223
224         sess_config = av_mallocz(sess_config_length + AV_INPUT_BUFFER_PADDING_SIZE);
225         if (!sess_config) {
226             av_log(ctx, AV_LOG_ERROR, "failed to allocate memory\n");
227             return DNN_ERROR;
228         }
229         ff_hex_to_data(sess_config, config);
230     }
231
232     graph_def = read_graph(model_filename);
233     if (!graph_def){
234         av_log(ctx, AV_LOG_ERROR, "Failed to read model \"%s\" graph\n", model_filename);
235         av_freep(&sess_config);
236         return DNN_ERROR;
237     }
238     tf_model->graph = TF_NewGraph();
239     tf_model->status = TF_NewStatus();
240     graph_opts = TF_NewImportGraphDefOptions();
241     TF_GraphImportGraphDef(tf_model->graph, graph_def, graph_opts, tf_model->status);
242     TF_DeleteImportGraphDefOptions(graph_opts);
243     TF_DeleteBuffer(graph_def);
244     if (TF_GetCode(tf_model->status) != TF_OK){
245         TF_DeleteGraph(tf_model->graph);
246         TF_DeleteStatus(tf_model->status);
247         av_log(ctx, AV_LOG_ERROR, "Failed to import serialized graph to model graph\n");
248         av_freep(&sess_config);
249         return DNN_ERROR;
250     }
251
252     init_op = TF_GraphOperationByName(tf_model->graph, "init");
253     sess_opts = TF_NewSessionOptions();
254
255     if (sess_config) {
256         TF_SetConfig(sess_opts, sess_config, sess_config_length,tf_model->status);
257         av_freep(&sess_config);
258         if (TF_GetCode(tf_model->status) != TF_OK) {
259             TF_DeleteGraph(tf_model->graph);
260             TF_DeleteStatus(tf_model->status);
261             TF_DeleteSessionOptions(sess_opts);
262             av_log(ctx, AV_LOG_ERROR, "Failed to set config for sess options with %s\n",
263                                       tf_model->ctx.options.sess_config);
264             return DNN_ERROR;
265         }
266     }
267
268     tf_model->session = TF_NewSession(tf_model->graph, sess_opts, tf_model->status);
269     TF_DeleteSessionOptions(sess_opts);
270     if (TF_GetCode(tf_model->status) != TF_OK)
271     {
272         TF_DeleteGraph(tf_model->graph);
273         TF_DeleteStatus(tf_model->status);
274         av_log(ctx, AV_LOG_ERROR, "Failed to create new session with model graph\n");
275         return DNN_ERROR;
276     }
277
278     // Run initialization operation with name "init" if it is present in graph
279     if (init_op){
280         TF_SessionRun(tf_model->session, NULL,
281                       NULL, NULL, 0,
282                       NULL, NULL, 0,
283                       &init_op, 1, NULL, tf_model->status);
284         if (TF_GetCode(tf_model->status) != TF_OK)
285         {
286             TF_DeleteSession(tf_model->session, tf_model->status);
287             TF_DeleteGraph(tf_model->graph);
288             TF_DeleteStatus(tf_model->status);
289             av_log(ctx, AV_LOG_ERROR, "Failed to run session when initializing\n");
290             return DNN_ERROR;
291         }
292     }
293
294     return DNN_SUCCESS;
295 }
296
297 #define NAME_BUFFER_SIZE 256
298
299 static DNNReturnType add_conv_layer(TFModel *tf_model, TF_Operation *transpose_op, TF_Operation **cur_op,
300                                     ConvolutionalParams* params, const int layer)
301 {
302     TFContext *ctx = &tf_model->ctx;
303     TF_Operation *op;
304     TF_OperationDescription *op_desc;
305     TF_Output input;
306     int64_t strides[] = {1, 1, 1, 1};
307     TF_Tensor *kernel_tensor = NULL, *biases_tensor = NULL;
308     int64_t dims[4];
309     int dims_len;
310     char name_buffer[NAME_BUFFER_SIZE];
311     int32_t size;
312
313     size = params->input_num * params->output_num * params->kernel_size * params->kernel_size;
314     input.index = 0;
315
316     snprintf(name_buffer, NAME_BUFFER_SIZE, "conv_kernel%d", layer);
317     op_desc = TF_NewOperation(tf_model->graph, "Const", name_buffer);
318     TF_SetAttrType(op_desc, "dtype", TF_FLOAT);
319     dims[0] = params->output_num;
320     dims[1] = params->kernel_size;
321     dims[2] = params->kernel_size;
322     dims[3] = params->input_num;
323     dims_len = 4;
324     kernel_tensor = TF_AllocateTensor(TF_FLOAT, dims, dims_len, size * sizeof(float));
325     memcpy(TF_TensorData(kernel_tensor), params->kernel, size * sizeof(float));
326     TF_SetAttrTensor(op_desc, "value", kernel_tensor, tf_model->status);
327     if (TF_GetCode(tf_model->status) != TF_OK){
328         goto err;
329     }
330     op = TF_FinishOperation(op_desc, tf_model->status);
331     if (TF_GetCode(tf_model->status) != TF_OK){
332         goto err;
333     }
334
335     snprintf(name_buffer, NAME_BUFFER_SIZE, "transpose%d", layer);
336     op_desc = TF_NewOperation(tf_model->graph, "Transpose", name_buffer);
337     input.oper = op;
338     TF_AddInput(op_desc, input);
339     input.oper = transpose_op;
340     TF_AddInput(op_desc, input);
341     TF_SetAttrType(op_desc, "T", TF_FLOAT);
342     TF_SetAttrType(op_desc, "Tperm", TF_INT32);
343     op = TF_FinishOperation(op_desc, tf_model->status);
344     if (TF_GetCode(tf_model->status) != TF_OK){
345         goto err;
346     }
347
348     snprintf(name_buffer, NAME_BUFFER_SIZE, "conv2d%d", layer);
349     op_desc = TF_NewOperation(tf_model->graph, "Conv2D", name_buffer);
350     input.oper = *cur_op;
351     TF_AddInput(op_desc, input);
352     input.oper = op;
353     TF_AddInput(op_desc, input);
354     TF_SetAttrType(op_desc, "T", TF_FLOAT);
355     TF_SetAttrIntList(op_desc, "strides", strides, 4);
356     TF_SetAttrString(op_desc, "padding", "VALID", 5);
357     *cur_op = TF_FinishOperation(op_desc, tf_model->status);
358     if (TF_GetCode(tf_model->status) != TF_OK){
359         goto err;
360     }
361
362     snprintf(name_buffer, NAME_BUFFER_SIZE, "conv_biases%d", layer);
363     op_desc = TF_NewOperation(tf_model->graph, "Const", name_buffer);
364     TF_SetAttrType(op_desc, "dtype", TF_FLOAT);
365     dims[0] = params->output_num;
366     dims_len = 1;
367     biases_tensor = TF_AllocateTensor(TF_FLOAT, dims, dims_len, params->output_num * sizeof(float));
368     memcpy(TF_TensorData(biases_tensor), params->biases, params->output_num * sizeof(float));
369     TF_SetAttrTensor(op_desc, "value", biases_tensor, tf_model->status);
370     if (TF_GetCode(tf_model->status) != TF_OK){
371         goto err;
372     }
373     op = TF_FinishOperation(op_desc, tf_model->status);
374     if (TF_GetCode(tf_model->status) != TF_OK){
375         goto err;
376     }
377
378     snprintf(name_buffer, NAME_BUFFER_SIZE, "bias_add%d", layer);
379     op_desc = TF_NewOperation(tf_model->graph, "BiasAdd", name_buffer);
380     input.oper = *cur_op;
381     TF_AddInput(op_desc, input);
382     input.oper = op;
383     TF_AddInput(op_desc, input);
384     TF_SetAttrType(op_desc, "T", TF_FLOAT);
385     *cur_op = TF_FinishOperation(op_desc, tf_model->status);
386     if (TF_GetCode(tf_model->status) != TF_OK){
387         goto err;
388     }
389
390     snprintf(name_buffer, NAME_BUFFER_SIZE, "activation%d", layer);
391     switch (params->activation){
392     case RELU:
393         op_desc = TF_NewOperation(tf_model->graph, "Relu", name_buffer);
394         break;
395     case TANH:
396         op_desc = TF_NewOperation(tf_model->graph, "Tanh", name_buffer);
397         break;
398     case SIGMOID:
399         op_desc = TF_NewOperation(tf_model->graph, "Sigmoid", name_buffer);
400         break;
401     default:
402         avpriv_report_missing_feature(ctx, "convolutional activation function %d", params->activation);
403         return DNN_ERROR;
404     }
405     input.oper = *cur_op;
406     TF_AddInput(op_desc, input);
407     TF_SetAttrType(op_desc, "T", TF_FLOAT);
408     *cur_op = TF_FinishOperation(op_desc, tf_model->status);
409     if (TF_GetCode(tf_model->status) != TF_OK){
410         goto err;
411     }
412
413     return DNN_SUCCESS;
414 err:
415     TF_DeleteTensor(kernel_tensor);
416     TF_DeleteTensor(biases_tensor);
417     av_log(ctx, AV_LOG_ERROR, "Failed to add conv layer %d\n", layer);
418     return DNN_ERROR;
419 }
420
421 static DNNReturnType add_depth_to_space_layer(TFModel *tf_model, TF_Operation **cur_op,
422                                               DepthToSpaceParams *params, const int layer)
423 {
424     TFContext *ctx = &tf_model->ctx;
425     TF_OperationDescription *op_desc;
426     TF_Output input;
427     char name_buffer[NAME_BUFFER_SIZE];
428
429     snprintf(name_buffer, NAME_BUFFER_SIZE, "depth_to_space%d", layer);
430     op_desc = TF_NewOperation(tf_model->graph, "DepthToSpace", name_buffer);
431     input.oper = *cur_op;
432     input.index = 0;
433     TF_AddInput(op_desc, input);
434     TF_SetAttrType(op_desc, "T", TF_FLOAT);
435     TF_SetAttrInt(op_desc, "block_size", params->block_size);
436     *cur_op = TF_FinishOperation(op_desc, tf_model->status);
437     if (TF_GetCode(tf_model->status) != TF_OK){
438         av_log(ctx, AV_LOG_ERROR, "Failed to add depth_to_space to layer %d\n", layer);
439         return DNN_ERROR;
440     }
441
442     return DNN_SUCCESS;
443 }
444
445 static DNNReturnType add_pad_layer(TFModel *tf_model, TF_Operation **cur_op,
446                                               LayerPadParams *params, const int layer)
447 {
448     TFContext *ctx = &tf_model->ctx;
449     TF_Operation *op;
450     TF_Tensor *tensor;
451     TF_OperationDescription *op_desc;
452     TF_Output input;
453     int32_t *pads;
454     int64_t pads_shape[] = {4, 2};
455
456     char name_buffer[NAME_BUFFER_SIZE];
457     snprintf(name_buffer, NAME_BUFFER_SIZE, "pad%d", layer);
458
459     op_desc = TF_NewOperation(tf_model->graph, "Const", name_buffer);
460     TF_SetAttrType(op_desc, "dtype", TF_INT32);
461     tensor = TF_AllocateTensor(TF_INT32, pads_shape, 2, 4 * 2 * sizeof(int32_t));
462     pads = (int32_t *)TF_TensorData(tensor);
463     pads[0] = params->paddings[0][0];
464     pads[1] = params->paddings[0][1];
465     pads[2] = params->paddings[1][0];
466     pads[3] = params->paddings[1][1];
467     pads[4] = params->paddings[2][0];
468     pads[5] = params->paddings[2][1];
469     pads[6] = params->paddings[3][0];
470     pads[7] = params->paddings[3][1];
471     TF_SetAttrTensor(op_desc, "value", tensor, tf_model->status);
472     if (TF_GetCode(tf_model->status) != TF_OK){
473         TF_DeleteTensor(tensor);
474         av_log(ctx, AV_LOG_ERROR, "Failed to set value for pad of layer %d\n", layer);
475         return DNN_ERROR;
476     }
477     op = TF_FinishOperation(op_desc, tf_model->status);
478     if (TF_GetCode(tf_model->status) != TF_OK){
479         TF_DeleteTensor(tensor);
480         av_log(ctx, AV_LOG_ERROR, "Failed to add pad to layer %d\n", layer);
481         return DNN_ERROR;
482     }
483
484     op_desc = TF_NewOperation(tf_model->graph, "MirrorPad", "mirror_pad");
485     input.oper = *cur_op;
486     input.index = 0;
487     TF_AddInput(op_desc, input);
488     input.oper = op;
489     TF_AddInput(op_desc, input);
490     TF_SetAttrType(op_desc, "T", TF_FLOAT);
491     TF_SetAttrType(op_desc, "Tpaddings", TF_INT32);
492     TF_SetAttrString(op_desc, "mode", "SYMMETRIC", 9);
493     *cur_op = TF_FinishOperation(op_desc, tf_model->status);
494     if (TF_GetCode(tf_model->status) != TF_OK){
495         TF_DeleteTensor(tensor);
496         av_log(ctx, AV_LOG_ERROR, "Failed to add mirror_pad to layer %d\n", layer);
497         return DNN_ERROR;
498     }
499
500     return DNN_SUCCESS;
501 }
502
503 static DNNReturnType add_maximum_layer(TFModel *tf_model, TF_Operation **cur_op,
504                                        DnnLayerMaximumParams *params, const int layer)
505 {
506     TFContext *ctx = &tf_model->ctx;
507     TF_Operation *op;
508     TF_Tensor *tensor;
509     TF_OperationDescription *op_desc;
510     TF_Output input;
511     float *y;
512
513     char name_buffer[NAME_BUFFER_SIZE];
514     snprintf(name_buffer, NAME_BUFFER_SIZE, "maximum/y%d", layer);
515
516     op_desc = TF_NewOperation(tf_model->graph, "Const", name_buffer);
517     TF_SetAttrType(op_desc, "dtype", TF_FLOAT);
518     tensor = TF_AllocateTensor(TF_FLOAT, NULL, 0, TF_DataTypeSize(TF_FLOAT));
519     y = (float *)TF_TensorData(tensor);
520     *y = params->val.y;
521     TF_SetAttrTensor(op_desc, "value", tensor, tf_model->status);
522     if (TF_GetCode(tf_model->status) != TF_OK){
523         TF_DeleteTensor(tensor);
524         av_log(ctx, AV_LOG_ERROR, "Failed to set value for maximum/y of layer %d", layer);
525         return DNN_ERROR;
526     }
527     op = TF_FinishOperation(op_desc, tf_model->status);
528     if (TF_GetCode(tf_model->status) != TF_OK){
529         TF_DeleteTensor(tensor);
530         av_log(ctx, AV_LOG_ERROR, "Failed to add maximum/y to layer %d\n", layer);
531         return DNN_ERROR;
532     }
533
534     snprintf(name_buffer, NAME_BUFFER_SIZE, "maximum%d", layer);
535     op_desc = TF_NewOperation(tf_model->graph, "Maximum", name_buffer);
536     input.oper = *cur_op;
537     input.index = 0;
538     TF_AddInput(op_desc, input);
539     input.oper = op;
540     TF_AddInput(op_desc, input);
541     TF_SetAttrType(op_desc, "T", TF_FLOAT);
542     *cur_op = TF_FinishOperation(op_desc, tf_model->status);
543     if (TF_GetCode(tf_model->status) != TF_OK){
544         TF_DeleteTensor(tensor);
545         av_log(ctx, AV_LOG_ERROR, "Failed to add maximum to layer %d\n", layer);
546         return DNN_ERROR;
547     }
548
549     return DNN_SUCCESS;
550 }
551
552 static DNNReturnType load_native_model(TFModel *tf_model, const char *model_filename)
553 {
554     TFContext *ctx = &tf_model->ctx;
555     int32_t layer;
556     TF_OperationDescription *op_desc;
557     TF_Operation *op;
558     TF_Operation *transpose_op;
559     TF_Tensor *tensor = NULL;
560     TF_Output input;
561     int32_t *transpose_perm;
562     int64_t transpose_perm_shape[] = {4};
563     int64_t input_shape[] = {1, -1, -1, -1};
564     DNNReturnType layer_add_res;
565     DNNModel *model = NULL;
566     NativeModel *native_model;
567
568     model = ff_dnn_load_model_native(model_filename, DFT_PROCESS_FRAME, NULL, NULL);
569     if (!model){
570         av_log(ctx, AV_LOG_ERROR, "Failed to load native model\n");
571         return DNN_ERROR;
572     }
573
574     native_model = model->model;
575     tf_model->graph = TF_NewGraph();
576     tf_model->status = TF_NewStatus();
577
578 #define CLEANUP_ON_ERROR(tf_model) \
579     { \
580         TF_DeleteTensor(tensor); \
581         TF_DeleteGraph(tf_model->graph); \
582         TF_DeleteStatus(tf_model->status); \
583         av_log(ctx, AV_LOG_ERROR, "Failed to set value or add operator to layer\n"); \
584         return DNN_ERROR; \
585     }
586
587     op_desc = TF_NewOperation(tf_model->graph, "Placeholder", "x");
588     TF_SetAttrType(op_desc, "dtype", TF_FLOAT);
589     TF_SetAttrShape(op_desc, "shape", input_shape, 4);
590     op = TF_FinishOperation(op_desc, tf_model->status);
591     if (TF_GetCode(tf_model->status) != TF_OK){
592         CLEANUP_ON_ERROR(tf_model);
593     }
594
595     op_desc = TF_NewOperation(tf_model->graph, "Const", "transpose_perm");
596     TF_SetAttrType(op_desc, "dtype", TF_INT32);
597     tensor = TF_AllocateTensor(TF_INT32, transpose_perm_shape, 1, 4 * sizeof(int32_t));
598     transpose_perm = (int32_t *)TF_TensorData(tensor);
599     transpose_perm[0] = 1;
600     transpose_perm[1] = 2;
601     transpose_perm[2] = 3;
602     transpose_perm[3] = 0;
603     TF_SetAttrTensor(op_desc, "value", tensor, tf_model->status);
604     if (TF_GetCode(tf_model->status) != TF_OK){
605         CLEANUP_ON_ERROR(tf_model);
606     }
607     transpose_op = TF_FinishOperation(op_desc, tf_model->status);
608     if (TF_GetCode(tf_model->status) != TF_OK){
609         CLEANUP_ON_ERROR(tf_model);
610     }
611
612     for (layer = 0; layer < native_model->layers_num; ++layer){
613         switch (native_model->layers[layer].type){
614         case DLT_INPUT:
615             layer_add_res = DNN_SUCCESS;
616             break;
617         case DLT_CONV2D:
618             layer_add_res = add_conv_layer(tf_model, transpose_op, &op,
619                                            (ConvolutionalParams *)native_model->layers[layer].params, layer);
620             break;
621         case DLT_DEPTH_TO_SPACE:
622             layer_add_res = add_depth_to_space_layer(tf_model, &op,
623                                                      (DepthToSpaceParams *)native_model->layers[layer].params, layer);
624             break;
625         case DLT_MIRROR_PAD:
626             layer_add_res = add_pad_layer(tf_model, &op,
627                                           (LayerPadParams *)native_model->layers[layer].params, layer);
628             break;
629         case DLT_MAXIMUM:
630             layer_add_res = add_maximum_layer(tf_model, &op,
631                                           (DnnLayerMaximumParams *)native_model->layers[layer].params, layer);
632             break;
633         default:
634             CLEANUP_ON_ERROR(tf_model);
635         }
636
637         if (layer_add_res != DNN_SUCCESS){
638             CLEANUP_ON_ERROR(tf_model);
639         }
640     }
641
642     op_desc = TF_NewOperation(tf_model->graph, "Identity", "y");
643     input.oper = op;
644     input.index = 0;
645     TF_AddInput(op_desc, input);
646     TF_FinishOperation(op_desc, tf_model->status);
647     if (TF_GetCode(tf_model->status) != TF_OK){
648         CLEANUP_ON_ERROR(tf_model);
649     }
650
651     ff_dnn_free_model_native(&model);
652
653     return DNN_SUCCESS;
654 }
655
656 DNNModel *ff_dnn_load_model_tf(const char *model_filename, DNNFunctionType func_type, const char *options, AVFilterContext *filter_ctx)
657 {
658     DNNModel *model = NULL;
659     TFModel *tf_model = NULL;
660
661     model = av_mallocz(sizeof(DNNModel));
662     if (!model){
663         return NULL;
664     }
665
666     tf_model = av_mallocz(sizeof(TFModel));
667     if (!tf_model){
668         av_freep(&model);
669         return NULL;
670     }
671     tf_model->ctx.class = &dnn_tensorflow_class;
672     tf_model->model = model;
673
674     //parse options
675     av_opt_set_defaults(&tf_model->ctx);
676     if (av_opt_set_from_string(&tf_model->ctx, options, NULL, "=", "&") < 0) {
677         av_log(&tf_model->ctx, AV_LOG_ERROR, "Failed to parse options \"%s\"\n", options);
678         av_freep(&tf_model);
679         av_freep(&model);
680         return NULL;
681     }
682
683     if (load_tf_model(tf_model, model_filename) != DNN_SUCCESS){
684         if (load_native_model(tf_model, model_filename) != DNN_SUCCESS){
685             av_freep(&tf_model);
686             av_freep(&model);
687
688             return NULL;
689         }
690     }
691
692     model->model = tf_model;
693     model->get_input = &get_input_tf;
694     model->get_output = &get_output_tf;
695     model->options = options;
696     model->filter_ctx = filter_ctx;
697     model->func_type = func_type;
698
699     return model;
700 }
701
702 static DNNReturnType execute_model_tf(const DNNModel *model, const char *input_name, AVFrame *in_frame,
703                                       const char **output_names, uint32_t nb_output, AVFrame *out_frame,
704                                       int do_ioproc)
705 {
706     TF_Output *tf_outputs;
707     TFModel *tf_model = model->model;
708     TFContext *ctx = &tf_model->ctx;
709     DNNData input, output;
710     TF_Tensor **output_tensors;
711     TF_Output tf_input;
712     TF_Tensor *input_tensor;
713
714     if (get_input_tf(tf_model, &input, input_name) != DNN_SUCCESS)
715         return DNN_ERROR;
716     input.height = in_frame->height;
717     input.width = in_frame->width;
718
719     tf_input.oper = TF_GraphOperationByName(tf_model->graph, input_name);
720     if (!tf_input.oper){
721         av_log(ctx, AV_LOG_ERROR, "Could not find \"%s\" in model\n", input_name);
722         return DNN_ERROR;
723     }
724     tf_input.index = 0;
725     input_tensor = allocate_input_tensor(&input);
726     if (!input_tensor){
727         av_log(ctx, AV_LOG_ERROR, "Failed to allocate memory for input tensor\n");
728         return DNN_ERROR;
729     }
730     input.data = (float *)TF_TensorData(input_tensor);
731
732     if (do_ioproc) {
733         if (tf_model->model->frame_pre_proc != NULL) {
734             tf_model->model->frame_pre_proc(in_frame, &input, tf_model->model->filter_ctx);
735         } else {
736             ff_proc_from_frame_to_dnn(in_frame, &input, tf_model->model->func_type, ctx);
737         }
738     }
739
740     if (nb_output != 1) {
741         // currently, the filter does not need multiple outputs,
742         // so we just pending the support until we really need it.
743         TF_DeleteTensor(input_tensor);
744         avpriv_report_missing_feature(ctx, "multiple outputs");
745         return DNN_ERROR;
746     }
747
748     tf_outputs = av_malloc_array(nb_output, sizeof(*tf_outputs));
749     if (tf_outputs == NULL) {
750         TF_DeleteTensor(input_tensor);
751         av_log(ctx, AV_LOG_ERROR, "Failed to allocate memory for *tf_outputs\n"); \
752         return DNN_ERROR;
753     }
754
755     output_tensors = av_mallocz_array(nb_output, sizeof(*output_tensors));
756     if (!output_tensors) {
757         TF_DeleteTensor(input_tensor);
758         av_freep(&tf_outputs);
759         av_log(ctx, AV_LOG_ERROR, "Failed to allocate memory for output tensor\n"); \
760         return DNN_ERROR;
761     }
762
763     for (int i = 0; i < nb_output; ++i) {
764         tf_outputs[i].oper = TF_GraphOperationByName(tf_model->graph, output_names[i]);
765         if (!tf_outputs[i].oper) {
766             TF_DeleteTensor(input_tensor);
767             av_freep(&tf_outputs);
768             av_freep(&output_tensors);
769             av_log(ctx, AV_LOG_ERROR, "Could not find output \"%s\" in model\n", output_names[i]); \
770             return DNN_ERROR;
771         }
772         tf_outputs[i].index = 0;
773     }
774
775     TF_SessionRun(tf_model->session, NULL,
776                   &tf_input, &input_tensor, 1,
777                   tf_outputs, output_tensors, nb_output,
778                   NULL, 0, NULL, tf_model->status);
779     if (TF_GetCode(tf_model->status) != TF_OK) {
780         TF_DeleteTensor(input_tensor);
781         av_freep(&tf_outputs);
782         av_freep(&output_tensors);
783         av_log(ctx, AV_LOG_ERROR, "Failed to run session when executing model\n");
784         return DNN_ERROR;
785     }
786
787     for (uint32_t i = 0; i < nb_output; ++i) {
788         output.height = TF_Dim(output_tensors[i], 1);
789         output.width = TF_Dim(output_tensors[i], 2);
790         output.channels = TF_Dim(output_tensors[i], 3);
791         output.data = TF_TensorData(output_tensors[i]);
792         output.dt = TF_TensorType(output_tensors[i]);
793
794         if (do_ioproc) {
795             if (tf_model->model->frame_post_proc != NULL) {
796                 tf_model->model->frame_post_proc(out_frame, &output, tf_model->model->filter_ctx);
797             } else {
798                 ff_proc_from_dnn_to_frame(out_frame, &output, ctx);
799             }
800         } else {
801             out_frame->width = output.width;
802             out_frame->height = output.height;
803         }
804     }
805
806     for (uint32_t i = 0; i < nb_output; ++i) {
807         if (output_tensors[i]) {
808             TF_DeleteTensor(output_tensors[i]);
809         }
810     }
811     TF_DeleteTensor(input_tensor);
812     av_freep(&output_tensors);
813     av_freep(&tf_outputs);
814     return DNN_SUCCESS;
815 }
816
817 DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, const char *input_name, AVFrame *in_frame,
818                                       const char **output_names, uint32_t nb_output, AVFrame *out_frame)
819 {
820     TFModel *tf_model = model->model;
821     TFContext *ctx = &tf_model->ctx;
822
823     if (!in_frame) {
824         av_log(ctx, AV_LOG_ERROR, "in frame is NULL when execute model.\n");
825         return DNN_ERROR;
826     }
827
828     if (!out_frame) {
829         av_log(ctx, AV_LOG_ERROR, "out frame is NULL when execute model.\n");
830         return DNN_ERROR;
831     }
832
833     return execute_model_tf(model, input_name, in_frame, output_names, nb_output, out_frame, 1);
834 }
835
836 void ff_dnn_free_model_tf(DNNModel **model)
837 {
838     TFModel *tf_model;
839
840     if (*model){
841         tf_model = (*model)->model;
842         if (tf_model->graph){
843             TF_DeleteGraph(tf_model->graph);
844         }
845         if (tf_model->session){
846             TF_CloseSession(tf_model->session, tf_model->status);
847             TF_DeleteSession(tf_model->session, tf_model->status);
848         }
849         if (tf_model->status){
850             TF_DeleteStatus(tf_model->status);
851         }
852         av_freep(&tf_model);
853         av_freep(model);
854     }
855 }