]> git.sesse.net Git - bcachefs-tools-debian/blobdiff - linux/mean_and_variance.c
Disable pristine-tar option in gbp.conf, since there is no pristine-tar branch.
[bcachefs-tools-debian] / linux / mean_and_variance.c
index 643e3113500b20a5dece4bc6a64eb3a253271d44..21ec6afc6788413a35c815cef827ef8a25891329 100644 (file)
 #include <linux/math64.h>
 #include <linux/mean_and_variance.h>
 #include <linux/module.h>
-#include <linux/printbuf.h>
 
-
-/**
- * fast_divpow2() - fast approximation for n / (1 << d)
- * @n: numerator
- * @d: the power of 2 denominator.
- *
- * note: this rounds towards 0.
- */
-inline s64 fast_divpow2(s64 n, u8 d)
-{
-       return (n + ((n < 0) ? ((1 << d) - 1) : 0)) >> d;
-}
-
-/**
- * mean_and_variance_update() - update a mean_and_variance struct @s1 with a new sample @v1
- * and return it.
- * @s1: the mean_and_variance to update.
- * @v1: the new sample.
- *
- * see linked pdf equation 12.
- */
-struct mean_and_variance mean_and_variance_update(struct mean_and_variance s1, s64 v1)
+u128_u u128_div(u128_u n, u64 d)
 {
-       struct mean_and_variance s2;
-       u64 v2 = abs(v1);
-
-       s2.n           = s1.n + 1;
-       s2.sum         = s1.sum + v1;
-       s2.sum_squares = u128_add(s1.sum_squares, u128_square(v2));
-       return s2;
+       u128_u r;
+       u64 rem;
+       u64 hi = u128_hi(n);
+       u64 lo = u128_lo(n);
+       u64  h =  hi & ((u64) U32_MAX  << 32);
+       u64  l = (hi &  (u64) U32_MAX) << 32;
+
+       r =             u128_shl(u64_to_u128(div64_u64_rem(h,                d, &rem)), 64);
+       r = u128_add(r, u128_shl(u64_to_u128(div64_u64_rem(l  + (rem << 32), d, &rem)), 32));
+       r = u128_add(r,          u64_to_u128(div64_u64_rem(lo + (rem << 32), d, &rem)));
+       return r;
 }
