1 /* SPDX-License-Identifier: GPL-2.0 */
2 #ifndef MEAN_AND_VARIANCE_H_
3 #define MEAN_AND_VARIANCE_H_
5 #include <linux/types.h>
6 #include <linux/kernel.h>
7 #include <linux/limits.h>
8 #include <linux/math64.h>
11 #define SQRT_U64_MAX 4294967295ULL
14 * u128_u: u128 user mode, because not all architectures support a real int128
18 #ifdef __SIZEOF_INT128__
22 } __aligned(16) u128_u;
24 static inline u128_u u64_to_u128(u64 a)
26 return (u128_u) { .v = a };
29 static inline u64 u128_lo(u128_u a)
34 static inline u64 u128_hi(u128_u a)
39 static inline u128_u u128_add(u128_u a, u128_u b)
45 static inline u128_u u128_sub(u128_u a, u128_u b)
51 static inline u128_u u128_shl(u128_u a, s8 shift)
57 static inline u128_u u128_square(u64 a)
59 u128_u b = u64_to_u128(a);
69 } __aligned(16) u128_u;
73 static inline u128_u u64_to_u128(u64 a)
75 return (u128_u) { .lo = a };
78 static inline u64 u128_lo(u128_u a)
83 static inline u64 u128_hi(u128_u a)
90 static inline u128_u u128_add(u128_u a, u128_u b)
95 c.hi = a.hi + b.hi + (c.lo < a.lo);
99 static inline u128_u u128_sub(u128_u a, u128_u b)
104 c.hi = a.hi - b.hi - (c.lo > a.lo);
108 static inline u128_u u128_shl(u128_u i, s8 shift)
112 r.lo = i.lo << shift;
114 r.hi = (i.hi << shift) | (i.lo >> (64 - shift));
116 r.hi = i.lo << (shift - 64);
122 static inline u128_u u128_square(u64 i)
125 u64 h = i >> 32, l = i & U32_MAX;
127 r = u128_shl(u64_to_u128(h*h), 64);
128 r = u128_add(r, u128_shl(u64_to_u128(h*l), 32));
129 r = u128_add(r, u128_shl(u64_to_u128(l*h), 32));
130 r = u128_add(r, u64_to_u128(l*l));
136 static inline u128_u u64s_to_u128(u64 hi, u64 lo)
138 u128_u c = u64_to_u128(hi);
141 c = u128_add(c, u64_to_u128(lo));
145 u128_u u128_div(u128_u n, u64 d);
147 struct mean_and_variance {
153 /* expontentially weighted variant */
154 struct mean_and_variance_weighted {
156 u8 weight; /* base 2 logarithim */
162 * fast_divpow2() - fast approximation for n / (1 << d)
164 * @d: the power of 2 denominator.
166 * note: this rounds towards 0.
168 static inline s64 fast_divpow2(s64 n, u8 d)
170 return (n + ((n < 0) ? ((1 << d) - 1) : 0)) >> d;
174 * mean_and_variance_update() - update a mean_and_variance struct @s1 with a new sample @v1
176 * @s1: the mean_and_variance to update.
177 * @v1: the new sample.
179 * see linked pdf equation 12.
181 static inline struct mean_and_variance
182 mean_and_variance_update(struct mean_and_variance s, s64 v)
184 return (struct mean_and_variance) {
187 .sum_squares = u128_add(s.sum_squares, u128_square(abs(v))),
191 s64 mean_and_variance_get_mean(struct mean_and_variance s);
192 u64 mean_and_variance_get_variance(struct mean_and_variance s1);
193 u32 mean_and_variance_get_stddev(struct mean_and_variance s);
195 void mean_and_variance_weighted_update(struct mean_and_variance_weighted *s, s64 v);
197 s64 mean_and_variance_weighted_get_mean(struct mean_and_variance_weighted s);
198 u64 mean_and_variance_weighted_get_variance(struct mean_and_variance_weighted s);
199 u32 mean_and_variance_weighted_get_stddev(struct mean_and_variance_weighted s);
201 #endif // MEAN_AND_VAIRANCE_H_