]> git.sesse.net Git - bcachefs-tools-debian/blobdiff - include/linux/mean_and_variance.h
Update bcachefs sources to 070ec8d07b bcachefs: Snapshot depth, skiplist fields
[bcachefs-tools-debian] / include / linux / mean_and_variance.h
index 756eb3d1ca641a2acebf5d52b05ebb551eaad5f2..647505010b3974b713823f96ddeab6a6aa8fe5df 100644 (file)
 #ifndef MEAN_AND_VARIANCE_H_
 #define MEAN_AND_VARIANCE_H_
 
-#include <linux/kernel.h>
 #include <linux/types.h>
 #include <linux/limits.h>
+#include <linux/math.h>
 #include <linux/math64.h>
 
 #define SQRT_U64_MAX 4294967295ULL
 
-/**
- * abs - return absolute value of an argument
- * @x: the value.  If it is unsigned type, it is converted to signed type first.
- *     char is treated as if it was signed (regardless of whether it really is)
- *     but the macro's return type is preserved as char.
- *
- * Return: an absolute value of x.
+/*
+ * u128_u: u128 user mode, because not all architectures support a real int128
+ * type
  */
-#define abs(x) __abs_choose_expr(x, long long,                         \
-               __abs_choose_expr(x, long,                              \
-               __abs_choose_expr(x, int,                               \
-               __abs_choose_expr(x, short,                             \
-               __abs_choose_expr(x, char,                              \
-               __builtin_choose_expr(                                  \
-                       __builtin_types_compatible_p(typeof(x), char),  \
-                       (char)({ signed char __x = (x); __x<0?-__x:__x; }), \
-                       ((void)0)))))))
 
-#define __abs_choose_expr(x, type, other) __builtin_choose_expr(       \
-       __builtin_types_compatible_p(typeof(x),   signed type) ||       \
-       __builtin_types_compatible_p(typeof(x), unsigned type),         \
-       ({ signed type __x = (x); __x < 0 ? -__x : __x; }), other)
+#ifdef __SIZEOF_INT128__
 
-#if defined(CONFIG_ARCH_SUPPORTS_INT128) && defined(__SIZEOF_INT128__)
-
-typedef unsigned __int128 u128;
+typedef struct {
+       unsigned __int128 v;
+} __aligned(16) u128_u;
 
-static inline u128 u64_to_u128(u64 a)
+static inline u128_u u64_to_u128(u64 a)
 {
-       return (u128)a;
+       return (u128_u) { .v = a };
 }
 
-static inline u64 u128_to_u64(u128 a)
+static inline u64 u128_lo(u128_u a)
 {
-       return (u64)a;
+       return a.v;
 }
 
-static inline u64 u128_shr64_to_u64(u128 a)
+static inline u64 u128_hi(u128_u a)
 {
-       return (u64)(a >> 64);
+       return a.v >> 64;
 }
 
-static inline u128 u128_add(u128 a, u128 b)
+static inline u128_u u128_add(u128_u a, u128_u b)
 {
-       return a + b;
+       a.v += b.v;
+       return a;
 }
 
-static inline u128 u128_sub(u128 a, u128 b)
+static inline u128_u u128_sub(u128_u a, u128_u b)
 {
-       return a - b;
+       a.v -= b.v;
+       return a;
 }
 
-static inline u128 u128_shl(u128 i, s8 shift)
+static inline u128_u u128_shl(u128_u a, s8 shift)
 {
-       return i << shift;
+       a.v <<= shift;
+       return a;
 }
 
-static inline u128 u128_shl64_add(u64 a, u64 b)
+static inline u128_u u128_square(u64 a)
 {
-       return ((u128)a << 64) + b;
-}
+       u128_u b = u64_to_u128(a);
 
-static inline u128 u128_square(u64 i)
-{
-       return i*i;
+       b.v *= b.v;
+       return b;
 }
 
 #else
 
 typedef struct {
        u64 hi, lo;
-} u128;
+} __aligned(16) u128_u;
+
+/* conversions */
 
-static inline u128 u64_to_u128(u64 a)
+static inline u128_u u64_to_u128(u64 a)
 {
-       return (u128){ .lo = a };
+       return (u128_u) { .lo = a };
 }
 
-static inline u64 u128_to_u64(u128 a)
+static inline u64 u128_lo(u128_u a)
 {
        return a.lo;
 }
 
-static inline u64 u128_shr64_to_u64(u128 a)
+static inline u64 u128_hi(u128_u a)
 {
        return a.hi;
 }
 
-static inline u128 u128_add(u128 a, u128 b)
+/* arithmetic */
+
+static inline u128_u u128_add(u128_u a, u128_u b)
 {
-       u128 c;
+       u128_u c;
 
        c.lo = a.lo + b.lo;
        c.hi = a.hi + b.hi + (c.lo < a.lo);
        return c;
 }
 
-static inline u128 u128_sub(u128 a, u128 b)
+static inline u128_u u128_sub(u128_u a, u128_u b)
 {
-       u128 c;
+       u128_u c;
 
        c.lo = a.lo - b.lo;
        c.hi = a.hi - b.hi - (c.lo > a.lo);
        return c;
 }
 
-static inline u128 u128_shl(u128 i, s8 shift)
+static inline u128_u u128_shl(u128_u i, s8 shift)
 {
-       u128 r;
+       u128_u r;
 
        r.lo = i.lo << shift;
        if (shift < 64)
@@ -129,15 +118,10 @@ static inline u128 u128_shl(u128 i, s8 shift)
        return r;
 }
 
