]> git.sesse.net Git - ffmpeg/blob - tests/dnn/dnn-layer-mathbinary-test.c
avcodec: Remove deprecated AVCodecContext.coded_frame
[ffmpeg] / tests / dnn / dnn-layer-mathbinary-test.c
1 /*
2  * Copyright (c) 2020
3  *
4  * This file is part of FFmpeg.
5  *
6  * FFmpeg is free software; you can redistribute it and/or
7  * modify it under the terms of the GNU Lesser General Public
8  * License as published by the Free Software Foundation; either
9  * version 2.1 of the License, or (at your option) any later version.
10  *
11  * FFmpeg is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14  * Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public
17  * License along with FFmpeg; if not, write to the Free Software
18  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
19  */
20
21 #include <stdio.h>
22 #include <string.h>
23 #include <math.h>
24 #include "libavfilter/dnn/dnn_backend_native_layer_mathbinary.h"
25 #include "libavutil/avassert.h"
26
27 #define EPSON 0.00005
28
29 static float get_expected(float f1, float f2, DNNMathBinaryOperation op)
30 {
31     switch (op)
32     {
33     case DMBO_SUB:
34         return f1 - f2;
35     case DMBO_ADD:
36         return f1 + f2;
37     case DMBO_MUL:
38         return f1 * f2;
39     case DMBO_REALDIV:
40         return f1 / f2;
41     case DMBO_MINIMUM:
42         return (f1 < f2) ? f1 : f2;
43     case DMBO_FLOORMOD:
44         return (float)((int)(f1) % (int)(f2));
45     default:
46         av_assert0(!"not supported yet");
47         return 0.f;
48     }
49 }
50
51 static int test_broadcast_input0(DNNMathBinaryOperation op)
52 {
53     DnnLayerMathBinaryParams params;
54     DnnOperand operands[2];
55     int32_t input_indexes[1];
56     float input[1*1*2*3] = {
57         -3, 2.5, 2, -2.1, 7.8, 100
58     };
59     float *output;
60
61     params.bin_op = op;
62     params.input0_broadcast = 1;
63     params.input1_broadcast = 0;
64     params.v = 7.28;
65
66     operands[0].data = input;
67     operands[0].dims[0] = 1;
68     operands[0].dims[1] = 1;
69     operands[0].dims[2] = 2;
70     operands[0].dims[3] = 3;
71     operands[1].data = NULL;
72
73     input_indexes[0] = 0;
74     ff_dnn_execute_layer_math_binary(operands, input_indexes, 1, &params, NULL);
75
76     output = operands[1].data;
77     for (int i = 0; i < sizeof(input) / sizeof(float); i++) {
78         float expected_output = get_expected(params.v, input[i], op);
79         if (fabs(output[i] - expected_output) > EPSON) {
80             printf("op %d, at index %d, output: %f, expected_output: %f (%s:%d)\n",
81                     op, i, output[i], expected_output, __FILE__, __LINE__);
82             av_freep(&output);
83             return 1;
84         }
85     }
86
87     av_freep(&output);
88     return 0;
89 }
90
91 static int test_broadcast_input1(DNNMathBinaryOperation op)
92 {
93     DnnLayerMathBinaryParams params;
94     DnnOperand operands[2];
95     int32_t input_indexes[1];
96     float input[1*1*2*3] = {
97         -3, 2.5, 2, -2.1, 7.8, 100
98     };
99     float *output;
100
101     params.bin_op = op;
102     params.input0_broadcast = 0;
103     params.input1_broadcast = 1;
104     params.v = 7.28;
105
106     operands[0].data = input;
107     operands[0].dims[0] = 1;
108     operands[0].dims[1] = 1;
109     operands[0].dims[2] = 2;
110     operands[0].dims[3] = 3;
111     operands[1].data = NULL;
112
113     input_indexes[0] = 0;
114     ff_dnn_execute_layer_math_binary(operands, input_indexes, 1, &params, NULL);
115
116     output = operands[1].data;
117     for (int i = 0; i < sizeof(input) / sizeof(float); i++) {
118         float expected_output = get_expected(input[i], params.v, op);
119         if (fabs(output[i] - expected_output) > EPSON) {
120             printf("op %d, at index %d, output: %f, expected_output: %f (%s:%d)\n",
121                     op, i, output[i], expected_output, __FILE__, __LINE__);
122             av_freep(&output);
123             return 1;
124         }
125     }
126
127     av_freep(&output);
128     return 0;
129 }
130
131 static int test_no_broadcast(DNNMathBinaryOperation op)
132 {
133     DnnLayerMathBinaryParams params;
134     DnnOperand operands[3];
135     int32_t input_indexes[2];
136     float input0[1*1*2*3] = {
137         -3, 2.5, 2, -2.1, 7.8, 100
138     };
139     float input1[1*1*2*3] = {
140         -1, 2, 3, -21, 8, 10.0
141     };
142     float *output;
143
144     params.bin_op = op;
145     params.input0_broadcast = 0;
146     params.input1_broadcast = 0;
147
148     operands[0].data = input0;
149     operands[0].dims[0] = 1;
150     operands[0].dims[1] = 1;
151     operands[0].dims[2] = 2;
152     operands[0].dims[3] = 3;
153     operands[1].data = input1;
154     operands[1].dims[0] = 1;
155     operands[1].dims[1] = 1;
156     operands[1].dims[2] = 2;
157     operands[1].dims[3] = 3;
158     operands[2].data = NULL;
159
160     input_indexes[0] = 0;
161     input_indexes[1] = 1;
162     ff_dnn_execute_layer_math_binary(operands, input_indexes, 2, &params, NULL);
163
164     output = operands[2].data;
165     for (int i = 0; i < sizeof(input0) / sizeof(float); i++) {
166         float expected_output = get_expected(input0[i], input1[i], op);
167         if (fabs(output[i] - expected_output) > EPSON) {
168             printf("op %d, at index %d, output: %f, expected_output: %f (%s:%d)\n",
169                     op, i, output[i], expected_output, __FILE__, __LINE__);
170             av_freep(&output);
171             return 1;
172         }
173     }
174
175     av_freep(&output);
176     return 0;
177 }
178
179 static int test(DNNMathBinaryOperation op)
180 {
181     if (test_broadcast_input0(op))
182         return 1;
183
184     if (test_broadcast_input1(op))
185         return 1;
186
187     if (test_no_broadcast(op))
188         return 1;
189
190     return 0;
191 }
192
193 int main(int argc, char **argv)
194 {
195     if (test(DMBO_SUB))
196         return 1;
197
198     if (test(DMBO_ADD))
199         return 1;
200
201     if (test(DMBO_MUL))
202         return 1;
203
204     if (test(DMBO_REALDIV))
205         return 1;
206
207     if (test(DMBO_MINIMUM))
208         return 1;
209
210     if (test(DMBO_FLOORMOD))
211         return 1;
212
213     return 0;
214 }