mlpack
dtb_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_EMST_DTB_IMPL_HPP
13 #define MLPACK_METHODS_EMST_DTB_IMPL_HPP
14 
15 #include "dtb_rules.hpp"
16 
17 namespace mlpack {
18 namespace emst {
19 
21 template<typename TreeType, typename MatType>
22 TreeType* BuildTree(
23  MatType&& dataset,
24  std::vector<size_t>& oldFromNew,
25  const typename std::enable_if<
27 {
28  return new TreeType(std::forward<MatType>(dataset), oldFromNew);
29 }
30 
32 template<typename TreeType, typename MatType>
33 TreeType* BuildTree(
34  MatType&& dataset,
35  const std::vector<size_t>& /* oldFromNew */,
36  const typename std::enable_if<
38 {
39  return new TreeType(std::forward<MatType>(dataset));
40 }
41 
46 template<
47  typename MetricType,
48  typename MatType,
49  template<typename TreeMetricType,
50  typename TreeStatType,
51  typename TreeMatType> class TreeType>
52 DualTreeBoruvka<MetricType, MatType, TreeType>::DualTreeBoruvka(
53  const MatType& dataset,
54  const bool naive,
55  const MetricType metric) :
56  tree(naive ? NULL : BuildTree<Tree>(dataset, oldFromNew)),
57  data(naive ? dataset : tree->Dataset()),
58  ownTree(!naive),
59  naive(naive),
60  connections(dataset.n_cols),
61  totalDist(0.0),
62  metric(metric)
63 {
64  edges.reserve(data.n_cols - 1); // Set size.
65 
66  neighborsInComponent.set_size(data.n_cols);
67  neighborsOutComponent.set_size(data.n_cols);
68  neighborsDistances.set_size(data.n_cols);
69  neighborsDistances.fill(DBL_MAX);
70 }
71 
72 template<
73  typename MetricType,
74  typename MatType,
75  template<typename TreeMetricType,
76  typename TreeStatType,
77  typename TreeMatType> class TreeType>
78 DualTreeBoruvka<MetricType, MatType, TreeType>::DualTreeBoruvka(
79  Tree* tree,
80  const MetricType metric) :
81  tree(tree),
82  data(tree->Dataset()),
83  ownTree(false),
84  naive(false),
85  connections(data.n_cols),
86  totalDist(0.0),
87  metric(metric)
88 {
89  edges.reserve(data.n_cols - 1); // Fill with EdgePairs.
90 
91  neighborsInComponent.set_size(data.n_cols);
92  neighborsOutComponent.set_size(data.n_cols);
93  neighborsDistances.set_size(data.n_cols);
94  neighborsDistances.fill(DBL_MAX);
95 }
96 
97 template<
98  typename MetricType,
99  typename MatType,
100  template<typename TreeMetricType,
101  typename TreeStatType,
102  typename TreeMatType> class TreeType>
103 DualTreeBoruvka<MetricType, MatType, TreeType>::~DualTreeBoruvka()
104 {
105  if (ownTree)
106  delete tree;
107 }
108 
113 template<
114  typename MetricType,
115  typename MatType,
116  template<typename TreeMetricType,
117  typename TreeStatType,
118  typename TreeMatType> class TreeType>
119 void DualTreeBoruvka<MetricType, MatType, TreeType>::ComputeMST(
120  arma::mat& results)
121 {
122  Timer::Start("emst/mst_computation");
123 
124  totalDist = 0; // Reset distance.
125 
126  typedef DTBRules<MetricType, Tree> RuleType;
127  RuleType rules(data, connections, neighborsDistances, neighborsInComponent,
128  neighborsOutComponent, metric);
129  while (edges.size() < (data.n_cols - 1))
130  {
131  if (naive)
132  {
133  // Full O(N^2) traversal.
134  for (size_t i = 0; i < data.n_cols; ++i)
135  for (size_t j = 0; j < data.n_cols; ++j)
136  rules.BaseCase(i, j);
137  }
138  else
139  {
140  typename Tree::template DualTreeTraverser<RuleType> traverser(rules);
141  traverser.Traverse(*tree, *tree);
142  }
143 
144  AddAllEdges();
145 
146  Cleanup();
147 
148  Log::Info << edges.size() << " edges found so far." << std::endl;
149  if (!naive)
150  {
151  Log::Info << rules.BaseCases() << " cumulative base cases." << std::endl;
152  Log::Info << rules.Scores() << " cumulative node combinations scored."
153  << std::endl;
154  }
155  }
156 
157  Timer::Stop("emst/mst_computation");
158 
159  EmitResults(results);
160 
161  Log::Info << "Total spanning tree length: " << totalDist << std::endl;
162 }
163 
167 template<
168  typename MetricType,
169  typename MatType,
170  template<typename TreeMetricType,
171  typename TreeStatType,
172  typename TreeMatType> class TreeType>
173 void DualTreeBoruvka<MetricType, MatType, TreeType>::AddEdge(
174  const size_t e1,
175  const size_t e2,
176  const double distance)
177 {
178  Log::Assert((distance >= 0.0),
179  "DualTreeBoruvka::AddEdge(): distance cannot be negative.");
180 
181  if (e1 < e2)
182  edges.push_back(EdgePair(e1, e2, distance));
183  else
184  edges.push_back(EdgePair(e2, e1, distance));
185 }
186 
190 template<
191  typename MetricType,
192  typename MatType,
193  template<typename TreeMetricType,
194  typename TreeStatType,
195  typename TreeMatType> class TreeType>
196 void DualTreeBoruvka<MetricType, MatType, TreeType>::AddAllEdges()
197 {
198  for (size_t i = 0; i < data.n_cols; ++i)
199  {
200  size_t component = connections.Find(i);
201  size_t inEdge = neighborsInComponent[component];
202  size_t outEdge = neighborsOutComponent[component];
203  if (connections.Find(inEdge) != connections.Find(outEdge))
204  {
205  // totalDist = totalDist + dist;
206  // changed to make this agree with the cover tree code
207  totalDist += neighborsDistances[component];
208  AddEdge(inEdge, outEdge, neighborsDistances[component]);
209  connections.Union(inEdge, outEdge);
210  }
211  }
212 }
213 
217 template<
218  typename MetricType,
219  typename MatType,
220  template<typename TreeMetricType,
221  typename TreeStatType,
222  typename TreeMatType> class TreeType>
223 void DualTreeBoruvka<MetricType, MatType, TreeType>::EmitResults(
224  arma::mat& results)
225 {
226  // Sort the edges.
227  std::sort(edges.begin(), edges.end(), SortFun);
228 
229  Log::Assert(edges.size() == data.n_cols - 1);
230  results.set_size(3, edges.size());
231 
232  // Need to unpermute the point labels.
233  if (!naive && ownTree && tree::TreeTraits<Tree>::RearrangesDataset)
234  {
235  for (size_t i = 0; i < (data.n_cols - 1); ++i)
236  {
237  // Make sure the edge list stores the smaller index first to
238  // make checking correctness easier
239  size_t ind1 = oldFromNew[edges[i].Lesser()];
240  size_t ind2 = oldFromNew[edges[i].Greater()];
241 
242  if (ind1 < ind2)
243  {
244  edges[i].Lesser() = ind1;
245  edges[i].Greater() = ind2;
246  }
247  else
248  {
249  edges[i].Lesser() = ind2;
250  edges[i].Greater() = ind1;
251  }
252 
253  results(0, i) = edges[i].Lesser();
254  results(1, i) = edges[i].Greater();
255  results(2, i) = edges[i].Distance();
256  }
257  }
258  else
259  {
260  for (size_t i = 0; i < edges.size(); ++i)
261  {
262  results(0, i) = edges[i].Lesser();
263  results(1, i) = edges[i].Greater();
264  results(2, i) = edges[i].Distance();
265  }
266  }
267 }
268 
273 template<
274  typename MetricType,
275  typename MatType,
276  template<typename TreeMetricType,
277  typename TreeStatType,
278  typename TreeMatType> class TreeType>
279 void DualTreeBoruvka<MetricType, MatType, TreeType>::CleanupHelper(Tree* tree)
280 {
281  // Reset the statistic information.
282  tree->Stat().MaxNeighborDistance() = DBL_MAX;
283  tree->Stat().MinNeighborDistance() = DBL_MAX;
284  tree->Stat().Bound() = DBL_MAX;
285 
286  // Recurse into all children.
287  for (size_t i = 0; i < tree->NumChildren(); ++i)
288  CleanupHelper(&tree->Child(i));
289 
290  // Get the component of the first child or point. Then we will check to see
291  // if all other components of children and points are the same.
292  const int component = (tree->NumChildren() != 0) ?
293  tree->Child(0).Stat().ComponentMembership() :
294  connections.Find(tree->Point(0));
295 
296  // Check components of children.
297  for (size_t i = 0; i < tree->NumChildren(); ++i)
298  if (tree->Child(i).Stat().ComponentMembership() != component)
299  return;
300 
301  // Check components of points.
302  for (size_t i = 0; i < tree->NumPoints(); ++i)
303  if (connections.Find(tree->Point(i)) != size_t(component))
304  return;
305 
306  // If we made it this far, all components are the same.
307  tree->Stat().ComponentMembership() = component;
308 }
309 
313 template<
314  typename MetricType,
315  typename MatType,
316  template<typename TreeMetricType,
317  typename TreeStatType,
318  typename TreeMatType> class TreeType>
319 void DualTreeBoruvka<MetricType, MatType, TreeType>::Cleanup()
320 {
321  for (size_t i = 0; i < data.n_cols; ++i)
322  neighborsDistances[i] = DBL_MAX;
323 
324  if (!naive)
325  CleanupHelper(tree);
326 }
327 
328 } // namespace emst
329 } // namespace mlpack
330 
331 #endif
An edge pair is simply two indices and a distance.
Definition: edge_pair.hpp:28
static void Start(const std::string &name)
Start the given timer.
Definition: timers.cpp:28
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
Definition: dtb_rules.hpp:23
static void Stop(const std::string &name)
Stop the given timer.
Definition: timers.cpp:36
The TreeTraits class provides compile-time information on the characteristics of a given tree type...
Definition: tree_traits.hpp:77
static MLPACK_EXPORT util::PrefixedOutStream Info
Prints informational messages if –verbose is specified, prefixed with [INFO ].
Definition: log.hpp:84
TreeType * BuildTree(MatType &&dataset, std::vector< size_t > &oldFromNew, const typename std::enable_if< tree::TreeTraits< TreeType >::RearrangesDataset >::type *=0)
Call the tree constructor that does mapping.
Definition: dtb_impl.hpp:22
Performs the MST calculation using the Dual-Tree Boruvka algorithm, using any type of tree...
Definition: dtb.hpp:83
static void Assert(bool condition, const std::string &message="Assert Failed.")
Checks if the specified condition is true.
Definition: log.cpp:38