-EXPORT_SYMBOL_GPL(mean_and_variance_update);
+EXPORT_SYMBOL_GPL(u128_div);
 
 /**
  * mean_and_variance_get_mean() - get mean from @s
+ * @s: mean and variance number of samples and their sums
  */
 s64 mean_and_variance_get_mean(struct mean_and_variance s)
 {
-       return div64_u64(s.sum, s.n);
+       return s.n ? div64_u64(s.sum, s.n) : 0;
 }
 EXPORT_SYMBOL_GPL(mean_and_variance_get_mean);
 
 /**
  * mean_and_variance_get_variance() -  get variance from @s1
+ * @s1: mean and variance number of samples and sums
  *
  * see linked pdf equation 12.
  */
 u64 mean_and_variance_get_variance(struct mean_and_variance s1)
 {
-       u128 s2 = u128_div(s1.sum_squares, s1.n);
-       u64  s3 = abs(mean_and_variance_get_mean(s1));
+       if (s1.n) {
+               u128_u s2 = u128_div(s1.sum_squares, s1.n);
+               u64  s3 = abs(mean_and_variance_get_mean(s1));
 
-       return u128_to_u64(u128_sub(s2, u128_square(s3)));
+               return u128_lo(u128_sub(s2, u128_square(s3)));
+       } else {
+               return 0;
+       }
 }
 EXPORT_SYMBOL_GPL(mean_and_variance_get_variance);
 
 /**
  * mean_and_variance_get_stddev() - get standard deviation from @s
+ * @s: mean and variance number of samples and their sums
  */
 u32 mean_and_variance_get_stddev(struct mean_and_variance s)
 {
@@ -111,66 +100,71 @@ EXPORT_SYMBOL_GPL(mean_and_variance_get_stddev);
 
 /**
  * mean_and_variance_weighted_update() - exponentially weighted variant of mean_and_variance_update()
- * @s1: ..
- * @s2: ..
+ * @s: mean and variance number of samples and their sums
+ * @x: new value to include in the &mean_and_variance_weighted
+ * @initted: caller must track whether this is the first use or not
+ * @weight: ewma weight
  *
  * see linked pdf: function derived from equations 140-143 where alpha = 2^w.
  * values are stored bitshifted for performance and added precision.
  */
-struct mean_and_variance_weighted mean_and_variance_weighted_update(struct mean_and_variance_weighted s1,
-                                                                   s64 x)
+void mean_and_variance_weighted_update(struct mean_and_variance_weighted *s,
+               s64 x, bool initted, u8 weight)
 {
-       struct mean_and_variance_weighted s2;
        // previous weighted variance.
-       u64 var_w0 = s1.variance;
-       u8 w = s2.w = s1.w;
+       u8 w            = weight;
+       u64 var_w0      = s->variance;
        // new value weighted.
-       s64 x_w = x << w;
-       s64 diff_w = x_w - s1.mean;
-       s64 diff = fast_divpow2(diff_w, w);
+       s64 x_w         = x << w;
+       s64 diff_w      = x_w - s->mean;
+       s64 diff        = fast_divpow2(diff_w, w);
        // new mean weighted.
-       s64 u_w1     = s1.mean + diff;
+       s64 u_w1        = s->mean + diff;
 
-       BUG_ON(w % 2 != 0);
-
-       if (!s1.init) {
-               s2.mean = x_w;
-               s2.variance = 0;
+       if (!initted) {
+               s->mean = x_w;
+               s->variance = 0;
        } else {
-               s2.mean = u_w1;
-               s2.variance = ((var_w0 << w) - var_w0 + ((diff_w * (x_w - u_w1)) >> w)) >> w;
+               s->mean = u_w1;
+               s->variance = ((var_w0 << w) - var_w0 + ((diff_w * (x_w - u_w1)) >> w)) >> w;
        }
-       s2.init = true;
-
-       return s2;
 }
 EXPORT_SYMBOL_GPL(mean_and_variance_weighted_update);
 
 /**
  * mean_and_variance_weighted_get_mean() - get mean from @s
+ * @s: mean and variance number of samples and their sums
+ * @weight: ewma weight
  */
-s64 mean_and_variance_weighted_get_mean(struct mean_and_variance_weighted s)
+s64 mean_and_variance_weighted_get_mean(struct mean_and_variance_weighted s,
+               u8 weight)
 {
-       return fast_divpow2(s.mean, s.w);
+       return fast_divpow2(s.mean, weight);
 }
 EXPORT_SYMBOL_GPL(mean_and_variance_weighted_get_mean);
 
 /**
  * mean_and_variance_weighted_get_variance() -- get variance from @s
+ * @s: mean and variance number of samples and their sums
+ * @weight: ewma weight
  */
-u64 mean_and_variance_weighted_get_variance(struct mean_and_variance_weighted s)
+u64 mean_and_variance_weighted_get_variance(struct mean_and_variance_weighted s,
+               u8 weight)
 {
        // always positive don't need fast divpow2
-       return s.variance >> s.w;
+       return s.variance >> weight;
 }
 EXPORT_SYMBOL_GPL(mean_and_variance_weighted_get_variance);
 
 /**
  * mean_and_variance_weighted_get_stddev() - get standard deviation from @s
+ * @s: mean and variance number of samples and their sums
+ * @weight: ewma weight
  */
-u32 mean_and_variance_weighted_get_stddev(struct mean_and_variance_weighted s)
+u32 mean_and_variance_weighted_get_stddev(struct mean_and_variance_weighted s,
+               u8 weight)
 {
-       return int_sqrt64(mean_and_variance_weighted_get_variance(s));
+       return int_sqrt64(mean_and_variance_weighted_get_variance(s, weight));
 }
 EXPORT_SYMBOL_GPL(mean_and_variance_weighted_get_stddev);