-static inline u128 u128_shl64_add(u64 a, u64 b)
-{
-       return u128_add(u128_shl(u64_to_u128(a), 64), u64_to_u128(b));
-}
-
-static inline u128 u128_square(u64 i)
+static inline u128_u u128_square(u64 i)
 {
-       u128 r;
-       u64  h = i >> 32, l = i & (u64)U32_MAX;
+       u128_u r;
+       u64  h = i >> 32, l = i & U32_MAX;
 
        r =             u128_shl(u64_to_u128(h*h), 64);
        r = u128_add(r, u128_shl(u64_to_u128(h*l), 32));
@@ -148,85 +132,67 @@ static inline u128 u128_square(u64 i)
 
 #endif
 
-static inline u128 u128_div(u128 n, u64 d)
+static inline u128_u u64s_to_u128(u64 hi, u64 lo)
 {
-       u128 r;
-       u64 rem;
-       u64 hi = u128_shr64_to_u64(n);
-       u64 lo = u128_to_u64(n);
-       u64  h =  hi & ((u64)U32_MAX  << 32);
-       u64  l = (hi &  (u64)U32_MAX) << 32;
+       u128_u c = u64_to_u128(hi);
 
-       r =             u128_shl(u64_to_u128(div64_u64_rem(h,                d, &rem)), 64);
-       r = u128_add(r, u128_shl(u64_to_u128(div64_u64_rem(l  + (rem << 32), d, &rem)), 32));
-       r = u128_add(r,          u64_to_u128(div64_u64_rem(lo + (rem << 32), d, &rem)));
-       return r;
+       c = u128_shl(c, 64);
+       c = u128_add(c, u64_to_u128(lo));
+       return c;
 }
 
+u128_u u128_div(u128_u n, u64 d);
+
 struct mean_and_variance {
-       s64 n;
-       s64 sum;
-       u128 sum_squares;
+       s64     n;
+       s64     sum;
+       u128_u  sum_squares;
 };
 
 /* expontentially weighted variant */
 struct mean_and_variance_weighted {
-       bool init;
-       u8 w;
-       s64 mean;
-       u64 variance;
+       bool    init;
+       u8      weight; /* base 2 logarithim */
+       s64     mean;
+       u64     variance;
 };
 
-s64 fast_divpow2(s64 n, u8 d);
-
-static inline struct mean_and_variance
-mean_and_variance_update_inlined(struct mean_and_variance s1, s64 v1)
+/**
+ * fast_divpow2() - fast approximation for n / (1 << d)
+ * @n: numerator
+ * @d: the power of 2 denominator.
+ *
+ * note: this rounds towards 0.
+ */
+static inline s64 fast_divpow2(s64 n, u8 d)
 {
-       struct mean_and_variance s2;
-       u64 v2 = abs(v1);
-
-       s2.n           = s1.n + 1;
-       s2.sum         = s1.sum + v1;
-       s2.sum_squares = u128_add(s1.sum_squares, u128_square(v2));
-       return s2;
+       return (n + ((n < 0) ? ((1 << d) - 1) : 0)) >> d;
 }
 
-static inline struct mean_and_variance_weighted
-mean_and_variance_weighted_update_inlined(struct mean_and_variance_weighted s1, s64 x)
+/**
+ * mean_and_variance_update() - update a mean_and_variance struct @s1 with a new sample @v1
+ * and return it.
+ * @s1: the mean_and_variance to update.
+ * @v1: the new sample.
+ *
+ * see linked pdf equation 12.
+ */
+static inline void
+mean_and_variance_update(struct mean_and_variance *s, s64 v)
 {
-       struct mean_and_variance_weighted s2;
-       // previous weighted variance.
-       u64 var_w0 = s1.variance;
-       u8 w = s2.w = s1.w;
-       // new value weighted.
-       s64 x_w = x << w;
-       s64 diff_w = x_w - s1.mean;
-       s64 diff = fast_divpow2(diff_w, w);
-       // new mean weighted.
-       s64 u_w1     = s1.mean + diff;
-
-       BUG_ON(w % 2 != 0);
-
-       if (!s1.init) {
-               s2.mean = x_w;
-               s2.variance = 0;
-       } else {
-               s2.mean = u_w1;
-               s2.variance = ((var_w0 << w) - var_w0 + ((diff_w * (x_w - u_w1)) >> w)) >> w;
-       }
-       s2.init = true;
-
-       return s2;
+       s->n++;
+       s->sum += v;
+       s->sum_squares = u128_add(s->sum_squares, u128_square(abs(v)));
 }
 
-struct mean_and_variance mean_and_variance_update(struct mean_and_variance s1, s64 v1);
-       s64              mean_and_variance_get_mean(struct mean_and_variance s);
-       u64              mean_and_variance_get_variance(struct mean_and_variance s1);
-       u32              mean_and_variance_get_stddev(struct mean_and_variance s);
+s64 mean_and_variance_get_mean(struct mean_and_variance s);
+u64 mean_and_variance_get_variance(struct mean_and_variance s1);
+u32 mean_and_variance_get_stddev(struct mean_and_variance s);
+
+void mean_and_variance_weighted_update(struct mean_and_variance_weighted *s, s64 v);
 
-struct mean_and_variance_weighted mean_and_variance_weighted_update(struct mean_and_variance_weighted s1, s64 v1);
-       s64                       mean_and_variance_weighted_get_mean(struct mean_and_variance_weighted s);
-       u64                       mean_and_variance_weighted_get_variance(struct mean_and_variance_weighted s);
-       u32                       mean_and_variance_weighted_get_stddev(struct mean_and_variance_weighted s);
+s64 mean_and_variance_weighted_get_mean(struct mean_and_variance_weighted s);
+u64 mean_and_variance_weighted_get_variance(struct mean_and_variance_weighted s);
+u32 mean_and_variance_weighted_get_stddev(struct mean_and_variance_weighted s);
 
 #endif // MEAN_AND_VAIRANCE_H_