#include <linux/math64.h>
#include <linux/mean_and_variance.h>
#include <linux/module.h>
-#include <linux/printbuf.h>
-
-/**
- * fast_divpow2() - fast approximation for n / (1 << d)
- * @n: numerator
- * @d: the power of 2 denominator.
- *
- * note: this rounds towards 0.
- */
-s64 fast_divpow2(s64 n, u8 d)
+u128_u u128_div(u128_u n, u64 d)
{
- return (n + ((n < 0) ? ((1 << d) - 1) : 0)) >> d;
-}
+ u128_u r;
+ u64 rem;
+ u64 hi = u128_hi(n);
+ u64 lo = u128_lo(n);
+ u64 h = hi & ((u64) U32_MAX << 32);
+ u64 l = (hi & (u64) U32_MAX) << 32;
-/**
- * mean_and_variance_update() - update a mean_and_variance struct @s1 with a new sample @v1
- * and return it.
- * @s1: the mean_and_variance to update.
- * @v1: the new sample.
- *
- * see linked pdf equation 12.
- */
-struct mean_and_variance mean_and_variance_update(struct mean_and_variance s1, s64 v1)
-{
- return mean_and_variance_update_inlined(s1, v1);
+ r = u128_shl(u64_to_u128(div64_u64_rem(h, d, &rem)), 64);
+ r = u128_add(r, u128_shl(u64_to_u128(div64_u64_rem(l + (rem << 32), d, &rem)), 32));
+ r = u128_add(r, u64_to_u128(div64_u64_rem(lo + (rem << 32), d, &rem)));
+ return r;
}
-EXPORT_SYMBOL_GPL(mean_and_variance_update);
+EXPORT_SYMBOL_GPL(u128_div);
/**
* mean_and_variance_get_mean() - get mean from @s
+ * @s: mean and variance number of samples and their sums
*/
s64 mean_and_variance_get_mean(struct mean_and_variance s)
{
- return div64_u64(s.sum, s.n);
+ return s.n ? div64_u64(s.sum, s.n) : 0;
}
EXPORT_SYMBOL_GPL(mean_and_variance_get_mean);
/**
* mean_and_variance_get_variance() - get variance from @s1
+ * @s1: mean and variance number of samples and sums
*
* see linked pdf equation 12.
*/
u64 mean_and_variance_get_variance(struct mean_and_variance s1)
{
- u128 s2 = u128_div(s1.sum_squares, s1.n);
- u64 s3 = abs(mean_and_variance_get_mean(s1));
+ if (s1.n) {
+ u128_u s2 = u128_div(s1.sum_squares, s1.n);
+ u64 s3 = abs(mean_and_variance_get_mean(s1));
- return u128_to_u64(u128_sub(s2, u128_square(s3)));
+ return u128_lo(u128_sub(s2, u128_square(s3)));
+ } else {
+ return 0;
+ }
}
EXPORT_SYMBOL_GPL(mean_and_variance_get_variance);
/**
* mean_and_variance_get_stddev() - get standard deviation from @s
+ * @s: mean and variance number of samples and their sums
*/
u32 mean_and_variance_get_stddev(struct mean_and_variance s)
{
/**
* mean_and_variance_weighted_update() - exponentially weighted variant of mean_and_variance_update()
- * @s1: ..
- * @s2: ..
+ * @s: mean and variance number of samples and their sums
+ * @x: new value to include in the &mean_and_variance_weighted
+ * @initted: caller must track whether this is the first use or not
+ * @weight: ewma weight
*
* see linked pdf: function derived from equations 140-143 where alpha = 2^w.
* values are stored bitshifted for performance and added precision.
*/
-struct mean_and_variance_weighted mean_and_variance_weighted_update(struct mean_and_variance_weighted s1,
- s64 x)
+void mean_and_variance_weighted_update(struct mean_and_variance_weighted *s,
+ s64 x, bool initted, u8 weight)
{
- return mean_and_variance_weighted_update_inlined(s1, x);
+ // previous weighted variance.
+ u8 w = weight;
+ u64 var_w0 = s->variance;
+ // new value weighted.
+ s64 x_w = x << w;
+ s64 diff_w = x_w - s->mean;
+ s64 diff = fast_divpow2(diff_w, w);
+ // new mean weighted.
+ s64 u_w1 = s->mean + diff;
+
+ if (!initted) {
+ s->mean = x_w;
+ s->variance = 0;
+ } else {
+ s->mean = u_w1;
+ s->variance = ((var_w0 << w) - var_w0 + ((diff_w * (x_w - u_w1)) >> w)) >> w;
+ }
}
EXPORT_SYMBOL_GPL(mean_and_variance_weighted_update);
/**
* mean_and_variance_weighted_get_mean() - get mean from @s
+ * @s: mean and variance number of samples and their sums
+ * @weight: ewma weight
*/
-s64 mean_and_variance_weighted_get_mean(struct mean_and_variance_weighted s)
+s64 mean_and_variance_weighted_get_mean(struct mean_and_variance_weighted s,
+ u8 weight)
{
- return fast_divpow2(s.mean, s.w);
+ return fast_divpow2(s.mean, weight);
}
EXPORT_SYMBOL_GPL(mean_and_variance_weighted_get_mean);
/**
* mean_and_variance_weighted_get_variance() -- get variance from @s
+ * @s: mean and variance number of samples and their sums
+ * @weight: ewma weight
*/
-u64 mean_and_variance_weighted_get_variance(struct mean_and_variance_weighted s)
+u64 mean_and_variance_weighted_get_variance(struct mean_and_variance_weighted s,
+ u8 weight)
{
// always positive don't need fast divpow2
- return s.variance >> s.w;
+ return s.variance >> weight;
}
EXPORT_SYMBOL_GPL(mean_and_variance_weighted_get_variance);
/**
* mean_and_variance_weighted_get_stddev() - get standard deviation from @s
+ * @s: mean and variance number of samples and their sums
+ * @weight: ewma weight
*/
-u32 mean_and_variance_weighted_get_stddev(struct mean_and_variance_weighted s)
+u32 mean_and_variance_weighted_get_stddev(struct mean_and_variance_weighted s,
+ u8 weight)
{
- return int_sqrt64(mean_and_variance_weighted_get_variance(s));
+ return int_sqrt64(mean_and_variance_weighted_get_variance(s, weight));
}
EXPORT_SYMBOL_GPL(mean_and_variance_weighted_get_stddev);