]> git.sesse.net Git - nms/blobdiff - tsp/tsp.cpp
Cache the distances we need instead of calculating them over and over again (70%...
[nms] / tsp / tsp.cpp
index 587573ba6fa93966ab341e80ec17ec733a38752b..000449219b2f3504886a32fe27f0ee479f97df45 100644 (file)
@@ -4,6 +4,30 @@
 #include <set>
 #include <algorithm>
 
+#define MIN_ROW 1
+#define MAX_ROW 75
+#define MIN_SWITCH 1
+#define MAX_SWITCH 6
+
+static const unsigned num_cache_elem = (MAX_ROW * MAX_SWITCH * 2) * (MAX_ROW * MAX_SWITCH * 2);
+static unsigned short dist_cache[(MAX_ROW * MAX_SWITCH * 2) * (MAX_ROW * MAX_SWITCH * 2)], opt_dist_cache[MAX_ROW * MAX_SWITCH * MAX_ROW * MAX_SWITCH];
+
+inline unsigned short &cache(
+       unsigned row_from, unsigned switch_from, unsigned side_from,
+       unsigned row_to, unsigned switch_to, unsigned side_to)
+{
+       return dist_cache[(row_from * MAX_SWITCH * 2 + switch_from * 2 + side_from) * (MAX_ROW * MAX_SWITCH * 2) +
+               row_to * MAX_SWITCH * 2 + switch_to * 2 + side_to];
+}
+
+inline unsigned short &opt_cache(
+       unsigned row_from, unsigned switch_from,
+       unsigned row_to, unsigned switch_to)
+{
+       return opt_dist_cache[(row_from * MAX_SWITCH + switch_from) * (MAX_ROW * MAX_SWITCH) +
+               row_to * MAX_SWITCH + switch_to];
+}
+
 struct order {
        unsigned row, num;
        int side;
@@ -114,7 +138,7 @@ int prim_mst(std::set<std::pair<unsigned, unsigned> > &set1)
                
                for (std::set<std::pair<unsigned, unsigned> >::iterator i = set1.begin(); i != set1.end(); ++i) {
                        for (std::set<std::pair<unsigned, unsigned> >::iterator j = set2.begin(); j != set2.end(); ++j) {
-                               unsigned d = optimistic_distance(i->first, i->second, j->first, j->second);
+                               unsigned d = opt_cache(i->first, i->second, j->first, j->second);
                                if (d < best_this_cost) {
                                        best_this_cost = d;
                                        best_set1 = i;
@@ -198,13 +222,13 @@ unsigned do_tsp(std::vector<std::pair<unsigned, unsigned> > &points, std::set<st
                temp[toi].row = i->first;
                temp[toi].num = i->second;
                temp[toi].side = 0;
-               temp[toi].cost = distance(last_row, last_switch, last_side, i->first, i->second, 0);
+               temp[toi].cost = cache(last_row, last_switch, last_side, i->first, i->second, 0);
                ++toi;
 
                temp[toi].row = i->first;
                temp[toi].num = i->second;
                temp[toi].side = 1;
-               temp[toi].cost = distance(last_row, last_switch, last_side, i->first, i->second, 1);
+               temp[toi].cost = cache(last_row, last_switch, last_side, i->first, i->second, 1);
                ++toi;
        }
 
@@ -241,11 +265,36 @@ int main()
                if (scanf("%u-%u", &row, &sw) != 2)
                        break;
 
+               if (row < MIN_ROW || row > MAX_ROW || sw < MIN_SWITCH || sw > MAX_SWITCH) {
+                       fprintf(stderr, "%u-%u is out of bounds!\n", row, sw);
+                       exit(1);
+               }
+
                points.push_back(std::make_pair(row, sw));
                if (points.size() != 1)
                        points_left.insert(std::make_pair(row, sw));
        }
 
+       // precalculate all distances
+       for (unsigned i = 0; i < points.size(); ++i) {
+               for (unsigned j = 0; j < points.size(); ++j) {
+                       cache(points[i].first, points[i].second, 0, points[j].first, points[j].second, 0) =
+                               distance(points[i].first, points[i].second, 0, points[j].first, points[j].second, 0);
+                       
+                       cache(points[i].first, points[i].second, 0, points[j].first, points[j].second, 1) =
+                               distance(points[i].first, points[i].second, 0, points[j].first, points[j].second, 1);
+                       
+                       cache(points[i].first, points[i].second, 1, points[j].first, points[j].second, 0) =
+                               distance(points[i].first, points[i].second, 1, points[j].first, points[j].second, 0);
+                       
+                       cache(points[i].first, points[i].second, 1, points[j].first, points[j].second, 1) =
+                               distance(points[i].first, points[i].second, 1, points[j].first, points[j].second, 1);
+                       
+                       opt_cache(points[i].first, points[i].second, points[j].first, points[j].second) =
+                               optimistic_distance(points[i].first, points[i].second, points[j].first, points[j].second);
+               }
+       }
+
        order *ord = new order[points.size()];
        best_tour = new order[points.size()];
        order *temp = new order[points.size() * points.size() * 4];