]> git.sesse.net Git - ffmpeg/blob - libavfilter/dnn/dnn_backend_native_layer_mathbinary.c
dnn/native: add native support for minimum
[ffmpeg] / libavfilter / dnn / dnn_backend_native_layer_mathbinary.c
1 /*
2  * Copyright (c) 2020
3  *
4  * This file is part of FFmpeg.
5  *
6  * FFmpeg is free software; you can redistribute it and/or
7  * modify it under the terms of the GNU Lesser General Public
8  * License as published by the Free Software Foundation; either
9  * version 2.1 of the License, or (at your option) any later version.
10  *
11  * FFmpeg is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14  * Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public
17  * License along with FFmpeg; if not, write to the Free Software
18  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
19  */
20
21 /**
22  * @file
23  * DNN native backend implementation.
24  */
25
26 #include "dnn_backend_native.h"
27 #include "libavutil/avassert.h"
28 #include "dnn_backend_native_layer_mathbinary.h"
29
30 int dnn_load_layer_math_binary(Layer *layer, AVIOContext *model_file_context, int file_size)
31 {
32     DnnLayerMathBinaryParams *params;
33     int dnn_size = 0;
34     int input_index = 0;
35     params = av_malloc(sizeof(*params));
36     if (!params)
37         return 0;
38
39     params->bin_op = (int32_t)avio_rl32(model_file_context);
40     dnn_size += 4;
41
42     params->input0_broadcast = (int32_t)avio_rl32(model_file_context);
43     dnn_size += 4;
44     if (params->input0_broadcast) {
45         params->v = av_int2float(avio_rl32(model_file_context));
46     } else {
47         layer->input_operand_indexes[input_index] = (int32_t)avio_rl32(model_file_context);
48         input_index++;
49     }
50     dnn_size += 4;
51
52     params->input1_broadcast = (int32_t)avio_rl32(model_file_context);
53     dnn_size += 4;
54     if (params->input1_broadcast) {
55         params->v = av_int2float(avio_rl32(model_file_context));
56     } else {
57         layer->input_operand_indexes[input_index] = (int32_t)avio_rl32(model_file_context);
58         input_index++;
59     }
60     dnn_size += 4;
61
62     layer->output_operand_index = (int32_t)avio_rl32(model_file_context);
63     dnn_size += 4;
64     layer->params = params;
65
66     return dnn_size;
67 }
68
69 int dnn_execute_layer_math_binary(DnnOperand *operands, const int32_t *input_operand_indexes,
70                                  int32_t output_operand_index, const void *parameters)
71 {
72     const DnnOperand *input = &operands[input_operand_indexes[0]];
73     DnnOperand *output = &operands[output_operand_index];
74     const DnnLayerMathBinaryParams *params = (const DnnLayerMathBinaryParams *)parameters;
75     int dims_count;
76     const float *src;
77     float *dst;
78
79     for (int i = 0; i < 4; ++i)
80         output->dims[i] = input->dims[i];
81
82     output->data_type = input->data_type;
83     output->length = calculate_operand_data_length(output);
84     output->data = av_realloc(output->data, output->length);
85     if (!output->data)
86         return DNN_ERROR;
87
88     dims_count = calculate_operand_dims_count(output);
89     src = input->data;
90     dst = output->data;
91
92     switch (params->bin_op) {
93     case DMBO_SUB:
94         if (params->input0_broadcast) {
95             for (int i = 0; i < dims_count; ++i) {
96                 dst[i] = params->v - src[i];
97             }
98         } else if (params->input1_broadcast) {
99             for (int i = 0; i < dims_count; ++i) {
100                 dst[i] = src[i] - params->v;
101             }
102         } else {
103             const DnnOperand *input1 = &operands[input_operand_indexes[1]];
104             const float *src1 = input1->data;
105             for (int i = 0; i < dims_count; ++i) {
106                 dst[i] = src[i] - src1[i];
107             }
108         }
109         return 0;
110     case DMBO_ADD:
111         if (params->input0_broadcast || params->input1_broadcast) {
112             for (int i = 0; i < dims_count; ++i) {
113                 dst[i] = params->v + src[i];
114             }
115         } else {
116             const DnnOperand *input1 = &operands[input_operand_indexes[1]];
117             const float *src1 = input1->data;
118             for (int i = 0; i < dims_count; ++i) {
119                 dst[i] = src[i] + src1[i];
120             }
121         }
122         return 0;
123     case DMBO_MUL:
124         if (params->input0_broadcast || params->input1_broadcast) {
125             for (int i = 0; i < dims_count; ++i) {
126                 dst[i] = params->v * src[i];
127             }
128         } else {
129             const DnnOperand *input1 = &operands[input_operand_indexes[1]];
130             const float *src1 = input1->data;
131             for (int i = 0; i < dims_count; ++i) {
132                 dst[i] = src[i] * src1[i];
133             }
134         }
135         return 0;
136     case DMBO_REALDIV:
137         if (params->input0_broadcast) {
138             for (int i = 0; i < dims_count; ++i) {
139                 dst[i] = params->v / src[i];
140             }
141         } else if (params->input1_broadcast) {
142             for (int i = 0; i < dims_count; ++i) {
143                 dst[i] = src[i] / params->v;
144             }
145         } else {
146             const DnnOperand *input1 = &operands[input_operand_indexes[1]];
147             const float *src1 = input1->data;
148             for (int i = 0; i < dims_count; ++i) {
149                 dst[i] = src[i] / src1[i];
150             }
151         }
152         return 0;
153     case DMBO_MINIMUM:
154         if (params->input0_broadcast || params->input1_broadcast) {
155             for (int i = 0; i < dims_count; ++i) {
156                 dst[i] = FFMIN(params->v, src[i]);
157             }
158         } else {
159             const DnnOperand *input1 = &operands[input_operand_indexes[1]];
160             const float *src1 = input1->data;
161             for (int i = 0; i < dims_count; ++i) {
162                 dst[i] = FFMIN(src[i], src1[i]);
163             }
164         }
165         return 0;
166     default:
167         return -1;
168     }
169 }