]> git.sesse.net Git - sigmoidsmooth/blob - opt.cc
Add a descriptive comment to fit.sh.
[sigmoidsmooth] / opt.cc
1 // Optimizes general sigmoid function 1/(1 + exp(-a(t-m))) over a set of points.
2
3 #include <stdio.h>
4 #include <vector>
5 #include "ceres/ceres.h"
6 #include "gflags/gflags.h"
7 #include "glog/logging.h"
8
9 using namespace ceres;
10 using std::vector;
11
12 struct Point {
13         float t;
14         float adoption_rate;
15 };
16
17 class DistanceFromSigmoidCostFunction {
18  public:
19         DistanceFromSigmoidCostFunction(Point p) : p_(p) {}
20         
21         template<class T>
22         bool operator() (const T * const a, const T * const m, T* e) const {
23                 *e = T(p_.adoption_rate) - T(1.0) / (T(1.0) + ceres::exp(- *a * (T(p_.t) - *m)));
24                 return true;
25         }
26
27  private:
28         Point p_;
29 };
30
31 int main(int argc, char** argv) {
32         google::ParseCommandLineFlags(&argc, &argv, true);
33         google::InitGoogleLogging(argv[0]);
34
35         Problem problem;
36
37         FILE *fp = fopen(argv[1], "r");
38         if (fp == NULL) {
39                 perror(argv[1]);
40                 exit(1);
41         }
42
43         // The two parameters to be optimized.
44         double a = 1.0;
45         double m = 0.5;
46
47         int line_num = 0;
48         while (!feof(fp)) {
49                 char buf[256];
50                 if (fgets(buf, 256, fp) == NULL) {
51                         break;
52                 }
53                 Point p;
54                 p.t = line_num;
55                 p.adoption_rate = atof(buf) * 0.01;
56                 problem.AddResidualBlock(
57                     new AutoDiffCostFunction<DistanceFromSigmoidCostFunction, 1, 1, 1>(
58                         new DistanceFromSigmoidCostFunction(p)),
59                     NULL,
60                     &a, &m);
61                 ++line_num;
62         }
63         fclose(fp);
64
65         a = 1.0 / line_num;
66
67         // Run the solver!
68         Solver::Options options;
69         options.max_num_iterations = 1000000;
70         options.linear_solver_type = ceres::DENSE_QR;
71         options.minimizer_progress_to_stdout = true;
72
73         Solver::Summary summary;
74         Solve(options, &problem, &summary);
75
76         std::cout << summary.BriefReport() << "\n";
77         std::cout << "a=" << a << std::endl;
78         std::cout << "m=" << m << std::endl;
79         
80         return 0;
81 }