#include <linux/types.h>
#include <linux/limits.h>
+#include <linux/math.h>
#include <linux/math64.h>
-#include <linux/printbuf.h>
#define SQRT_U64_MAX 4294967295ULL
+/*
+ * u128_u: u128 user mode, because not all architectures support a real int128
+ * type
+ */
-#if defined(CONFIG_ARCH_SUPPORTS_INT128) && defined(__SIZEOF_INT128__)
+#ifdef __SIZEOF_INT128__
-typedef unsigned __int128 u128;
+typedef struct {
+ unsigned __int128 v;
+} __aligned(16) u128_u;
-static inline u128 u64_to_u128(u64 a)
+static inline u128_u u64_to_u128(u64 a)
{
- return (u128)a;
+ return (u128_u) { .v = a };
}
-static inline u64 u128_to_u64(u128 a)
+static inline u64 u128_lo(u128_u a)
{
- return (u64)a;
+ return a.v;
}
-static inline u64 u128_shr64_to_u64(u128 a)
+static inline u64 u128_hi(u128_u a)
{
- return (u64)(a >> 64);
+ return a.v >> 64;
}
-static inline u128 u128_add(u128 a, u128 b)
+static inline u128_u u128_add(u128_u a, u128_u b)
{
- return a + b;
+ a.v += b.v;
+ return a;
}
-static inline u128 u128_sub(u128 a, u128 b)
+static inline u128_u u128_sub(u128_u a, u128_u b)
{
- return a - b;
+ a.v -= b.v;
+ return a;
}
-static inline u128 u128_shl(u128 i, s8 shift)
+static inline u128_u u128_shl(u128_u a, s8 shift)
{
- return i << shift;
+ a.v <<= shift;
+ return a;
}
-static inline u128 u128_shl64_add(u64 a, u64 b)
+static inline u128_u u128_square(u64 a)
{
- return ((u128)a << 64) + b;
-}
+ u128_u b = u64_to_u128(a);
-static inline u128 u128_square(u64 i)
-{
- return i*i;
+ b.v *= b.v;
+ return b;
}
#else
typedef struct {
u64 hi, lo;
-} u128;
+} __aligned(16) u128_u;
-static inline u128 u64_to_u128(u64 a)
+/* conversions */
+
+static inline u128_u u64_to_u128(u64 a)
{
- return (u128){ .lo = a };
+ return (u128_u) { .lo = a };
}
-static inline u64 u128_to_u64(u128 a)
+static inline u64 u128_lo(u128_u a)
{
return a.lo;
}
-static inline u64 u128_shr64_to_u64(u128 a)
+static inline u64 u128_hi(u128_u a)
{
return a.hi;
}
-static inline u128 u128_add(u128 a, u128 b)
+/* arithmetic */
+
+static inline u128_u u128_add(u128_u a, u128_u b)
{
- u128 c;
+ u128_u c;
c.lo = a.lo + b.lo;
c.hi = a.hi + b.hi + (c.lo < a.lo);
return c;
}
-static inline u128 u128_sub(u128 a, u128 b)
+static inline u128_u u128_sub(u128_u a, u128_u b)
{
- u128 c;
+ u128_u c;
c.lo = a.lo - b.lo;
c.hi = a.hi - b.hi - (c.lo > a.lo);
return c;
}
-static inline u128 u128_shl(u128 i, s8 shift)
+static inline u128_u u128_shl(u128_u i, s8 shift)
{
- u128 r;
+ u128_u r;
r.lo = i.lo << shift;
if (shift < 64)
return r;
}
-static inline u128 u128_shl64_add(u64 a, u64 b)
+static inline u128_u u128_square(u64 i)
{
- return u128_add(u128_shl(u64_to_u128(a), 64), u64_to_u128(b));
-}
-
-static inline u128 u128_square(u64 i)
-{
- u128 r;
- u64 h = i >> 32, l = i & (u64)U32_MAX;
+ u128_u r;
+ u64 h = i >> 32, l = i & U32_MAX;
r = u128_shl(u64_to_u128(h*h), 64);
r = u128_add(r, u128_shl(u64_to_u128(h*l), 32));
#endif
-static inline u128 u128_div(u128 n, u64 d)
+static inline u128_u u64s_to_u128(u64 hi, u64 lo)
{
- u128 r;
- u64 rem;
- u64 hi = u128_shr64_to_u64(n);
- u64 lo = u128_to_u64(n);
- u64 h = hi & ((u64)U32_MAX << 32);
- u64 l = (hi & (u64)U32_MAX) << 32;
+ u128_u c = u64_to_u128(hi);
- r = u128_shl(u64_to_u128(div64_u64_rem(h, d, &rem)), 64);
- r = u128_add(r, u128_shl(u64_to_u128(div64_u64_rem(l + (rem << 32), d, &rem)), 32));
- r = u128_add(r, u64_to_u128(div64_u64_rem(lo + (rem << 32), d, &rem)));
- return r;
+ c = u128_shl(c, 64);
+ c = u128_add(c, u64_to_u128(lo));
+ return c;
}
+u128_u u128_div(u128_u n, u64 d);
+
struct mean_and_variance {
- s64 n;
- s64 sum;
- u128 sum_squares;
+ s64 n;
+ s64 sum;
+ u128_u sum_squares;
};
/* expontentially weighted variant */
struct mean_and_variance_weighted {
- bool init;
- u8 w;
- s64 mean;
- u64 variance;
+ bool init;
+ u8 weight; /* base 2 logarithim */
+ s64 mean;
+ u64 variance;
};
-s64 fast_divpow2(s64 n, u8 d);
+/**
+ * fast_divpow2() - fast approximation for n / (1 << d)
+ * @n: numerator
+ * @d: the power of 2 denominator.
+ *
+ * note: this rounds towards 0.
+ */
+static inline s64 fast_divpow2(s64 n, u8 d)
+{
+ return (n + ((n < 0) ? ((1 << d) - 1) : 0)) >> d;
+}
+
+/**
+ * mean_and_variance_update() - update a mean_and_variance struct @s1 with a new sample @v1
+ * and return it.
+ * @s1: the mean_and_variance to update.
+ * @v1: the new sample.
+ *
+ * see linked pdf equation 12.
+ */
+static inline void
+mean_and_variance_update(struct mean_and_variance *s, s64 v)
+{
+ s->n++;
+ s->sum += v;
+ s->sum_squares = u128_add(s->sum_squares, u128_square(abs(v)));
+}
+
+s64 mean_and_variance_get_mean(struct mean_and_variance s);
+u64 mean_and_variance_get_variance(struct mean_and_variance s1);
+u32 mean_and_variance_get_stddev(struct mean_and_variance s);
-struct mean_and_variance mean_and_variance_update(struct mean_and_variance s1, s64 v1);
- s64 mean_and_variance_get_mean(struct mean_and_variance s);
- u64 mean_and_variance_get_variance(struct mean_and_variance s1);
- u32 mean_and_variance_get_stddev(struct mean_and_variance s);
+void mean_and_variance_weighted_update(struct mean_and_variance_weighted *s, s64 v);
-struct mean_and_variance_weighted mean_and_variance_weighted_update(struct mean_and_variance_weighted s1, s64 v1);
- s64 mean_and_variance_weighted_get_mean(struct mean_and_variance_weighted s);
- u64 mean_and_variance_weighted_get_variance(struct mean_and_variance_weighted s);
- u32 mean_and_variance_weighted_get_stddev(struct mean_and_variance_weighted s);
+s64 mean_and_variance_weighted_get_mean(struct mean_and_variance_weighted s);
+u64 mean_and_variance_weighted_get_variance(struct mean_and_variance_weighted s);
+u32 mean_and_variance_weighted_get_stddev(struct mean_and_variance_weighted s);
#endif // MEAN_AND_VAIRANCE_H_