From: Steinar H. Gunderson Date: Fri, 7 Apr 2006 19:56:03 +0000 (+0000) Subject: Implement a new version of the MST finder (disabled by default). X-Git-Url: https://git.sesse.net/?a=commitdiff_plain;h=089c57286db1cc0f8f344ca583146f5550fb11d4;hp=8009a5bab6f840576b44a6d70e833c3c2be2a88b;p=nms Implement a new version of the MST finder (disabled by default). --- diff --git a/tsp/tsp.cpp b/tsp/tsp.cpp index 760646c..e7fbab0 100644 --- a/tsp/tsp.cpp +++ b/tsp/tsp.cpp @@ -8,6 +8,7 @@ #define MAX_ROW 75 #define MIN_SWITCH 1 #define MAX_SWITCH 6 +#define HEAP_MST 0 static const unsigned num_cache_elem = (MAX_ROW * MAX_SWITCH * 2) * (MAX_ROW * MAX_SWITCH * 2); static unsigned short dist_cache[(MAX_ROW * MAX_SWITCH * 2) * (MAX_ROW * MAX_SWITCH * 2)], opt_dist_cache[MAX_ROW * MAX_SWITCH * MAX_ROW * MAX_SWITCH]; @@ -121,6 +122,72 @@ int optimistic_distance(int row_from, int switch_from, int row_to, int switch_to return distance(row_from, switch_from, 0, row_to, switch_to, 0); } +#if HEAP_MST +// this is, surprisingly enough, _slower_ than the naive variant below, so it's not enabled +struct prim_queue_val { + std::pair dst; + int cost; + + bool operator< (const prim_queue_val &other) const + { + return (cost > other.cost); + } +}; + +// standard O(V^2 log v) prim +int prim_mst(std::set > &in) +{ + std::set > set2; + std::priority_queue queue; + + // pick the first one + std::set >::iterator start = in.begin(); + + unsigned row = start->first; + unsigned num = start->second; + + set2.insert(*start); + + // find all the edges out from it + for (std::set >::iterator j = in.begin(); j != in.end(); ++j) { + if (set2.count(*j)) + continue; + + unsigned d = opt_cache(row, num, j->first, j->second); + prim_queue_val val = { *j, d }; + queue.push(val); + } + + unsigned total_cost = 0; + while (set2.size() != in.size()) { +invalid: + prim_queue_val val = queue.top(); + queue.pop(); + + // check if dst is already moved + if (set2.count(val.dst)) + goto invalid; + + unsigned row = val.dst.first; + unsigned num = val.dst.second; + set2.insert(val.dst); + + total_cost += val.cost; + + // find all the edges from this new node + for (std::set >::iterator j = in.begin(); j != in.end(); ++j) { + if (set2.count(*j)) + continue; + + unsigned d = opt_cache(row, num, j->first, j->second); + prim_queue_val val = { *j, d }; + queue.push(val); + } + } + + return total_cost; +} +#else // extremely primitive O(V^3) prim int prim_mst(std::set > &set1) { @@ -153,7 +220,7 @@ int prim_mst(std::set > &set1) return total_cost; } - +#endif void print_tour(std::vector > &points) {