4 * This file is part of FFmpeg.
6 * FFmpeg is free software; you can redistribute it and/or
7 * modify it under the terms of the GNU Lesser General Public
8 * License as published by the Free Software Foundation; either
9 * version 2.1 of the License, or (at your option) any later version.
11 * FFmpeg is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14 * Lesser General Public License for more details.
16 * You should have received a copy of the GNU Lesser General Public
17 * License along with FFmpeg; if not, write to the Free Software
18 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
23 * DNN native backend implementation.
26 #include "dnn_backend_native.h"
27 #include "libavutil/avassert.h"
28 #include "dnn_backend_native_layer_mathbinary.h"
30 typedef float (*FunType)(float src0, float src1);
32 static float sub(float src0, float src1)
36 static float add(float src0, float src1)
40 static float mul(float src0, float src1)
44 static float realdiv(float src0, float src1)
48 static float minimum(float src0, float src1)
50 return FFMIN(src0, src1);
52 static float floormod(float src0, float src1)
54 return (float)((int)(src0) % (int)(src1));
57 static void math_binary_commutative(FunType pfun, const DnnLayerMathBinaryParams *params, const DnnOperand *input, DnnOperand *output, DnnOperand *operands, const int32_t *input_operand_indexes)
62 dims_count = ff_calculate_operand_dims_count(output);
65 if (params->input0_broadcast || params->input1_broadcast) {
66 for (int i = 0; i < dims_count; ++i) {
67 dst[i] = pfun(params->v, src[i]);
70 const DnnOperand *input1 = &operands[input_operand_indexes[1]];
71 const float *src1 = input1->data;
72 for (int i = 0; i < dims_count; ++i) {
73 dst[i] = pfun(src[i], src1[i]);
77 static void math_binary_not_commutative(FunType pfun, const DnnLayerMathBinaryParams *params, const DnnOperand *input, DnnOperand *output, DnnOperand *operands, const int32_t *input_operand_indexes)
82 dims_count = ff_calculate_operand_dims_count(output);
85 if (params->input0_broadcast) {
86 for (int i = 0; i < dims_count; ++i) {
87 dst[i] = pfun(params->v, src[i]);
89 } else if (params->input1_broadcast) {
90 for (int i = 0; i < dims_count; ++i) {
91 dst[i] = pfun(src[i], params->v);
94 const DnnOperand *input1 = &operands[input_operand_indexes[1]];
95 const float *src1 = input1->data;
96 for (int i = 0; i < dims_count; ++i) {
97 dst[i] = pfun(src[i], src1[i]);
101 int ff_dnn_load_layer_math_binary(Layer *layer, AVIOContext *model_file_context, int file_size, int operands_num)
103 DnnLayerMathBinaryParams params = { 0 };
107 params.bin_op = (int32_t)avio_rl32(model_file_context);
110 params.input0_broadcast = (int32_t)avio_rl32(model_file_context);
112 if (params.input0_broadcast) {
113 params.v = av_int2float(avio_rl32(model_file_context));
115 layer->input_operand_indexes[input_index] = (int32_t)avio_rl32(model_file_context);
116 if (layer->input_operand_indexes[input_index] >= operands_num) {
123 params.input1_broadcast = (int32_t)avio_rl32(model_file_context);
125 if (params.input1_broadcast) {
126 params.v = av_int2float(avio_rl32(model_file_context));
128 layer->input_operand_indexes[input_index] = (int32_t)avio_rl32(model_file_context);
129 if (layer->input_operand_indexes[input_index] >= operands_num) {
136 layer->output_operand_index = (int32_t)avio_rl32(model_file_context);
139 if (layer->output_operand_index >= operands_num) {
142 layer->params = av_memdup(¶ms, sizeof(params));
149 int ff_dnn_execute_layer_math_binary(DnnOperand *operands, const int32_t *input_operand_indexes,
150 int32_t output_operand_index, const void *parameters, NativeContext *ctx)
152 const DnnOperand *input = &operands[input_operand_indexes[0]];
153 DnnOperand *output = &operands[output_operand_index];
154 const DnnLayerMathBinaryParams *params = parameters;
156 for (int i = 0; i < 4; ++i)
157 output->dims[i] = input->dims[i];
159 output->data_type = input->data_type;
160 output->length = ff_calculate_operand_data_length(output);
161 if (output->length <= 0) {
162 av_log(ctx, AV_LOG_ERROR, "The output data length overflow\n");
165 output->data = av_realloc(output->data, output->length);
167 av_log(ctx, AV_LOG_ERROR, "Failed to reallocate memory for output\n");
171 switch (params->bin_op) {
173 math_binary_not_commutative(sub, params, input, output, operands, input_operand_indexes);
176 math_binary_commutative(add, params, input, output, operands, input_operand_indexes);
179 math_binary_commutative(mul, params, input, output, operands, input_operand_indexes);
182 math_binary_not_commutative(realdiv, params, input, output, operands, input_operand_indexes);
185 math_binary_commutative(minimum, params, input, output, operands, input_operand_indexes);
188 math_binary_not_commutative(floormod, params, input, output, operands, input_operand_indexes);
191 av_log(ctx, AV_LOG_ERROR, "Unmatch math binary operator\n");