]> git.sesse.net Git - nms/blob - tsp/tsp.cpp
Implement a new version of the MST finder (disabled by default).
[nms] / tsp / tsp.cpp
1 #include <stdio.h>
2 #include <limits.h>
3 #include <vector>
4 #include <set>
5 #include <algorithm>
6
7 #define MIN_ROW 1
8 #define MAX_ROW 75
9 #define MIN_SWITCH 1
10 #define MAX_SWITCH 6
11 #define HEAP_MST 0
12
13 static const unsigned num_cache_elem = (MAX_ROW * MAX_SWITCH * 2) * (MAX_ROW * MAX_SWITCH * 2);
14 static unsigned short dist_cache[(MAX_ROW * MAX_SWITCH * 2) * (MAX_ROW * MAX_SWITCH * 2)], opt_dist_cache[MAX_ROW * MAX_SWITCH * MAX_ROW * MAX_SWITCH];
15
16 inline unsigned short &cache(
17         unsigned row_from, unsigned switch_from, unsigned side_from,
18         unsigned row_to, unsigned switch_to, unsigned side_to)
19 {
20         return dist_cache[(row_from * MAX_SWITCH * 2 + switch_from * 2 + side_from) * (MAX_ROW * MAX_SWITCH * 2) +
21                 row_to * MAX_SWITCH * 2 + switch_to * 2 + side_to];
22 }
23
24 inline unsigned short &opt_cache(
25         unsigned row_from, unsigned switch_from,
26         unsigned row_to, unsigned switch_to)
27 {
28         return opt_dist_cache[(row_from * MAX_SWITCH + switch_from) * (MAX_ROW * MAX_SWITCH) +
29                 row_to * MAX_SWITCH + switch_to];
30 }
31
32 struct order {
33         unsigned row, num;
34         int side;
35         int cost;
36
37         bool operator< (const order &other) const
38         {
39                 return (cost < other.cost);
40         }
41 };
42
43 static unsigned best_so_far = UINT_MAX;
44 order *best_tour;
45
46 int distance_switch(unsigned from, unsigned to)
47 {
48         /* on the same side of the middle? 9.6m per switch. */
49         if ((from > 3) == (to > 3)) {
50                 return abs(from - to) * 96;
51         }
52
53         /* have to cross the border? 25.8m from sw3->sw4 => 16.2m extra gap. */
54         /* that's _got_ to be wrong. say it's 3m. */
55         return abs(from - to) * 96 + 30;
56 }
57
58 int distance_middle(unsigned sw, unsigned middle)
59 {
60         /* symmetry: 4-5-6 are just mirrored 3-2-1. */
61         if (middle == 2) {
62                 if (sw > 3)
63                         sw = 7 - sw;
64
65                 /* estimate 25.8m/2 = 12.9m from sw3 to the middle */
66                 return 129 + (3 - sw) * 96;
67         }
68         
69         /* more symmetry -- getting from 1-6 to the top is like getting from 6-1 to the bottom. */
70         if (middle == 3) {
71                 middle = 1;
72                 sw = 7 - sw;
73         }
74
75         /* guesstimate 4.8m extra to get to the bottom */
76         if (sw > 3)
77                 return 48 + 162 + (sw - 1) * 96;
78         else
79                 return 48 + (sw - 1) * 96;
80 }
81
82 int distance_row(unsigned from, unsigned to)
83 {
84         /* don't calculate gaps here just yet, just estimate 4.1m per double row */
85         return 41 * abs(from - to);
86 }
87
88 int distance(int row_from, int switch_from, int side_from, int row_to, int switch_to, int side_to)
89 {
90         /* can we just walk directly? */
91         if (row_from == row_to && side_from == side_to) {
92                 return distance_switch(switch_from, switch_to);
93         }
94
95         /* can we just switch sides? */
96         if (row_from + 1 == row_to && side_from == 1 && side_to == 0) {
97                 return distance_switch(switch_from, switch_to);
98         }
99         if (row_from == row_to + 1 && side_from == 0 && side_to == 1) {
100                 return distance_switch(switch_from, switch_to);
101         }
102         
103         /* we'll need to go to one of the three middles */
104         int best2 = distance_middle(switch_from, 2) + distance_middle(switch_to, 2);
105         int distrow = distance_row(row_from, row_to);
106         if ((switch_from > 3) != (switch_to > 3))
107                 return best2 + distrow;
108         if (switch_from > 3) {
109                 int best3 = distance_middle(switch_from, 3) + distance_middle(switch_to, 3);
110                 return std::min(best2, best3) + distrow;
111         } else {
112                 int best1 = distance_middle(switch_from, 1) + distance_middle(switch_to, 1);
113                 return std::min(best2, best1) + distrow;
114         }
115 }
116
117 int optimistic_distance(int row_from, int switch_from, int row_to, int switch_to)
118 {
119         if (abs(row_from - row_to) == 1)
120                 return distance_switch(switch_from, switch_to);
121         else
122                 return distance(row_from, switch_from, 0, row_to, switch_to, 0);
123 }
124
125 #if HEAP_MST
126 // this is, surprisingly enough, _slower_ than the naive variant below, so it's not enabled
127 struct prim_queue_val {
128         std::pair<unsigned, unsigned> dst;
129         int cost;
130
131         bool operator< (const prim_queue_val &other) const
132         {
133                 return (cost > other.cost);
134         }
135 };
136
137 // standard O(V^2 log v) prim
138 int prim_mst(std::set<std::pair<unsigned, unsigned> > &in)
139 {
140         std::set<std::pair<unsigned, unsigned> > set2;
141         std::priority_queue<prim_queue_val> queue;
142
143         // pick the first one
144         std::set<std::pair<unsigned, unsigned> >::iterator start = in.begin();
145         
146         unsigned row = start->first;
147         unsigned num = start->second;
148
149         set2.insert(*start);
150         
151         // find all the edges out from it
152         for (std::set<std::pair<unsigned, unsigned> >::iterator j = in.begin(); j != in.end(); ++j) {
153                 if (set2.count(*j))
154                         continue;
155                 
156                 unsigned d = opt_cache(row, num, j->first, j->second);
157                 prim_queue_val val = { *j, d };
158                 queue.push(val);
159         }
160
161         unsigned total_cost = 0;
162         while (set2.size() != in.size()) {
163 invalid:
164                 prim_queue_val val = queue.top();
165                 queue.pop();
166                 
167                 // check if dst is already moved
168                 if (set2.count(val.dst))
169                         goto invalid;
170         
171                 unsigned row = val.dst.first;
172                 unsigned num = val.dst.second;
173                 set2.insert(val.dst);
174
175                 total_cost += val.cost;
176
177                 // find all the edges from this new node
178                 for (std::set<std::pair<unsigned, unsigned> >::iterator j = in.begin(); j != in.end(); ++j) {
179                         if (set2.count(*j))
180                                 continue;
181                         
182                         unsigned d = opt_cache(row, num, j->first, j->second);
183                         prim_queue_val val = { *j, d };
184                         queue.push(val);
185                 }
186         }
187
188         return total_cost;
189 }
190 #else
191 // extremely primitive O(V^3) prim
192 int prim_mst(std::set<std::pair<unsigned, unsigned> > &set1)
193 {
194         std::set<std::pair<unsigned, unsigned> > set2;
195
196         // pick the first one
197         std::set<std::pair<unsigned, unsigned> >::iterator start = set1.begin();
198         set2.insert(*start);
199         set1.erase(start);
200
201         unsigned total_cost = 0;
202         while (set1.size() > 0) {
203                 unsigned best_this_cost = UINT_MAX;
204                 std::set<std::pair<unsigned, unsigned> >::iterator best_set1;
205                 
206                 for (std::set<std::pair<unsigned, unsigned> >::iterator i = set1.begin(); i != set1.end(); ++i) {
207                         for (std::set<std::pair<unsigned, unsigned> >::iterator j = set2.begin(); j != set2.end(); ++j) {
208                                 unsigned d = opt_cache(i->first, i->second, j->first, j->second);
209                                 if (d < best_this_cost) {
210                                         best_this_cost = d;
211                                         best_set1 = i;
212                                 }
213                         }
214                 }
215
216                 set2.insert(*best_set1);
217                 set1.erase(best_set1);
218                 total_cost += best_this_cost;
219         }
220
221         return total_cost;
222 }
223 #endif
224
225 void print_tour(std::vector<std::pair<unsigned, unsigned> > &points)
226 {
227         std::set<std::pair<unsigned, unsigned> > points_left;
228         for (unsigned i = 0; i < points.size(); ++i) {
229                 points_left.insert(points[i]);
230         }
231         
232         for (unsigned i = 0; i < points.size(); ++i) {
233                 if (best_tour[i].side == 0)
234                         printf("%2u-%u (left side)  ", best_tour[i].row, best_tour[i].num);
235                 else
236                         printf("%2u-%u (right side) ", best_tour[i].row, best_tour[i].num);
237                 if (i == 0) {
238                         printf("           ");
239                 } else {
240                         printf("cost=%4u  ", best_tour[i].cost);
241                 }
242
243                 // let's see how good the MST heuristics are
244                 if (i != points.size() - 1) {
245                         std::set<std::pair<unsigned, unsigned> > mst_tree = points_left;
246                         printf("mst_bound=%5u, ", prim_mst(mst_tree));
247
248                         unsigned rest_cost = 0;
249                         for (unsigned j = i + 1; j < points.size(); ++j) {
250                                 rest_cost += best_tour[j].cost;
251                         }
252                         
253                         printf("rest_cost=%5u", rest_cost);
254                 }
255
256                 printf("\n");
257                 
258                 std::set<std::pair<unsigned, unsigned> >::iterator j = points_left.find(std::make_pair(best_tour[i].row, best_tour[i].num));
259                 points_left.erase(j);
260         }
261 }
262
263 unsigned do_tsp(std::vector<std::pair<unsigned, unsigned> > &points, std::set<std::pair<unsigned, unsigned> > &points_left, order *ord, order *temp, unsigned ind, unsigned cost_so_far)
264 {
265         if (cost_so_far >= best_so_far)
266                 return UINT_MAX;
267         if (ind == points.size()) {
268                 memcpy(best_tour, ord, sizeof(order) * points.size());
269                 printf("\nNew best tour found! cost=%u\n", cost_so_far);
270                 print_tour(points);
271                 best_so_far = cost_so_far;
272                 return 0;
273         }
274
275         /* 
276          * Simple heuristic: always search for the closest points from this one first; that
277          * will give us a sizable cutoff.
278          */
279         unsigned toi = 0;
280         unsigned last_row = ord[ind-1].row;
281         unsigned last_switch = ord[ind-1].num;
282         unsigned last_side = ord[ind-1].side;
283         
284         std::set<std::pair<unsigned, unsigned> > mst_set = points_left;
285         mst_set.insert(std::make_pair(last_row, last_switch));
286         
287         for (std::set<std::pair<unsigned, unsigned> >::iterator i = points_left.begin(); i != points_left.end(); ++i) {
288                 /* try both sides */
289                 temp[toi].row = i->first;
290                 temp[toi].num = i->second;
291                 temp[toi].side = 0;
292                 temp[toi].cost = cache(last_row, last_switch, last_side, i->first, i->second, 0);
293                 ++toi;
294
295                 temp[toi].row = i->first;
296                 temp[toi].num = i->second;
297                 temp[toi].side = 1;
298                 temp[toi].cost = cache(last_row, last_switch, last_side, i->first, i->second, 1);
299                 ++toi;
300         }
301
302         unsigned min_rest_cost = prim_mst(mst_set);
303         if (cost_so_far + min_rest_cost >= best_so_far) {
304                 return UINT_MAX;
305         }
306         
307         std::sort(temp, temp + toi);
308
309         unsigned best_this_cost = UINT_MAX;
310         for (unsigned i = 0; i < toi; ++i)
311         {
312                 ord[ind] = temp[i];
313                 
314                 std::set<std::pair<unsigned, unsigned> >::iterator j = points_left.find(std::make_pair(temp[i].row, temp[i].num));
315                 points_left.erase(j);
316                 unsigned cost_rest = do_tsp(points, points_left, ord, temp + points.size() * 2, ind + 1, cost_so_far + temp[i].cost);
317                 points_left.insert(std::make_pair(temp[i].row, temp[i].num));
318                 
319                 best_this_cost = std::min(best_this_cost, cost_rest);
320         }
321
322         return best_this_cost;
323 }
324
325 int main()
326 {
327         std::vector<std::pair<unsigned, unsigned> > points;
328         std::set<std::pair<unsigned, unsigned> > points_left;
329
330         for ( ;; ) {
331                 unsigned row, sw;
332                 if (scanf("%u-%u", &row, &sw) != 2)
333                         break;
334
335                 if (row < MIN_ROW || row > MAX_ROW || sw < MIN_SWITCH || sw > MAX_SWITCH) {
336                         fprintf(stderr, "%u-%u is out of bounds!\n", row, sw);
337                         exit(1);
338                 }
339
340                 points.push_back(std::make_pair(row, sw));
341                 if (points.size() != 1)
342                         points_left.insert(std::make_pair(row, sw));
343         }
344
345         // precalculate all distances
346         for (unsigned i = 0; i < points.size(); ++i) {
347                 for (unsigned j = 0; j < points.size(); ++j) {
348                         cache(points[i].first, points[i].second, 0, points[j].first, points[j].second, 0) =
349                                 distance(points[i].first, points[i].second, 0, points[j].first, points[j].second, 0);
350                         
351                         cache(points[i].first, points[i].second, 0, points[j].first, points[j].second, 1) =
352                                 distance(points[i].first, points[i].second, 0, points[j].first, points[j].second, 1);
353                         
354                         cache(points[i].first, points[i].second, 1, points[j].first, points[j].second, 0) =
355                                 distance(points[i].first, points[i].second, 1, points[j].first, points[j].second, 0);
356                         
357                         cache(points[i].first, points[i].second, 1, points[j].first, points[j].second, 1) =
358                                 distance(points[i].first, points[i].second, 1, points[j].first, points[j].second, 1);
359                         
360                         opt_cache(points[i].first, points[i].second, points[j].first, points[j].second) =
361                                 optimistic_distance(points[i].first, points[i].second, points[j].first, points[j].second);
362                 }
363         }
364
365         order *ord = new order[points.size()];
366         best_tour = new order[points.size()];
367         order *temp = new order[points.size() * points.size() * 4];
368         
369         /* always start at the first one, left side (hack) */
370         ord[0].row = points[0].first;
371         ord[0].num = points[0].second;
372         ord[0].side = 0;
373         
374         do_tsp(points, points_left, ord, temp, 1, 0);
375         printf("All done.\n");
376 }
377
378