#include "dnn_filter_common.h"
-int ff_dnn_init(DnnContext *ctx, AVFilterContext *filter_ctx)
+int ff_dnn_init(DnnContext *ctx, DNNFunctionType func_type, AVFilterContext *filter_ctx)
{
if (!ctx->model_filename) {
av_log(filter_ctx, AV_LOG_ERROR, "model file for network is not specified\n");
return AVERROR(EINVAL);
}
- ctx->model = (ctx->dnn_module->load_model)(ctx->model_filename, ctx->backend_options, filter_ctx);
+ ctx->model = (ctx->dnn_module->load_model)(ctx->model_filename, func_type, ctx->backend_options, filter_ctx);
if (!ctx->model) {
av_log(filter_ctx, AV_LOG_ERROR, "could not load DNN model\n");
return AVERROR(EINVAL);
return 0;
}
+int ff_dnn_set_frame_proc(DnnContext *ctx, FramePrePostProc pre_proc, FramePrePostProc post_proc)
+{
+ ctx->model->frame_pre_proc = pre_proc;
+ ctx->model->frame_post_proc = post_proc;
+ return 0;
+}
+
+int ff_dnn_set_detect_post_proc(DnnContext *ctx, DetectPostProc post_proc)
+{
+ ctx->model->detect_post_proc = post_proc;
+ return 0;
+}
+
DNNReturnType ff_dnn_get_input(DnnContext *ctx, DNNData *input)
{
return ctx->model->get_input(ctx->model->model, input, ctx->model_inputname);