4 * This file is part of FFmpeg.
6 * FFmpeg is free software; you can redistribute it and/or
7 * modify it under the terms of the GNU Lesser General Public
8 * License as published by the Free Software Foundation; either
9 * version 2.1 of the License, or (at your option) any later version.
11 * FFmpeg is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14 * Lesser General Public License for more details.
16 * You should have received a copy of the GNU Lesser General Public
17 * License along with FFmpeg; if not, write to the Free Software
18 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
23 * DNN native backend implementation.
26 #include "dnn_backend_native.h"
27 #include "libavutil/avassert.h"
28 #include "dnn_backend_native_layer_mathbinary.h"
30 typedef float (*FunType)(float src0, float src1);
32 static float sub(float src0, float src1)
36 static float add(float src0, float src1)
40 static float mul(float src0, float src1)
44 static float realdiv(float src0, float src1)
48 static float minimum(float src0, float src1)
50 return FFMIN(src0, src1);
52 static float floormod(float src0, float src1)
54 return (float)((int)(src0) % (int)(src1));
57 static void math_binary_commutative(FunType pfun, const DnnLayerMathBinaryParams *params, const DnnOperand *input, DnnOperand *output, DnnOperand *operands, const int32_t *input_operand_indexes)
62 dims_count = ff_calculate_operand_dims_count(output);
65 if (params->input0_broadcast || params->input1_broadcast) {
66 for (int i = 0; i < dims_count; ++i) {
67 dst[i] = pfun(params->v, src[i]);
70 const DnnOperand *input1 = &operands[input_operand_indexes[1]];
71 const float *src1 = input1->data;
72 for (int i = 0; i < dims_count; ++i) {
73 dst[i] = pfun(src[i], src1[i]);
77 static void math_binary_not_commutative(FunType pfun, const DnnLayerMathBinaryParams *params, const DnnOperand *input, DnnOperand *output, DnnOperand *operands, const int32_t *input_operand_indexes)
82 dims_count = ff_calculate_operand_dims_count(output);
85 if (params->input0_broadcast) {
86 for (int i = 0; i < dims_count; ++i) {
87 dst[i] = pfun(params->v, src[i]);
89 } else if (params->input1_broadcast) {
90 for (int i = 0; i < dims_count; ++i) {
91 dst[i] = pfun(src[i], params->v);
94 const DnnOperand *input1 = &operands[input_operand_indexes[1]];
95 const float *src1 = input1->data;
96 for (int i = 0; i < dims_count; ++i) {
97 dst[i] = pfun(src[i], src1[i]);
101 int ff_dnn_load_layer_math_binary(Layer *layer, AVIOContext *model_file_context, int file_size, int operands_num)
103 DnnLayerMathBinaryParams *params;
106 params = av_malloc(sizeof(*params));
110 params->bin_op = (int32_t)avio_rl32(model_file_context);
113 params->input0_broadcast = (int32_t)avio_rl32(model_file_context);
115 if (params->input0_broadcast) {
116 params->v = av_int2float(avio_rl32(model_file_context));
118 layer->input_operand_indexes[input_index] = (int32_t)avio_rl32(model_file_context);
119 if (layer->input_operand_indexes[input_index] >= operands_num) {
126 params->input1_broadcast = (int32_t)avio_rl32(model_file_context);
128 if (params->input1_broadcast) {
129 params->v = av_int2float(avio_rl32(model_file_context));
131 layer->input_operand_indexes[input_index] = (int32_t)avio_rl32(model_file_context);
132 if (layer->input_operand_indexes[input_index] >= operands_num) {
139 layer->output_operand_index = (int32_t)avio_rl32(model_file_context);
141 layer->params = params;
143 if (layer->output_operand_index >= operands_num) {
150 int ff_dnn_execute_layer_math_binary(DnnOperand *operands, const int32_t *input_operand_indexes,
151 int32_t output_operand_index, const void *parameters, NativeContext *ctx)
153 const DnnOperand *input = &operands[input_operand_indexes[0]];
154 DnnOperand *output = &operands[output_operand_index];
155 const DnnLayerMathBinaryParams *params = (const DnnLayerMathBinaryParams *)parameters;
157 for (int i = 0; i < 4; ++i)
158 output->dims[i] = input->dims[i];
160 output->data_type = input->data_type;
161 output->length = ff_calculate_operand_data_length(output);
162 if (output->length <= 0) {
163 av_log(ctx, AV_LOG_ERROR, "The output data length overflow\n");
166 output->data = av_realloc(output->data, output->length);
168 av_log(ctx, AV_LOG_ERROR, "Failed to reallocate memory for output\n");
172 switch (params->bin_op) {
174 math_binary_not_commutative(sub, params, input, output, operands, input_operand_indexes);
177 math_binary_commutative(add, params, input, output, operands, input_operand_indexes);
180 math_binary_commutative(mul, params, input, output, operands, input_operand_indexes);
183 math_binary_not_commutative(realdiv, params, input, output, operands, input_operand_indexes);
186 math_binary_commutative(minimum, params, input, output, operands, input_operand_indexes);
189 math_binary_not_commutative(floormod, params, input, output, operands, input_operand_indexes);
192 av_log(ctx, AV_LOG_ERROR, "Unmatch math binary operator\n");