#include "dnn_backend_native_layer_depth2space.h"
#include "libavformat/avio.h"
#include "libavutil/avassert.h"
+#include "../internal.h"
#include "dnn_backend_native_layer_pad.h"
#include "dnn_backend_native_layer_maximum.h"
#include "dnn_io_proc.h"
#include <tensorflow/c/c_api.h>
+typedef struct TFOptions{
+ char *sess_config;
+} TFOptions;
+
typedef struct TFContext {
const AVClass *class;
+ TFOptions options;
} TFContext;
typedef struct TFModel{
TF_Status *status;
} TFModel;
-static const AVClass dnn_tensorflow_class = {
- .class_name = "dnn_tensorflow",
- .item_name = av_default_item_name,
- .option = NULL,
- .version = LIBAVUTIL_VERSION_INT,
- .category = AV_CLASS_CATEGORY_FILTER,
+#define OFFSET(x) offsetof(TFContext, x)
+#define FLAGS AV_OPT_FLAG_FILTERING_PARAM
+static const AVOption dnn_tensorflow_options[] = {
+ { "sess_config", "config for SessionOptions", OFFSET(options.sess_config), AV_OPT_TYPE_STRING, { .str = NULL }, 0, 0, FLAGS },
+ { NULL }
};
+AVFILTER_DEFINE_CLASS(dnn_tensorflow);
+
static DNNReturnType execute_model_tf(const DNNModel *model, const char *input_name, AVFrame *in_frame,
const char **output_names, uint32_t nb_output, AVFrame *out_frame,
int do_ioproc);
TF_ImportGraphDefOptions *graph_opts;
TF_SessionOptions *sess_opts;
const TF_Operation *init_op;
+ uint8_t *sess_config = NULL;
+ int sess_config_length = 0;
+
+ // prepare the sess config data
+ if (tf_model->ctx.options.sess_config != NULL) {
+ /*
+ tf_model->ctx.options.sess_config is hex to present the serialized proto
+ required by TF_SetConfig below, so we need to first generate the serialized
+ proto in a python script, the following is a script example to generate
+ serialized proto which specifies one GPU, we can change the script to add
+ more options.
+
+ import tensorflow as tf
+ gpu_options = tf.GPUOptions(visible_device_list='0')
+ config = tf.ConfigProto(gpu_options=gpu_options)
+ s = config.SerializeToString()
+ b = ''.join("%02x" % int(ord(b)) for b in s[::-1])
+ print('0x%s' % b)
+
+ the script output looks like: 0xab...cd, and then pass 0xab...cd to sess_config.
+ */
+ char tmp[3];
+ tmp[2] = '\0';
+
+ if (strncmp(tf_model->ctx.options.sess_config, "0x", 2) != 0) {
+ av_log(ctx, AV_LOG_ERROR, "sess_config should start with '0x'\n");
+ return DNN_ERROR;
+ }
+
+ sess_config_length = strlen(tf_model->ctx.options.sess_config);
+ if (sess_config_length % 2 != 0) {
+ av_log(ctx, AV_LOG_ERROR, "the length of sess_config is not even (%s), "
+ "please re-generate the config.\n",
+ tf_model->ctx.options.sess_config);
+ return DNN_ERROR;
+ }
+
+ sess_config_length -= 2; //ignore the first '0x'
+ sess_config_length /= 2; //get the data length in byte
+
+ sess_config = av_malloc(sess_config_length);
+ if (!sess_config) {
+ av_log(ctx, AV_LOG_ERROR, "failed to allocate memory\n");
+ return DNN_ERROR;
+ }
+
+ for (int i = 0; i < sess_config_length; i++) {
+ int index = 2 + (sess_config_length - 1 - i) * 2;
+ tmp[0] = tf_model->ctx.options.sess_config[index];
+ tmp[1] = tf_model->ctx.options.sess_config[index + 1];
+ sess_config[i] = strtol(tmp, NULL, 16);
+ }
+ }
graph_def = read_graph(model_filename);
if (!graph_def){
av_log(ctx, AV_LOG_ERROR, "Failed to read model \"%s\" graph\n", model_filename);
+ av_freep(&sess_config);
return DNN_ERROR;
}
tf_model->graph = TF_NewGraph();
TF_DeleteGraph(tf_model->graph);
TF_DeleteStatus(tf_model->status);
av_log(ctx, AV_LOG_ERROR, "Failed to import serialized graph to model graph\n");
+ av_freep(&sess_config);
return DNN_ERROR;
}
init_op = TF_GraphOperationByName(tf_model->graph, "init");
sess_opts = TF_NewSessionOptions();
+
+ if (sess_config) {
+ TF_SetConfig(sess_opts, sess_config, sess_config_length,tf_model->status);
+ av_freep(&sess_config);
+ if (TF_GetCode(tf_model->status) != TF_OK) {
+ av_log(ctx, AV_LOG_ERROR, "Failed to set config for sess options with %s\n",
+ tf_model->ctx.options.sess_config);
+ return DNN_ERROR;
+ }
+ }
+
tf_model->session = TF_NewSession(tf_model->graph, sess_opts, tf_model->status);
TF_DeleteSessionOptions(sess_opts);
if (TF_GetCode(tf_model->status) != TF_OK)
tf_model->ctx.class = &dnn_tensorflow_class;
tf_model->model = model;
+ //parse options
+ av_opt_set_defaults(&tf_model->ctx);
+ if (av_opt_set_from_string(&tf_model->ctx, options, NULL, "=", "&") < 0) {
+ av_log(&tf_model->ctx, AV_LOG_ERROR, "Failed to parse options \"%s\"\n", options);
+ av_freep(&tf_model);
+ av_freep(&model);
+ return NULL;
+ }
+
if (load_tf_model(tf_model, model_filename) != DNN_SUCCESS){
if (load_native_model(tf_model, model_filename) != DNN_SUCCESS){
av_freep(&tf_model);