]> git.sesse.net Git - foosball/commitdiff
Generalize the Gauss-Jordan solver to arbitrary sizes. In the process it's
authorSteinar H. Gunderson <sesse@debian.org>
Sun, 21 Oct 2007 12:33:34 +0000 (14:33 +0200)
committerSteinar H. Gunderson <sesse@debian.org>
Sun, 21 Oct 2007 12:33:34 +0000 (14:33 +0200)
becoming really crufty old code =)

foosrank.cpp

index bdb54f2324b3e6a95e626cf035474fb9c1f97b1f..25298e7ddfe0bdc16590ab98e806dd8150401a1e 100644 (file)
@@ -177,70 +177,41 @@ static void mat_mul_trans(double *matA, unsigned ah, unsigned aw,
        }
 }
 
-// solves Ax = B by Gauss-Jordan elimination, where A is a 3x3 matrix,
-// x is a column vector of length 3 and B is a row vector of length 3.
+// solves Ax = B by Gauss-Jordan elimination, where A is an NxN matrix,
+// x is a column vector of length N and B is a row vector of length N.
 // Destroys its input in the process.
-static void solve3x3(double *A, double *x, double *B)
+template<int N>
+static void solve_matrix(double *A, double *x, double *B)
 {
-       // row 1 -= row 0 * (a1/a0)
-       {
-               double f = A[1] / A[0];
-               A[1] = 0.0;
-               A[4] -= A[3] * f;
-               A[7] -= A[6] * f;
-
-               B[1] -= B[0] * f;
-       }
-
-       // row 2 -= row 0 * (a2/a0)
-       {
-               double f = A[2] / A[0];
-               A[2] = 0.0;
-               A[5] -= A[3] * f;
-               A[8] -= A[6] * f;
-
-               B[2] -= B[0] * f;
-       }
-
-       // row 2 -= row 1 * (a5/a4)
-       {
-               double f = A[5] / A[4];
-               A[5] = 0.0;
-               A[8] -= A[7] * f;
-               
-               B[2] -= B[1] * f;
-       }
-
-       // back substitute:
-
-       // row 1 -= row 2 * (a7/a8)
-       {
-               double f = A[7] / A[8];
-               A[7] = 0.0;
-
-               B[1] -= B[2] * f;
-       }
-
-       // row 0 -= row 2 * (a6/a8)
-       {
-               double f = A[6] / A[8];
-               A[6] = 0.0;
+       for (int i = 0; i < N; ++i) {
+               for (int j = i+1; j < N; ++j) {
+                       // row j -= row i * (a[i,j] / a[i,i])
+                       double f = A[j+i*N] / A[i+i*N];
+
+                       A[j+i*N] = 0.0;
+                       for (int k = i+1; k < N; ++k) {
+                               A[j+k*N] -= A[i+k*N] * f;
+                       }
 
-               B[0] -= B[2] * f;
+                       B[j] -= B[i] * f;
+               }
        }
 
-       // row 0 -= row 1 * (a3/a4)
-       {
-               double f = A[3] / A[4];
-               A[3] = 0.0;
-
-               B[0] -= B[1] * f;
+       // back-substitute
+       for (int i = N; i --> 0; ) {
+               for (int j = i; j --> 0; ) {
+                       // row j -= row i * (a[j,j] / a[j,i])
+                       double f = A[i+j*N] / A[j+j*N];
+                       
+                       // A[j+i*N] = 0.0;
+                       B[j] -= B[i] * f;
+               }
        }
 
        // normalize
-       x[0] = B[0] / A[0];
-       x[1] = B[1] / A[4];
-       x[2] = B[2] / A[8];
+       for (int i = 0; i < N; ++i) {
+               x[i] = B[i] / A[i+i*N];
+       }
 }
 
 // Give an OK starting estimate for the least squares, by numerical integration
@@ -338,7 +309,7 @@ static void least_squares(vector<pair<double, double> > &curve, double mu1, doub
                mat_mul_trans(matA, curve.size(), 3, dbeta, curve.size(), 1, matATdb);
 
                // solve
-               solve3x3(matATA, dlambda, matATdb);
+               solve_matrix<3>(matATA, dlambda, matATdb);
 
                A += dlambda[0];
                mu += dlambda[1];