]> git.sesse.net Git - bcachefs-tools-debian/blobdiff - linux/mean_and_variance.c
New upstream release
[bcachefs-tools-debian] / linux / mean_and_variance.c
index aa95db1277716d563961be30a528cde29fa70578..eb5f2ba03b7fbdd920b0080600989a0dc76f8e61 100644 (file)
 #include <linux/math64.h>
 #include <linux/mean_and_variance.h>
 #include <linux/module.h>
-#include <linux/printbuf.h>
 
-
-/**
- * fast_divpow2() - fast approximation for n / (1 << d)
- * @n: numerator
- * @d: the power of 2 denominator.
- *
- * note: this rounds towards 0.
- */
-s64 fast_divpow2(s64 n, u8 d)
-{
-       return (n + ((n < 0) ? ((1 << d) - 1) : 0)) >> d;
-}
-
-/**
- * mean_and_variance_update() - update a mean_and_variance struct @s1 with a new sample @v1
- * and return it.
- * @s1: the mean_and_variance to update.
- * @v1: the new sample.
- *
- * see linked pdf equation 12.
- */
-struct mean_and_variance mean_and_variance_update(struct mean_and_variance s1, s64 v1)
+u128_u u128_div(u128_u n, u64 d)
 {
-       struct mean_and_variance s2;
-       u64 v2 = abs(v1);
-
-       s2.n           = s1.n + 1;
-       s2.sum         = s1.sum + v1;
-       s2.sum_squares = u128_add(s1.sum_squares, u128_square(v2));
-       return s2;
+       u128_u r;
+       u64 rem;
+       u64 hi = u128_hi(n);
+       u64 lo = u128_lo(n);
+       u64  h =  hi & ((u64) U32_MAX  << 32);
+       u64  l = (hi &  (u64) U32_MAX) << 32;
+
+       r =             u128_shl(u64_to_u128(div64_u64_rem(h,                d, &rem)), 64);
+       r = u128_add(r, u128_shl(u64_to_u128(div64_u64_rem(l  + (rem << 32), d, &rem)), 32));
+       r = u128_add(r,          u64_to_u128(div64_u64_rem(lo + (rem << 32), d, &rem)));
+       return r;
 }
-EXPORT_SYMBOL_GPL(mean_and_variance_update);
+EXPORT_SYMBOL_GPL(u128_div);
 
 /**
  * mean_and_variance_get_mean() - get mean from @s
  */
 s64 mean_and_variance_get_mean(struct mean_and_variance s)
 {
-       return div64_u64(s.sum, s.n);
+       return s.n ? div64_u64(s.sum, s.n) : 0;
 }
 EXPORT_SYMBOL_GPL(mean_and_variance_get_mean);
 
@@ -93,10 +75,14 @@ EXPORT_SYMBOL_GPL(mean_and_variance_get_mean);
  */
 u64 mean_and_variance_get_variance(struct mean_and_variance s1)
 {
-       u128 s2 = u128_div(s1.sum_squares, s1.n);
-       u64  s3 = abs(mean_and_variance_get_mean(s1));
+       if (s1.n) {
+               u128_u s2 = u128_div(s1.sum_squares, s1.n);
+               u64  s3 = abs(mean_and_variance_get_mean(s1));
 
-       return u128_to_u64(u128_sub(s2, u128_square(s3)));
+               return u128_lo(u128_sub(s2, u128_square(s3)));
+       } else {
+               return 0;
+       }
 }
 EXPORT_SYMBOL_GPL(mean_and_variance_get_variance);
 
@@ -117,32 +103,26 @@ EXPORT_SYMBOL_GPL(mean_and_variance_get_stddev);
  * see linked pdf: function derived from equations 140-143 where alpha = 2^w.
  * values are stored bitshifted for performance and added precision.
  */
-struct mean_and_variance_weighted mean_and_variance_weighted_update(struct mean_and_variance_weighted s1,
-                                                                   s64 x)
+void mean_and_variance_weighted_update(struct mean_and_variance_weighted *s, s64 x)
 {
-       struct mean_and_variance_weighted s2;
        // previous weighted variance.
-       u64 var_w0 = s1.variance;
-       u8 w = s2.w = s1.w;
+       u8 w            = s->weight;
+       u64 var_w0      = s->variance;
        // new value weighted.
-       s64 x_w = x << w;
-       s64 diff_w = x_w - s1.mean;
-       s64 diff = fast_divpow2(diff_w, w);
+       s64 x_w         = x << w;
+       s64 diff_w      = x_w - s->mean;
+       s64 diff        = fast_divpow2(diff_w, w);
        // new mean weighted.
-       s64 u_w1     = s1.mean + diff;
+       s64 u_w1        = s->mean + diff;
 
-       BUG_ON(w % 2 != 0);
-
-       if (!s1.init) {
-               s2.mean = x_w;
-               s2.variance = 0;
+       if (!s->init) {
+               s->mean = x_w;
+               s->variance = 0;
        } else {
-               s2.mean = u_w1;
-               s2.variance = ((var_w0 << w) - var_w0 + ((diff_w * (x_w - u_w1)) >> w)) >> w;
+               s->mean = u_w1;
+               s->variance = ((var_w0 << w) - var_w0 + ((diff_w * (x_w - u_w1)) >> w)) >> w;
        }
-       s2.init = true;
-
-       return s2;
+       s->init = true;
 }
 EXPORT_SYMBOL_GPL(mean_and_variance_weighted_update);
 
@@ -151,7 +131,7 @@ EXPORT_SYMBOL_GPL(mean_and_variance_weighted_update);
  */
 s64 mean_and_variance_weighted_get_mean(struct mean_and_variance_weighted s)
 {
-       return fast_divpow2(s.mean, s.w);
+       return fast_divpow2(s.mean, s.weight);
 }
 EXPORT_SYMBOL_GPL(mean_and_variance_weighted_get_mean);
 
@@ -161,7 +141,7 @@ EXPORT_SYMBOL_GPL(mean_and_variance_weighted_get_mean);
 u64 mean_and_variance_weighted_get_variance(struct mean_and_variance_weighted s)
 {
        // always positive don't need fast divpow2
-       return s.variance >> s.w;
+       return s.variance >> s.weight;
 }
 EXPORT_SYMBOL_GPL(mean_and_variance_weighted_get_variance);