]> git.sesse.net Git - bcachefs-tools-debian/blob - include/linux/mean_and_variance.h
9ed79f42a40439472a4367ef8e5c5d566890a9f3
[bcachefs-tools-debian] / include / linux / mean_and_variance.h
1 /* SPDX-License-Identifier: GPL-2.0 */
2 #ifndef MEAN_AND_VARIANCE_H_
3 #define MEAN_AND_VARIANCE_H_
4
5 #include <linux/types.h>
6 #include <linux/kernel.h>
7 #include <linux/limits.h>
8 #include <linux/math64.h>
9 #include <stdlib.h>
10
11 #define SQRT_U64_MAX 4294967295ULL
12
13 /*
14  * u128_u: u128 user mode, because not all architectures support a real int128
15  * type
16  */
17
18 #ifdef __SIZEOF_INT128__
19
20 typedef struct {
21         unsigned __int128 v;
22 } __aligned(16) u128_u;
23
24 static inline u128_u u64_to_u128(u64 a)
25 {
26         return (u128_u) { .v = a };
27 }
28
29 static inline u64 u128_lo(u128_u a)
30 {
31         return a.v;
32 }
33
34 static inline u64 u128_hi(u128_u a)
35 {
36         return a.v >> 64;
37 }
38
39 static inline u128_u u128_add(u128_u a, u128_u b)
40 {
41         a.v += b.v;
42         return a;
43 }
44
45 static inline u128_u u128_sub(u128_u a, u128_u b)
46 {
47         a.v -= b.v;
48         return a;
49 }
50
51 static inline u128_u u128_shl(u128_u a, s8 shift)
52 {
53         a.v <<= shift;
54         return a;
55 }
56
57 static inline u128_u u128_square(u64 a)
58 {
59         u128_u b = u64_to_u128(a);
60
61         b.v *= b.v;
62         return b;
63 }
64
65 #else
66
67 typedef struct {
68         u64 hi, lo;
69 } __aligned(16) u128_u;
70
71 /* conversions */
72
73 static inline u128_u u64_to_u128(u64 a)
74 {
75         return (u128_u) { .lo = a };
76 }
77
78 static inline u64 u128_lo(u128_u a)
79 {
80         return a.lo;
81 }
82
83 static inline u64 u128_hi(u128_u a)
84 {
85         return a.hi;
86 }
87
88 /* arithmetic */
89
90 static inline u128_u u128_add(u128_u a, u128_u b)
91 {
92         u128_u c;
93
94         c.lo = a.lo + b.lo;
95         c.hi = a.hi + b.hi + (c.lo < a.lo);
96         return c;
97 }
98
99 static inline u128_u u128_sub(u128_u a, u128_u b)
100 {
101         u128_u c;
102
103         c.lo = a.lo - b.lo;
104         c.hi = a.hi - b.hi - (c.lo > a.lo);
105         return c;
106 }
107
108 static inline u128_u u128_shl(u128_u i, s8 shift)
109 {
110         u128_u r;
111
112         r.lo = i.lo << shift;
113         if (shift < 64)
114                 r.hi = (i.hi << shift) | (i.lo >> (64 - shift));
115         else {
116                 r.hi = i.lo << (shift - 64);
117                 r.lo = 0;
118         }
119         return r;
120 }
121
122 static inline u128_u u128_square(u64 i)
123 {
124         u128_u r;
125         u64  h = i >> 32, l = i & U32_MAX;
126
127         r =             u128_shl(u64_to_u128(h*h), 64);
128         r = u128_add(r, u128_shl(u64_to_u128(h*l), 32));
129         r = u128_add(r, u128_shl(u64_to_u128(l*h), 32));
130         r = u128_add(r,          u64_to_u128(l*l));
131         return r;
132 }
133
134 #endif
135
136 static inline u128_u u64s_to_u128(u64 hi, u64 lo)
137 {
138         u128_u c = u64_to_u128(hi);
139
140         c = u128_shl(c, 64);
141         c = u128_add(c, u64_to_u128(lo));
142         return c;
143 }
144
145 u128_u u128_div(u128_u n, u64 d);
146
147 struct mean_and_variance {
148         s64     n;
149         s64     sum;
150         u128_u  sum_squares;
151 };
152
153 /* expontentially weighted variant */
154 struct mean_and_variance_weighted {
155         bool    init;
156         u8      weight; /* base 2 logarithim */
157         s64     mean;
158         u64     variance;
159 };
160
161 /**
162  * fast_divpow2() - fast approximation for n / (1 << d)
163  * @n: numerator
164  * @d: the power of 2 denominator.
165  *
166  * note: this rounds towards 0.
167  */
168 static inline s64 fast_divpow2(s64 n, u8 d)
169 {
170         return (n + ((n < 0) ? ((1 << d) - 1) : 0)) >> d;
171 }
172
173 /**
174  * mean_and_variance_update() - update a mean_and_variance struct @s1 with a new sample @v1
175  * and return it.
176  * @s1: the mean_and_variance to update.
177  * @v1: the new sample.
178  *
179  * see linked pdf equation 12.
180  */
181 static inline struct mean_and_variance
182 mean_and_variance_update(struct mean_and_variance s, s64 v)
183 {
184         return (struct mean_and_variance) {
185                 .n           = s.n + 1,
186                 .sum         = s.sum + v,
187                 .sum_squares = u128_add(s.sum_squares, u128_square(abs(v))),
188         };
189 }
190
191 s64 mean_and_variance_get_mean(struct mean_and_variance s);
192 u64 mean_and_variance_get_variance(struct mean_and_variance s1);
193 u32 mean_and_variance_get_stddev(struct mean_and_variance s);
194
195 void mean_and_variance_weighted_update(struct mean_and_variance_weighted *s, s64 v);
196
197 s64 mean_and_variance_weighted_get_mean(struct mean_and_variance_weighted s);
198 u64 mean_and_variance_weighted_get_variance(struct mean_and_variance_weighted s);
199 u32 mean_and_variance_weighted_get_stddev(struct mean_and_variance_weighted s);
200
201 #endif // MEAN_AND_VAIRANCE_H_