1 /* SPDX-License-Identifier: GPL-2.0 */
2 #ifndef MEAN_AND_VARIANCE_H_
3 #define MEAN_AND_VARIANCE_H_
5 #include <linux/kernel.h>
6 #include <linux/types.h>
7 #include <linux/limits.h>
8 #include <linux/math64.h>
10 #define SQRT_U64_MAX 4294967295ULL
13 * abs - return absolute value of an argument
14 * @x: the value. If it is unsigned type, it is converted to signed type first.
15 * char is treated as if it was signed (regardless of whether it really is)
16 * but the macro's return type is preserved as char.
18 * Return: an absolute value of x.
20 #define abs(x) __abs_choose_expr(x, long long, \
21 __abs_choose_expr(x, long, \
22 __abs_choose_expr(x, int, \
23 __abs_choose_expr(x, short, \
24 __abs_choose_expr(x, char, \
25 __builtin_choose_expr( \
26 __builtin_types_compatible_p(typeof(x), char), \
27 (char)({ signed char __x = (x); __x<0?-__x:__x; }), \
30 #define __abs_choose_expr(x, type, other) __builtin_choose_expr( \
31 __builtin_types_compatible_p(typeof(x), signed type) || \
32 __builtin_types_compatible_p(typeof(x), unsigned type), \
33 ({ signed type __x = (x); __x < 0 ? -__x : __x; }), other)
35 #if defined(CONFIG_ARCH_SUPPORTS_INT128) && defined(__SIZEOF_INT128__)
37 typedef unsigned __int128 u128;
39 static inline u128 u64_to_u128(u64 a)
44 static inline u64 u128_to_u64(u128 a)
49 static inline u64 u128_shr64_to_u64(u128 a)
51 return (u64)(a >> 64);
54 static inline u128 u128_add(u128 a, u128 b)
59 static inline u128 u128_sub(u128 a, u128 b)
64 static inline u128 u128_shl(u128 i, s8 shift)
69 static inline u128 u128_shl64_add(u64 a, u64 b)
71 return ((u128)a << 64) + b;
74 static inline u128 u128_square(u64 i)
85 static inline u128 u64_to_u128(u64 a)
87 return (u128){ .lo = a };
90 static inline u64 u128_to_u64(u128 a)
95 static inline u64 u128_shr64_to_u64(u128 a)
100 static inline u128 u128_add(u128 a, u128 b)
105 c.hi = a.hi + b.hi + (c.lo < a.lo);
109 static inline u128 u128_sub(u128 a, u128 b)
114 c.hi = a.hi - b.hi - (c.lo > a.lo);
118 static inline u128 u128_shl(u128 i, s8 shift)
122 r.lo = i.lo << shift;
124 r.hi = (i.hi << shift) | (i.lo >> (64 - shift));
126 r.hi = i.lo << (shift - 64);
132 static inline u128 u128_shl64_add(u64 a, u64 b)
134 return u128_add(u128_shl(u64_to_u128(a), 64), u64_to_u128(b));
137 static inline u128 u128_square(u64 i)
140 u64 h = i >> 32, l = i & (u64)U32_MAX;
142 r = u128_shl(u64_to_u128(h*h), 64);
143 r = u128_add(r, u128_shl(u64_to_u128(h*l), 32));
144 r = u128_add(r, u128_shl(u64_to_u128(l*h), 32));
145 r = u128_add(r, u64_to_u128(l*l));
151 static inline u128 u128_div(u128 n, u64 d)
155 u64 hi = u128_shr64_to_u64(n);
156 u64 lo = u128_to_u64(n);
157 u64 h = hi & ((u64)U32_MAX << 32);
158 u64 l = (hi & (u64)U32_MAX) << 32;
160 r = u128_shl(u64_to_u128(div64_u64_rem(h, d, &rem)), 64);
161 r = u128_add(r, u128_shl(u64_to_u128(div64_u64_rem(l + (rem << 32), d, &rem)), 32));
162 r = u128_add(r, u64_to_u128(div64_u64_rem(lo + (rem << 32), d, &rem)));
166 struct mean_and_variance {
172 /* expontentially weighted variant */
173 struct mean_and_variance_weighted {
180 s64 fast_divpow2(s64 n, u8 d);
182 static inline struct mean_and_variance
183 mean_and_variance_update_inlined(struct mean_and_variance s1, s64 v1)
185 struct mean_and_variance s2;
189 s2.sum = s1.sum + v1;
190 s2.sum_squares = u128_add(s1.sum_squares, u128_square(v2));
194 static inline struct mean_and_variance_weighted
195 mean_and_variance_weighted_update_inlined(struct mean_and_variance_weighted s1, s64 x)
197 struct mean_and_variance_weighted s2;
198 // previous weighted variance.
199 u64 var_w0 = s1.variance;
201 // new value weighted.
203 s64 diff_w = x_w - s1.mean;
204 s64 diff = fast_divpow2(diff_w, w);
205 // new mean weighted.
206 s64 u_w1 = s1.mean + diff;
215 s2.variance = ((var_w0 << w) - var_w0 + ((diff_w * (x_w - u_w1)) >> w)) >> w;
222 struct mean_and_variance mean_and_variance_update(struct mean_and_variance s1, s64 v1);
223 s64 mean_and_variance_get_mean(struct mean_and_variance s);
224 u64 mean_and_variance_get_variance(struct mean_and_variance s1);
225 u32 mean_and_variance_get_stddev(struct mean_and_variance s);
227 struct mean_and_variance_weighted mean_and_variance_weighted_update(struct mean_and_variance_weighted s1, s64 v1);
228 s64 mean_and_variance_weighted_get_mean(struct mean_and_variance_weighted s);
229 u64 mean_and_variance_weighted_get_variance(struct mean_and_variance_weighted s);
230 u32 mean_and_variance_weighted_get_stddev(struct mean_and_variance_weighted s);
232 #endif // MEAN_AND_VAIRANCE_H_