12 #ifndef MLPACK_METHODS_EMST_DTB_IMPL_HPP 13 #define MLPACK_METHODS_EMST_DTB_IMPL_HPP 21 template<
typename TreeType,
typename MatType>
24 std::vector<size_t>& oldFromNew,
25 const typename std::enable_if<
28 return new TreeType(std::forward<MatType>(dataset), oldFromNew);
32 template<
typename TreeType,
typename MatType>
35 const std::vector<size_t>& ,
36 const typename std::enable_if<
39 return new TreeType(std::forward<MatType>(dataset));
49 template<
typename TreeMetricType,
50 typename TreeStatType,
51 typename TreeMatType>
class TreeType>
53 const MatType& dataset,
55 const MetricType metric) :
56 tree(naive ? NULL : BuildTree<Tree>(dataset, oldFromNew)),
57 data(naive ? dataset : tree->Dataset()),
60 connections(dataset.n_cols),
64 edges.reserve(
data.n_cols - 1);
66 neighborsInComponent.set_size(
data.n_cols);
67 neighborsOutComponent.set_size(
data.n_cols);
68 neighborsDistances.set_size(
data.n_cols);
69 neighborsDistances.fill(DBL_MAX);
75 template<
typename TreeMetricType,
76 typename TreeStatType,
77 typename TreeMatType>
class TreeType>
80 const MetricType metric) :
82 data(tree->Dataset()),
85 connections(data.n_cols),
89 edges.reserve(
data.n_cols - 1);
91 neighborsInComponent.set_size(
data.n_cols);
92 neighborsOutComponent.set_size(
data.n_cols);
93 neighborsDistances.set_size(
data.n_cols);
94 neighborsDistances.fill(DBL_MAX);
100 template<
typename TreeMetricType,
101 typename TreeStatType,
102 typename TreeMatType>
class TreeType>
116 template<
typename TreeMetricType,
117 typename TreeStatType,
118 typename TreeMatType>
class TreeType>
127 RuleType rules(
data, connections, neighborsDistances, neighborsInComponent,
128 neighborsOutComponent, metric);
129 while (edges.size() < (
data.n_cols - 1))
134 for (
size_t i = 0; i <
data.n_cols; ++i)
135 for (
size_t j = 0; j <
data.n_cols; ++j)
136 rules.BaseCase(i, j);
140 typename Tree::template DualTreeTraverser<RuleType> traverser(rules);
141 traverser.Traverse(*tree, *tree);
148 Log::Info << edges.size() <<
" edges found so far." << std::endl;
151 Log::Info << rules.BaseCases() <<
" cumulative base cases." << std::endl;
152 Log::Info << rules.Scores() <<
" cumulative node combinations scored." 159 EmitResults(results);
161 Log::Info <<
"Total spanning tree length: " << totalDist << std::endl;
170 template<
typename TreeMetricType,
171 typename TreeStatType,
172 typename TreeMatType>
class TreeType>
173 void DualTreeBoruvka<MetricType, MatType, TreeType>::AddEdge(
176 const double distance)
179 "DualTreeBoruvka::AddEdge(): distance cannot be negative.");
182 edges.push_back(
EdgePair(e1, e2, distance));
184 edges.push_back(
EdgePair(e2, e1, distance));
193 template<
typename TreeMetricType,
194 typename TreeStatType,
195 typename TreeMatType>
class TreeType>
196 void DualTreeBoruvka<MetricType, MatType, TreeType>::AddAllEdges()
198 for (
size_t i = 0; i <
data.n_cols; ++i)
200 size_t component = connections.Find(i);
201 size_t inEdge = neighborsInComponent[component];
202 size_t outEdge = neighborsOutComponent[component];
203 if (connections.Find(inEdge) != connections.Find(outEdge))
207 totalDist += neighborsDistances[component];
208 AddEdge(inEdge, outEdge, neighborsDistances[component]);
209 connections.Union(inEdge, outEdge);
220 template<
typename TreeMetricType,
221 typename TreeStatType,
222 typename TreeMatType>
class TreeType>
223 void DualTreeBoruvka<MetricType, MatType, TreeType>::EmitResults(
227 std::sort(edges.begin(), edges.end(), SortFun);
230 results.set_size(3, edges.size());
235 for (
size_t i = 0; i < (
data.n_cols - 1); ++i)
239 size_t ind1 = oldFromNew[edges[i].Lesser()];
240 size_t ind2 = oldFromNew[edges[i].Greater()];
244 edges[i].Lesser() = ind1;
245 edges[i].Greater() = ind2;
249 edges[i].Lesser() = ind2;
250 edges[i].Greater() = ind1;
253 results(0, i) = edges[i].Lesser();
254 results(1, i) = edges[i].Greater();
255 results(2, i) = edges[i].Distance();
260 for (
size_t i = 0; i < edges.size(); ++i)
262 results(0, i) = edges[i].Lesser();
263 results(1, i) = edges[i].Greater();
264 results(2, i) = edges[i].Distance();
276 template<
typename TreeMetricType,
277 typename TreeStatType,
278 typename TreeMatType>
class TreeType>
279 void DualTreeBoruvka<MetricType, MatType, TreeType>::CleanupHelper(Tree* tree)
282 tree->Stat().MaxNeighborDistance() = DBL_MAX;
283 tree->Stat().MinNeighborDistance() = DBL_MAX;
284 tree->Stat().Bound() = DBL_MAX;
287 for (
size_t i = 0; i < tree->NumChildren(); ++i)
288 CleanupHelper(&tree->Child(i));
292 const int component = (tree->NumChildren() != 0) ?
293 tree->Child(0).Stat().ComponentMembership() :
294 connections.Find(tree->Point(0));
297 for (
size_t i = 0; i < tree->NumChildren(); ++i)
298 if (tree->Child(i).Stat().ComponentMembership() != component)
302 for (
size_t i = 0; i < tree->NumPoints(); ++i)
303 if (connections.Find(tree->Point(i)) !=
size_t(component))
307 tree->Stat().ComponentMembership() = component;
316 template<
typename TreeMetricType,
317 typename TreeStatType,
318 typename TreeMatType>
class TreeType>
319 void DualTreeBoruvka<MetricType, MatType, TreeType>::Cleanup()
321 for (
size_t i = 0; i <
data.n_cols; ++i)
322 neighborsDistances[i] = DBL_MAX;
An edge pair is simply two indices and a distance.
Definition: edge_pair.hpp:28
static void Start(const std::string &name)
Start the given timer.
Definition: timers.cpp:28
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
Definition: dtb_rules.hpp:23
static void Stop(const std::string &name)
Stop the given timer.
Definition: timers.cpp:36
The TreeTraits class provides compile-time information on the characteristics of a given tree type...
Definition: tree_traits.hpp:77
static MLPACK_EXPORT util::PrefixedOutStream Info
Prints informational messages if –verbose is specified, prefixed with [INFO ].
Definition: log.hpp:84
TreeType * BuildTree(MatType &&dataset, std::vector< size_t > &oldFromNew, const typename std::enable_if< tree::TreeTraits< TreeType >::RearrangesDataset >::type *=0)
Call the tree constructor that does mapping.
Definition: dtb_impl.hpp:22
Performs the MST calculation using the Dual-Tree Boruvka algorithm, using any type of tree...
Definition: dtb.hpp:83
static void Assert(bool condition, const std::string &message="Assert Failed.")
Checks if the specified condition is true.
Definition: log.cpp:38