]> git.sesse.net Git - nms/blob - tsp/tsp.cpp
Correct the estimates from last commit; unfortunately, they were overly pessimistic.
[nms] / tsp / tsp.cpp
1 #include <stdio.h>
2 #include <limits.h>
3 #include <vector>
4 #include <set>
5 #include <algorithm>
6
7 #define MIN_ROW 1
8 #define MAX_ROW 75
9 #define MIN_SWITCH 1
10 #define MAX_SWITCH 6
11 #define HEAP_MST 0
12
13 static const unsigned num_cache_elem = (MAX_ROW * MAX_SWITCH * 2) * (MAX_ROW * MAX_SWITCH * 2);
14 static unsigned short dist_cache[(MAX_ROW * MAX_SWITCH * 2) * (MAX_ROW * MAX_SWITCH * 2)],
15         opt_dist_cache[MAX_ROW * MAX_SWITCH * MAX_ROW * MAX_SWITCH],
16         pess_dist_cache[MAX_ROW * MAX_SWITCH * MAX_ROW * MAX_SWITCH];
17
18 inline unsigned short &cache(
19         unsigned row_from, unsigned switch_from, unsigned side_from,
20         unsigned row_to, unsigned switch_to, unsigned side_to)
21 {
22         return dist_cache[(row_from * MAX_SWITCH * 2 + switch_from * 2 + side_from) * (MAX_ROW * MAX_SWITCH * 2) +
23                 row_to * MAX_SWITCH * 2 + switch_to * 2 + side_to];
24 }
25
26 inline unsigned short &opt_cache(
27         unsigned row_from, unsigned switch_from,
28         unsigned row_to, unsigned switch_to)
29 {
30         return opt_dist_cache[(row_from * MAX_SWITCH + switch_from) * (MAX_ROW * MAX_SWITCH) +
31                 row_to * MAX_SWITCH + switch_to];
32 }
33
34 inline unsigned short &pess_cache(
35         unsigned row_from, unsigned switch_from,
36         unsigned row_to, unsigned switch_to)
37 {
38         return pess_dist_cache[(row_from * MAX_SWITCH + switch_from) * (MAX_ROW * MAX_SWITCH) +
39                 row_to * MAX_SWITCH + switch_to];
40 }
41
42 struct order {
43         unsigned row, num;
44         int side;
45         int cost;
46
47         bool operator< (const order &other) const
48         {
49                 return (cost < other.cost);
50         }
51 };
52
53 static unsigned best_so_far = UINT_MAX;
54 order *best_tour;
55
56 int distance_switch(unsigned from, unsigned to)
57 {
58         /* on the same side of the middle? 9.6m per switch. */
59         if ((from > 3) == (to > 3)) {
60                 return abs(from - to) * 96;
61         }
62
63         /* have to cross the border? 25.8m from sw3->sw4 => 16.2m extra gap. */
64         /* that's _got_ to be wrong. say it's 3m. */
65         return abs(from - to) * 96 + 30;
66 }
67
68 int distance_middle(unsigned sw, unsigned middle)
69 {
70         /* symmetry: 4-5-6 are just mirrored 3-2-1. */
71         if (middle == 2) {
72                 if (sw > 3)
73                         sw = 7 - sw;
74
75                 /* estimate 25.8m/2 = 12.9m from sw3 to the middle */
76                 return 129 + (3 - sw) * 96;
77         }
78         
79         /* more symmetry -- getting from 1-6 to the top is like getting from 6-1 to the bottom. */
80         if (middle == 3) {
81                 middle = 1;
82                 sw = 7 - sw;
83         }
84
85         /* guesstimate 4.8m extra to get to the bottom */
86         if (sw > 3)
87                 return 48 + 162 + (sw - 1) * 96;
88         else
89                 return 48 + (sw - 1) * 96;
90 }
91
92 int distance_row(unsigned from, unsigned to)
93 {
94         /* 4.1m per double row, plus gaps */
95         unsigned base_cost = 41 * abs(from - to);
96         
97         if ((from <= 9) != (to <= 9))
98                 base_cost += 25;
99         if ((from <= 17) != (to <= 17))
100                 base_cost += 25;
101         if ((from <= 25) != (to <= 25))
102                 base_cost += 25;
103         if ((from <= 34) != (to <= 34))
104                 base_cost += 25;
105
106         /* don't calculate gaps here just yet, just estimate 4.1m per double row */
107         return base_cost;
108 }
109
110 int pessimistic_distance(int row_from, int switch_from, int row_to, int switch_to)
111 {
112         /* we'll need to go to one of the three middles */
113         int best2 = distance_middle(switch_from, 2) + distance_middle(switch_to, 2);
114         int distrow = distance_row(row_from, row_to);
115         if ((switch_from > 3) != (switch_to > 3))
116                 return best2 + distrow;
117         if (switch_from > 3) {
118                 int best3 = distance_middle(switch_from, 3) + distance_middle(switch_to, 3);
119                 return std::min(best2, best3) + distrow;
120         } else {
121                 int best1 = distance_middle(switch_from, 1) + distance_middle(switch_to, 1);
122                 return std::min(best2, best1) + distrow;
123         }
124 }
125
126 int distance(int row_from, int switch_from, int side_from, int row_to, int switch_to, int side_to)
127 {
128         /* can we just walk directly? */
129         if (row_from == row_to && side_from == side_to) {
130                 return distance_switch(switch_from, switch_to);
131         }
132         
133         /* can we just switch sides? */
134         if (row_from + 1 == row_to && side_from == 1 && side_to == 0) {
135                 return distance_switch(switch_from, switch_to) + distance_row(row_from, row_to) - 31;
136         }
137         if (row_from == row_to + 1 && side_from == 0 && side_to == 1) {
138                 return distance_switch(switch_from, switch_to) + distance_row(row_from, row_to) - 31;
139         }
140
141         return pessimistic_distance(row_from, switch_from, row_to, switch_to);
142 }       
143
144 int optimistic_distance(int row_from, int switch_from, int row_to, int switch_to)
145 {
146         if (row_from == row_to)
147                 return distance_switch(switch_from, switch_to);
148         else if (abs(row_from - row_to) == 1)
149                 return distance_switch(switch_from, switch_to) + distance_row(row_from, row_to) - 31;
150         else
151                 return pessimistic_distance(row_from, switch_from, row_to, switch_to);
152 }
153
154 #if HEAP_MST
155 // this is, surprisingly enough, _slower_ than the naive variant below, so it's not enabled
156 struct prim_queue_val {
157         std::pair<unsigned, unsigned> dst;
158         int cost;
159
160         bool operator< (const prim_queue_val &other) const
161         {
162                 return (cost > other.cost);
163         }
164 };
165
166 // standard O(V^2 log v) prim
167 int prim_mst(std::set<std::pair<unsigned, unsigned> > &in)
168 {
169         std::set<std::pair<unsigned, unsigned> > set2;
170         std::priority_queue<prim_queue_val> queue;
171
172         // pick the first one
173         std::set<std::pair<unsigned, unsigned> >::iterator start = in.begin();
174         
175         unsigned row = start->first;
176         unsigned num = start->second;
177
178         set2.insert(*start);
179         
180         // find all the edges out from it
181         for (std::set<std::pair<unsigned, unsigned> >::iterator j = in.begin(); j != in.end(); ++j) {
182                 if (set2.count(*j))
183                         continue;
184                 
185                 unsigned d = opt_cache(row, num, j->first, j->second);
186                 prim_queue_val val = { *j, d };
187                 queue.push(val);
188         }
189
190         unsigned total_cost = 0;
191         while (set2.size() != in.size()) {
192 invalid:
193                 prim_queue_val val = queue.top();
194                 queue.pop();
195                 
196                 // check if dst is already moved
197                 if (set2.count(val.dst))
198                         goto invalid;
199         
200                 unsigned row = val.dst.first;
201                 unsigned num = val.dst.second;
202                 set2.insert(val.dst);
203
204                 total_cost += val.cost;
205
206                 // find all the edges from this new node
207                 for (std::set<std::pair<unsigned, unsigned> >::iterator j = in.begin(); j != in.end(); ++j) {
208                         if (set2.count(*j))
209                                 continue;
210                         
211                         unsigned d = opt_cache(row, num, j->first, j->second);
212                         prim_queue_val val = { *j, d };
213                         queue.push(val);
214                 }
215         }
216
217         return total_cost;
218 }
219 #else
220 // extremely primitive O(V^3) prim
221 int prim_mst(std::set<std::pair<unsigned, unsigned> > &set1)
222 {
223         std::set<std::pair<unsigned, unsigned> > set2;
224
225         // pick the first one
226         std::set<std::pair<unsigned, unsigned> >::iterator start = set1.begin();
227         set2.insert(*start);
228         set1.erase(start);
229
230         unsigned total_cost = 0;
231         while (set1.size() > 0) {
232                 unsigned best_this_cost = UINT_MAX;
233                 std::set<std::pair<unsigned, unsigned> >::iterator best_set1;
234                 
235                 for (std::set<std::pair<unsigned, unsigned> >::iterator i = set1.begin(); i != set1.end(); ++i) {
236                         for (std::set<std::pair<unsigned, unsigned> >::iterator j = set2.begin(); j != set2.end(); ++j) {
237                                 unsigned d = opt_cache(i->first, i->second, j->first, j->second);
238                                 if (d < best_this_cost) {
239                                         best_this_cost = d;
240                                         best_set1 = i;
241                                 }
242                         }
243                 }
244
245                 set2.insert(*best_set1);
246                 set1.erase(best_set1);
247                 total_cost += best_this_cost;
248         }
249
250         return total_cost;
251 }
252 #endif
253
254 void print_tour(std::vector<std::pair<unsigned, unsigned> > &points)
255 {
256         std::set<std::pair<unsigned, unsigned> > points_left;
257         for (unsigned i = 0; i < points.size(); ++i) {
258                 points_left.insert(points[i]);
259         }
260         
261         for (unsigned i = 0; i < points.size(); ++i) {
262                 if (best_tour[i].side == 0)
263                         printf("%2u-%u (left side)  ", best_tour[i].row, best_tour[i].num);
264                 else
265                         printf("%2u-%u (right side) ", best_tour[i].row, best_tour[i].num);
266                 if (i == 0) {
267                         printf("           ");
268                 } else {
269                         printf("cost=%4u  ", best_tour[i].cost);
270                 }
271
272                 // let's see how good the MST heuristics are
273                 if (i != points.size() - 1) {
274                         std::set<std::pair<unsigned, unsigned> > mst_tree = points_left;
275                         printf("mst_bound=%5u, ", prim_mst(mst_tree));
276
277                         unsigned rest_cost = 0;
278                         for (unsigned j = i + 1; j < points.size(); ++j) {
279                                 rest_cost += best_tour[j].cost;
280                         }
281                         
282                         printf("rest_cost=%5u", rest_cost);
283                 }
284
285                 printf("\n");
286                 
287                 std::set<std::pair<unsigned, unsigned> >::iterator j = points_left.find(std::make_pair(best_tour[i].row, best_tour[i].num));
288                 points_left.erase(j);
289         }
290 }
291
292 unsigned do_tsp(std::vector<std::pair<unsigned, unsigned> > &points, std::set<std::pair<unsigned, unsigned> > &points_left, order *ord, order *temp, unsigned ind, unsigned cost_so_far)
293 {
294         if (cost_so_far >= best_so_far)
295                 return UINT_MAX;
296         if (ind == points.size()) {
297                 memcpy(best_tour, ord, sizeof(order) * points.size());
298                 printf("\nNew best tour found! cost=%u\n", cost_so_far);
299                 print_tour(points);
300                 best_so_far = cost_so_far;
301                 return 0;
302         }
303         
304         unsigned last_row = ord[ind-1].row;
305         unsigned last_switch = ord[ind-1].num;
306         unsigned last_side = ord[ind-1].side;
307
308         /* 
309          * The minimum spanning tree gives us a reliable lower bound for what we can
310          * achieve for the rest of the graph, so check it before doing anything else.
311          */
312         std::set<std::pair<unsigned, unsigned> > mst_set = points_left;
313         mst_set.insert(std::make_pair(last_row, last_switch));
314         
315         unsigned min_rest_cost = prim_mst(mst_set);
316         if (cost_so_far + min_rest_cost >= best_so_far) {
317                 return UINT_MAX;
318         }
319
320         /* 
321          * Simple heuristic: always search for the closest points from this one first; that
322          * will give us a sizable cutoff.
323          */
324         unsigned toi = 0;
325         
326         for (std::set<std::pair<unsigned, unsigned> >::iterator i = points_left.begin(); i != points_left.end(); ++i) {
327                 /* try both sides */
328                 temp[toi].row = i->first;
329                 temp[toi].num = i->second;
330                 temp[toi].side = 0;
331                 temp[toi].cost = cache(last_row, last_switch, last_side, i->first, i->second, 0);
332                 ++toi;
333
334                 temp[toi].row = i->first;
335                 temp[toi].num = i->second;
336                 temp[toi].side = 1;
337                 temp[toi].cost = cache(last_row, last_switch, last_side, i->first, i->second, 1);
338                 ++toi;
339         }
340         std::sort(temp, temp + toi);
341
342         unsigned best_this_cost = UINT_MAX;
343         for (unsigned i = 0; i < toi; ++i)
344         {
345                 ord[ind] = temp[i];
346                 
347                 std::set<std::pair<unsigned, unsigned> >::iterator j = points_left.find(std::make_pair(temp[i].row, temp[i].num));
348                 points_left.erase(j);
349                 unsigned cost_rest = do_tsp(points, points_left, ord, temp + points.size() * 2, ind + 1, cost_so_far + temp[i].cost);
350                 points_left.insert(std::make_pair(temp[i].row, temp[i].num));
351                 
352                 best_this_cost = std::min(best_this_cost, cost_rest);
353         }
354
355         return best_this_cost;
356 }
357
358 int main()
359 {
360         std::vector<std::pair<unsigned, unsigned> > points;
361         std::set<std::pair<unsigned, unsigned> > points_left;
362
363         for ( ;; ) {
364                 unsigned row, sw;
365                 if (scanf("%u-%u", &row, &sw) != 2)
366                         break;
367
368                 if (row < MIN_ROW || row > MAX_ROW || sw < MIN_SWITCH || sw > MAX_SWITCH) {
369                         fprintf(stderr, "%u-%u is out of bounds!\n", row, sw);
370                         exit(1);
371                 }
372
373                 points.push_back(std::make_pair(row, sw));
374                 if (points.size() != 1)
375                         points_left.insert(std::make_pair(row, sw));
376         }
377
378         // precalculate all distances
379         for (unsigned i = 0; i < points.size(); ++i) {
380                 for (unsigned j = 0; j < points.size(); ++j) {
381                         cache(points[i].first, points[i].second, 0, points[j].first, points[j].second, 0) =
382                                 distance(points[i].first, points[i].second, 0, points[j].first, points[j].second, 0);
383                         
384                         cache(points[i].first, points[i].second, 0, points[j].first, points[j].second, 1) =
385                                 distance(points[i].first, points[i].second, 0, points[j].first, points[j].second, 1);
386                         
387                         cache(points[i].first, points[i].second, 1, points[j].first, points[j].second, 0) =
388                                 distance(points[i].first, points[i].second, 1, points[j].first, points[j].second, 0);
389                         
390                         cache(points[i].first, points[i].second, 1, points[j].first, points[j].second, 1) =
391                                 distance(points[i].first, points[i].second, 1, points[j].first, points[j].second, 1);
392                         
393                         opt_cache(points[i].first, points[i].second, points[j].first, points[j].second) =
394                                 optimistic_distance(points[i].first, points[i].second, points[j].first, points[j].second);
395                         
396                         pess_cache(points[i].first, points[i].second, points[j].first, points[j].second) =
397                                 pessimistic_distance(points[i].first, points[i].second, points[j].first, points[j].second);
398                 }
399         }
400         
401         order *ord = new order[points.size()];
402         best_tour = new order[points.size()];
403         order *temp = new order[points.size() * points.size() * 4];
404         
405         /* always start at the first one, left side (hack) */
406         ord[0].row = points[0].first;
407         ord[0].num = points[0].second;
408         ord[0].side = 0;
409         
410         do_tsp(points, points_left, ord, temp, 1, 0);
411         printf("All done.\n");
412 }
413
414