]> git.sesse.net Git - bcachefs-tools-debian/blob - include/linux/mean_and_variance.h
756eb3d1ca641a2acebf5d52b05ebb551eaad5f2
[bcachefs-tools-debian] / include / linux / mean_and_variance.h
1 /* SPDX-License-Identifier: GPL-2.0 */
2 #ifndef MEAN_AND_VARIANCE_H_
3 #define MEAN_AND_VARIANCE_H_
4
5 #include <linux/kernel.h>
6 #include <linux/types.h>
7 #include <linux/limits.h>
8 #include <linux/math64.h>
9
10 #define SQRT_U64_MAX 4294967295ULL
11
12 /**
13  * abs - return absolute value of an argument
14  * @x: the value.  If it is unsigned type, it is converted to signed type first.
15  *     char is treated as if it was signed (regardless of whether it really is)
16  *     but the macro's return type is preserved as char.
17  *
18  * Return: an absolute value of x.
19  */
20 #define abs(x)  __abs_choose_expr(x, long long,                         \
21                 __abs_choose_expr(x, long,                              \
22                 __abs_choose_expr(x, int,                               \
23                 __abs_choose_expr(x, short,                             \
24                 __abs_choose_expr(x, char,                              \
25                 __builtin_choose_expr(                                  \
26                         __builtin_types_compatible_p(typeof(x), char),  \
27                         (char)({ signed char __x = (x); __x<0?-__x:__x; }), \
28                         ((void)0)))))))
29
30 #define __abs_choose_expr(x, type, other) __builtin_choose_expr(        \
31         __builtin_types_compatible_p(typeof(x),   signed type) ||       \
32         __builtin_types_compatible_p(typeof(x), unsigned type),         \
33         ({ signed type __x = (x); __x < 0 ? -__x : __x; }), other)
34
35 #if defined(CONFIG_ARCH_SUPPORTS_INT128) && defined(__SIZEOF_INT128__)
36
37 typedef unsigned __int128 u128;
38
39 static inline u128 u64_to_u128(u64 a)
40 {
41         return (u128)a;
42 }
43
44 static inline u64 u128_to_u64(u128 a)
45 {
46         return (u64)a;
47 }
48
49 static inline u64 u128_shr64_to_u64(u128 a)
50 {
51         return (u64)(a >> 64);
52 }
53
54 static inline u128 u128_add(u128 a, u128 b)
55 {
56         return a + b;
57 }
58
59 static inline u128 u128_sub(u128 a, u128 b)
60 {
61         return a - b;
62 }
63
64 static inline u128 u128_shl(u128 i, s8 shift)
65 {
66         return i << shift;
67 }
68
69 static inline u128 u128_shl64_add(u64 a, u64 b)
70 {
71         return ((u128)a << 64) + b;
72 }
73
74 static inline u128 u128_square(u64 i)
75 {
76         return i*i;
77 }
78
79 #else
80
81 typedef struct {
82         u64 hi, lo;
83 } u128;
84
85 static inline u128 u64_to_u128(u64 a)
86 {
87         return (u128){ .lo = a };
88 }
89
90 static inline u64 u128_to_u64(u128 a)
91 {
92         return a.lo;
93 }
94
95 static inline u64 u128_shr64_to_u64(u128 a)
96 {
97         return a.hi;
98 }
99
100 static inline u128 u128_add(u128 a, u128 b)
101 {
102         u128 c;
103
104         c.lo = a.lo + b.lo;
105         c.hi = a.hi + b.hi + (c.lo < a.lo);
106         return c;
107 }
108
109 static inline u128 u128_sub(u128 a, u128 b)
110 {
111         u128 c;
112
113         c.lo = a.lo - b.lo;
114         c.hi = a.hi - b.hi - (c.lo > a.lo);
115         return c;
116 }
117
118 static inline u128 u128_shl(u128 i, s8 shift)
119 {
120         u128 r;
121
122         r.lo = i.lo << shift;
123         if (shift < 64)
124                 r.hi = (i.hi << shift) | (i.lo >> (64 - shift));
125         else {
126                 r.hi = i.lo << (shift - 64);
127                 r.lo = 0;
128         }
129         return r;
130 }
131
132 static inline u128 u128_shl64_add(u64 a, u64 b)
133 {
134         return u128_add(u128_shl(u64_to_u128(a), 64), u64_to_u128(b));
135 }
136
137 static inline u128 u128_square(u64 i)
138 {
139         u128 r;
140         u64  h = i >> 32, l = i & (u64)U32_MAX;
141
142         r =             u128_shl(u64_to_u128(h*h), 64);
143         r = u128_add(r, u128_shl(u64_to_u128(h*l), 32));
144         r = u128_add(r, u128_shl(u64_to_u128(l*h), 32));
145         r = u128_add(r,          u64_to_u128(l*l));
146         return r;
147 }
148
149 #endif
150
151 static inline u128 u128_div(u128 n, u64 d)
152 {
153         u128 r;
154         u64 rem;
155         u64 hi = u128_shr64_to_u64(n);
156         u64 lo = u128_to_u64(n);
157         u64  h =  hi & ((u64)U32_MAX  << 32);
158         u64  l = (hi &  (u64)U32_MAX) << 32;
159
160         r =             u128_shl(u64_to_u128(div64_u64_rem(h,                d, &rem)), 64);
161         r = u128_add(r, u128_shl(u64_to_u128(div64_u64_rem(l  + (rem << 32), d, &rem)), 32));
162         r = u128_add(r,          u64_to_u128(div64_u64_rem(lo + (rem << 32), d, &rem)));
163         return r;
164 }
165
166 struct mean_and_variance {
167         s64 n;
168         s64 sum;
169         u128 sum_squares;
170 };
171
172 /* expontentially weighted variant */
173 struct mean_and_variance_weighted {
174         bool init;
175         u8 w;
176         s64 mean;
177         u64 variance;
178 };
179
180 s64 fast_divpow2(s64 n, u8 d);
181
182 static inline struct mean_and_variance
183 mean_and_variance_update_inlined(struct mean_and_variance s1, s64 v1)
184 {
185         struct mean_and_variance s2;
186         u64 v2 = abs(v1);
187
188         s2.n           = s1.n + 1;
189         s2.sum         = s1.sum + v1;
190         s2.sum_squares = u128_add(s1.sum_squares, u128_square(v2));
191         return s2;
192 }
193
194 static inline struct mean_and_variance_weighted
195 mean_and_variance_weighted_update_inlined(struct mean_and_variance_weighted s1, s64 x)
196 {
197         struct mean_and_variance_weighted s2;
198         // previous weighted variance.
199         u64 var_w0 = s1.variance;
200         u8 w = s2.w = s1.w;
201         // new value weighted.
202         s64 x_w = x << w;
203         s64 diff_w = x_w - s1.mean;
204         s64 diff = fast_divpow2(diff_w, w);
205         // new mean weighted.
206         s64 u_w1     = s1.mean + diff;
207
208         BUG_ON(w % 2 != 0);
209
210         if (!s1.init) {
211                 s2.mean = x_w;
212                 s2.variance = 0;
213         } else {
214                 s2.mean = u_w1;
215                 s2.variance = ((var_w0 << w) - var_w0 + ((diff_w * (x_w - u_w1)) >> w)) >> w;
216         }
217         s2.init = true;
218
219         return s2;
220 }
221
222 struct mean_and_variance mean_and_variance_update(struct mean_and_variance s1, s64 v1);
223        s64               mean_and_variance_get_mean(struct mean_and_variance s);
224        u64               mean_and_variance_get_variance(struct mean_and_variance s1);
225        u32               mean_and_variance_get_stddev(struct mean_and_variance s);
226
227 struct mean_and_variance_weighted mean_and_variance_weighted_update(struct mean_and_variance_weighted s1, s64 v1);
228        s64                        mean_and_variance_weighted_get_mean(struct mean_and_variance_weighted s);
229        u64                        mean_and_variance_weighted_get_variance(struct mean_and_variance_weighted s);
230        u32                        mean_and_variance_weighted_get_stddev(struct mean_and_variance_weighted s);
231
232 #endif // MEAN_AND_VAIRANCE_H_