]> git.sesse.net Git - bcachefs-tools-debian/blob - include/linux/mean_and_variance.h
Update bcachefs sources to 5963d1b1a4 bcacehfs: Fix bch2_get_alloc_in_memory_pos()
[bcachefs-tools-debian] / include / linux / mean_and_variance.h
1 /* SPDX-License-Identifier: GPL-2.0 */
2 #ifndef MEAN_AND_VARIANCE_H_
3 #define MEAN_AND_VARIANCE_H_
4
5 #include <linux/types.h>
6 #include <linux/limits.h>
7 #include <linux/math64.h>
8 #include <linux/printbuf.h>
9
10 #define SQRT_U64_MAX 4294967295ULL
11
12
13 #if defined(CONFIG_ARCH_SUPPORTS_INT128) && defined(__SIZEOF_INT128__)
14
15 typedef unsigned __int128 u128;
16
17 static inline u128 u64_to_u128(u64 a)
18 {
19         return (u128)a;
20 }
21
22 static inline u64 u128_to_u64(u128 a)
23 {
24         return (u64)a;
25 }
26
27 static inline u64 u128_shr64_to_u64(u128 a)
28 {
29         return (u64)(a >> 64);
30 }
31
32 static inline u128 u128_add(u128 a, u128 b)
33 {
34         return a + b;
35 }
36
37 static inline u128 u128_sub(u128 a, u128 b)
38 {
39         return a - b;
40 }
41
42 static inline u128 u128_shl(u128 i, s8 shift)
43 {
44         return i << shift;
45 }
46
47 static inline u128 u128_shl64_add(u64 a, u64 b)
48 {
49         return ((u128)a << 64) + b;
50 }
51
52 static inline u128 u128_square(u64 i)
53 {
54         return i*i;
55 }
56
57 #else
58
59 typedef struct {
60         u64 hi, lo;
61 } u128;
62
63 static inline u128 u64_to_u128(u64 a)
64 {
65         return (u128){ .lo = a };
66 }
67
68 static inline u64 u128_to_u64(u128 a)
69 {
70         return a.lo;
71 }
72
73 static inline u64 u128_shr64_to_u64(u128 a)
74 {
75         return a.hi;
76 }
77
78 static inline u128 u128_add(u128 a, u128 b)
79 {
80         u128 c;
81
82         c.lo = a.lo + b.lo;
83         c.hi = a.hi + b.hi + (c.lo < a.lo);
84         return c;
85 }
86
87 static inline u128 u128_sub(u128 a, u128 b)
88 {
89         u128 c;
90
91         c.lo = a.lo - b.lo;
92         c.hi = a.hi - b.hi - (c.lo > a.lo);
93         return c;
94 }
95
96 static inline u128 u128_shl(u128 i, s8 shift)
97 {
98         u128 r;
99
100         r.lo = i.lo << shift;
101         if (shift < 64)
102                 r.hi = (i.hi << shift) | (i.lo >> (64 - shift));
103         else {
104                 r.hi = i.lo << (shift - 64);
105                 r.lo = 0;
106         }
107         return r;
108 }
109
110 static inline u128 u128_shl64_add(u64 a, u64 b)
111 {
112         return u128_add(u128_shl(u64_to_u128(a), 64), u64_to_u128(b));
113 }
114
115 static inline u128 u128_square(u64 i)
116 {
117         u128 r;
118         u64  h = i >> 32, l = i & (u64)U32_MAX;
119
120         r =             u128_shl(u64_to_u128(h*h), 64);
121         r = u128_add(r, u128_shl(u64_to_u128(h*l), 32));
122         r = u128_add(r, u128_shl(u64_to_u128(l*h), 32));
123         r = u128_add(r,          u64_to_u128(l*l));
124         return r;
125 }
126
127 #endif
128
129 static inline u128 u128_div(u128 n, u64 d)
130 {
131         u128 r;
132         u64 rem;
133         u64 hi = u128_shr64_to_u64(n);
134         u64 lo = u128_to_u64(n);
135         u64  h =  hi & ((u64)U32_MAX  << 32);
136         u64  l = (hi &  (u64)U32_MAX) << 32;
137
138         r =             u128_shl(u64_to_u128(div64_u64_rem(h,                d, &rem)), 64);
139         r = u128_add(r, u128_shl(u64_to_u128(div64_u64_rem(l  + (rem << 32), d, &rem)), 32));
140         r = u128_add(r,          u64_to_u128(div64_u64_rem(lo + (rem << 32), d, &rem)));
141         return r;
142 }
143
144 struct mean_and_variance {
145         s64 n;
146         s64 sum;
147         u128 sum_squares;
148 };
149
150 /* expontentially weighted variant */
151 struct mean_and_variance_weighted {
152         bool init;
153         u8 w;
154         s64 mean;
155         u64 variance;
156 };
157
158 s64 fast_divpow2(s64 n, u8 d);
159
160 static inline struct mean_and_variance
161 mean_and_variance_update_inlined(struct mean_and_variance s1, s64 v1)
162 {
163         struct mean_and_variance s2;
164         u64 v2 = abs(v1);
165
166         s2.n           = s1.n + 1;
167         s2.sum         = s1.sum + v1;
168         s2.sum_squares = u128_add(s1.sum_squares, u128_square(v2));
169         return s2;
170 }
171
172 static inline struct mean_and_variance_weighted
173 mean_and_variance_weighted_update_inlined(struct mean_and_variance_weighted s1, s64 x)
174 {
175         struct mean_and_variance_weighted s2;
176         // previous weighted variance.
177         u64 var_w0 = s1.variance;
178         u8 w = s2.w = s1.w;
179         // new value weighted.
180         s64 x_w = x << w;
181         s64 diff_w = x_w - s1.mean;
182         s64 diff = fast_divpow2(diff_w, w);
183         // new mean weighted.
184         s64 u_w1     = s1.mean + diff;
185
186         BUG_ON(w % 2 != 0);
187
188         if (!s1.init) {
189                 s2.mean = x_w;
190                 s2.variance = 0;
191         } else {
192                 s2.mean = u_w1;
193                 s2.variance = ((var_w0 << w) - var_w0 + ((diff_w * (x_w - u_w1)) >> w)) >> w;
194         }
195         s2.init = true;
196
197         return s2;
198 }
199
200 struct mean_and_variance mean_and_variance_update(struct mean_and_variance s1, s64 v1);
201        s64               mean_and_variance_get_mean(struct mean_and_variance s);
202        u64               mean_and_variance_get_variance(struct mean_and_variance s1);
203        u32               mean_and_variance_get_stddev(struct mean_and_variance s);
204
205 struct mean_and_variance_weighted mean_and_variance_weighted_update(struct mean_and_variance_weighted s1, s64 v1);
206        s64                        mean_and_variance_weighted_get_mean(struct mean_and_variance_weighted s);
207        u64                        mean_and_variance_weighted_get_variance(struct mean_and_variance_weighted s);
208        u32                        mean_and_variance_weighted_get_stddev(struct mean_and_variance_weighted s);
209
210 #endif // MEAN_AND_VAIRANCE_H_