]> git.sesse.net Git - wloh/blobdiff - bayeswf.cpp
Back up all tables, not just the ones we need themselves.
[wloh] / bayeswf.cpp
index 6441801e23c32c17bab637f4c8ad09b7b609b23c..71686c999cb944f2af6226d23b9fcf7e07d9115e 100644 (file)
@@ -201,6 +201,9 @@ void construct_hessian(const float *mu, int num_players)
 {
        memset(hessian, 0, sizeof(hessian));
 
+       for (int i = 0; i < num_players; ++i) {
+               hessian[i][i] += 1.0f / (prior_sigma * prior_sigma);
+       }
        for (unsigned i = 0; i < all_matches.size(); ++i) {
                const match &m = all_matches[i];
 
@@ -224,6 +227,14 @@ void compute_mu_uncertainty(const float *mu, int num_players)
 {
        memset(mu_stddev, 0, sizeof(mu_stddev));
 
+       // Temporarily use mu_stddev to store the diagonal of the Hessian.
+
+       // Prior.
+       for (int i = 0; i < num_players; ++i) {
+               mu_stddev[i] += 1.0f / (prior_sigma * prior_sigma);
+       }
+
+       // Matches.
        for (unsigned i = 0; i < all_matches.size(); ++i) {
                const match &m = all_matches[i];
 
@@ -233,10 +244,11 @@ void compute_mu_uncertainty(const float *mu, int num_players)
                double sigma_sq = global_sigma * global_sigma;
                float w = m.weight;
 
-               // Temporarily use mu_stddev to store the diagonal of the Hessian.
                mu_stddev[p1] += w / sigma_sq;
                mu_stddev[p2] += w / sigma_sq;
        }
+
+       // Now convert to standard deviation.
        for (int i = 0; i < num_players; ++i) {
                mu_stddev[i] = 1.0f / sqrt(mu_stddev[i]);
        }
@@ -337,7 +349,7 @@ int main(int argc, char **argv)
                sumdiff += (global_sigma - old_global_sigma) * (global_sigma - old_global_sigma);
                if (sumdiff < EPSILON) {
                        //fprintf(stderr, "Converged after %d iterations. Stopping.\n", j);
-                       printf("%d -1\n", j + 1);
+                       printf("%d -1\n", j + 1);
                        break;
                }
        }
@@ -348,11 +360,11 @@ int main(int argc, char **argv)
        compute_mu_uncertainty(mu, num_players);
        dump_scores(players, mu, mu_stddev, num_players);
        //fprintf(stderr, "Optimal sigma: %f (two-player: %f)\n", sigma[0], sigma[0] * sqrt(2.0f));
-       printf("%f -2\n", global_sigma / sqrt(2.0f));
-       printf("%f -3\n", prior_sigma);
+       printf("%f -2\n", global_sigma / sqrt(2.0f));
+       printf("%f -3\n", prior_sigma);
 
        float total_logl = compute_total_logl(mu, num_players);
-       printf("%f -4\n", total_logl);
+       printf("%f -4\n", total_logl);
 
 //     construct_hessian(mu, sigma, num_players);
 #endif