mlpack
dual_tree_kmeans_rules_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_KMEANS_DUAL_TREE_KMEANS_RULES_IMPL_HPP
13 #define MLPACK_METHODS_KMEANS_DUAL_TREE_KMEANS_RULES_IMPL_HPP
14 
16 
17 namespace mlpack {
18 namespace kmeans {
19 
20 template<typename MetricType, typename TreeType>
21 DualTreeKMeansRules<MetricType, TreeType>::DualTreeKMeansRules(
22  const arma::mat& centroids,
23  const arma::mat& dataset,
24  arma::Row<size_t>& assignments,
25  arma::vec& upperBounds,
26  arma::vec& lowerBounds,
27  MetricType& metric,
28  const std::vector<bool>& prunedPoints,
29  const std::vector<size_t>& oldFromNewCentroids,
30  std::vector<bool>& visited) :
31  centroids(centroids),
32  dataset(dataset),
33  assignments(assignments),
34  upperBounds(upperBounds),
35  lowerBounds(lowerBounds),
36  metric(metric),
37  prunedPoints(prunedPoints),
38  oldFromNewCentroids(oldFromNewCentroids),
39  visited(visited),
40  baseCases(0),
41  scores(0),
42  lastQueryIndex(dataset.n_cols),
43  lastReferenceIndex(centroids.n_cols),
44  lastBaseCase(0.0)
45 {
46  // We must set the traversal info last query and reference node pointers to
47  // something that is both invalid (i.e. not a tree node) and not NULL. We'll
48  // use the this pointer.
49  traversalInfo.LastQueryNode() = (TreeType*) this;
50  traversalInfo.LastReferenceNode() = (TreeType*) this;
51 }
52 
53 template<typename MetricType, typename TreeType>
54 inline force_inline double DualTreeKMeansRules<MetricType, TreeType>::BaseCase(
55  const size_t queryIndex,
56  const size_t referenceIndex)
57 {
58  if (prunedPoints[queryIndex])
59  return 0.0; // Returning 0 shouldn't be a problem.
60 
61  // If we have already performed this base case, then do not perform it again.
62  if ((lastQueryIndex == queryIndex) && (lastReferenceIndex == referenceIndex))
63  return lastBaseCase;
64 
65  // Any base cases imply that we will get a result.
66  visited[queryIndex] = true;
67 
68  // Calculate the distance.
69  ++baseCases;
70  const double distance = metric.Evaluate(dataset.col(queryIndex),
71  centroids.col(referenceIndex));
72 
73  if (distance < upperBounds[queryIndex])
74  {
75  lowerBounds[queryIndex] = upperBounds[queryIndex];
76  upperBounds[queryIndex] = distance;
77  assignments[queryIndex] = (tree::TreeTraits<TreeType>::RearrangesDataset) ?
78  oldFromNewCentroids[referenceIndex] : referenceIndex;
79  }
80  else if (distance < lowerBounds[queryIndex])
81  {
82  lowerBounds[queryIndex] = distance;
83  }
84 
85  // Cache this information for the next time BaseCase() is called.
86  lastQueryIndex = queryIndex;
87  lastReferenceIndex = referenceIndex;
88  lastBaseCase = distance;
89 
90  return distance;
91 }
92 
93 template<typename MetricType, typename TreeType>
94 inline double DualTreeKMeansRules<MetricType, TreeType>::Score(
95  const size_t queryIndex,
96  TreeType& /* referenceNode */)
97 {
98  // If the query point has already been pruned, then don't recurse further.
99  if (prunedPoints[queryIndex])
100  return DBL_MAX;
101 
102  // No pruning at this level; we're not likely to encounter a single query
103  // point with a reference node..
104  return 0;
105 }
106 
107 template<typename MetricType, typename TreeType>
108 inline double DualTreeKMeansRules<MetricType, TreeType>::Score(
109  TreeType& queryNode,
110  TreeType& referenceNode)
111 {
112  if (queryNode.Stat().StaticPruned() == true)
113  return DBL_MAX;
114 
115  // Pruned() for the root node must never be set to size_t(-1).
116  if (queryNode.Stat().Pruned() == size_t(-1))
117  {
118  queryNode.Stat().Pruned() = queryNode.Parent()->Stat().Pruned();
119  queryNode.Stat().LowerBound() = queryNode.Parent()->Stat().LowerBound();
120  queryNode.Stat().Owner() = queryNode.Parent()->Stat().Owner();
121  }
122 
123  if (queryNode.Stat().Pruned() == centroids.n_cols)
124  return DBL_MAX;
125 
126  // This looks a lot like the hackery used in NeighborSearchRules to avoid
127  // distance computations. We'll use the traversal info to see if a
128  // parent-child or parent-parent prune is possible.
129  const double queryParentDist = queryNode.ParentDistance();
130  const double queryDescDist = queryNode.FurthestDescendantDistance();
131  const double refParentDist = referenceNode.ParentDistance();
132  const double refDescDist = referenceNode.FurthestDescendantDistance();
133  const double lastScore = traversalInfo.LastScore();
134  double adjustedScore;
135  double score = 0.0;
136 
137  // We want to set adjustedScore to be the distance between the centroid of the
138  // last query node and last reference node. We will do this by adjusting the
139  // last score. In some cases, we can just use the last base case.
141  {
142  adjustedScore = traversalInfo.LastBaseCase();
143  }
144  else if (lastScore == 0.0) // Nothing we can do here.
145  {
146  adjustedScore = 0.0;
147  }
148  else
149  {
150  // The last score is equal to the distance between the centroids minus the
151  // radii of the query and reference bounds along the axis of the line
152  // between the two centroids. In the best case, these radii are the
153  // furthest descendant distances, but that is not always true. It would
154  // take too long to calculate the exact radii, so we are forced to use
155  // MinimumBoundDistance() as a lower-bound approximation.
156  const double lastQueryDescDist =
157  traversalInfo.LastQueryNode()->MinimumBoundDistance();
158  const double lastRefDescDist =
159  traversalInfo.LastReferenceNode()->MinimumBoundDistance();
160  adjustedScore = lastScore + lastQueryDescDist + lastRefDescDist;
161  }
162 
163  // Assemble an adjusted score. For nearest neighbor search, this adjusted
164  // score is a lower bound on MinDistance(queryNode, referenceNode) that is
165  // assembled without actually calculating MinDistance(). For furthest
166  // neighbor search, it is an upper bound on
167  // MaxDistance(queryNode, referenceNode). If the traversalInfo isn't usable
168  // then the node should not be pruned by this.
169  if (traversalInfo.LastQueryNode() == queryNode.Parent())
170  {
171  const double queryAdjust = queryParentDist + queryDescDist;
172  adjustedScore -= queryAdjust;
173  }
174  else if (traversalInfo.LastQueryNode() == &queryNode)
175  {
176  adjustedScore -= queryDescDist;
177  }
178  else
179  {
180  // The last query node wasn't this query node or its parent. So we force
181  // the adjustedScore to be such that this combination can't be pruned here,
182  // because we don't really know anything about it.
183 
184  // It would be possible to modify this section to try and make a prune based
185  // on the query descendant distance and the distance between the query node
186  // and last traversal query node, but this case doesn't actually happen for
187  // kd-trees or cover trees.
188  adjustedScore = 0.0;
189  }
190  if (traversalInfo.LastReferenceNode() == referenceNode.Parent())
191  {
192  const double refAdjust = refParentDist + refDescDist;
193  adjustedScore -= refAdjust;
194  }
195  else if (traversalInfo.LastReferenceNode() == &referenceNode)
196  {
197  adjustedScore -= refDescDist;
198  }
199  else
200  {
201  // The last reference node wasn't this reference node or its parent. So we
202  // force the adjustedScore to be such that this combination can't be pruned
203  // here, because we don't really know anything about it.
204 
205  // It would be possible to modify this section to try and make a prune based
206  // on the reference descendant distance and the distance between the
207  // reference node and last traversal reference node, but this case doesn't
208  // actually happen for kd-trees or cover trees.
209  adjustedScore = 0.0;
210  }
211 
212  // Now, check if we can prune.
213  if (adjustedScore > queryNode.Stat().UpperBound())
214  {
216  {
217  // There isn't any need to set the traversal information because no
218  // descendant combinations will be visited, and those are the only
219  // combinations that would depend on the traversal information.
220  if (adjustedScore < queryNode.Stat().LowerBound())
221  {
222  // If this might affect the lower bound, make it more exact.
223  queryNode.Stat().LowerBound() = std::min(queryNode.Stat().LowerBound(),
224  queryNode.MinDistance(referenceNode));
225  ++scores;
226  }
227 
228  queryNode.Stat().Pruned() += referenceNode.NumDescendants();
229  score = DBL_MAX;
230  }
231  }
232 
233  if (score != DBL_MAX)
234  {
235  // Get minimum and maximum distances.
236  const math::Range distances = queryNode.RangeDistance(referenceNode);
237 
238  score = distances.Lo();
239  ++scores;
240  if (distances.Lo() > queryNode.Stat().UpperBound())
241  {
242  // The reference node can own no points in this query node. We may
243  // improve the lower bound on pruned nodes, though.
244  if (distances.Lo() < queryNode.Stat().LowerBound())
245  queryNode.Stat().LowerBound() = distances.Lo();
246 
247  // This assumes that reference clusters don't appear elsewhere in the
248  // tree.
249  queryNode.Stat().Pruned() += referenceNode.NumDescendants();
250  score = DBL_MAX;
251  }
252  else if (distances.Hi() < queryNode.Stat().UpperBound())
253  {
254  // Tighten upper bound.
255  const double tighterBound =
256  queryNode.MaxDistance(centroids.col(referenceNode.Descendant(0)));
257  ++scores; // Count extra distance calculation.
258 
259  if (tighterBound <= queryNode.Stat().UpperBound())
260  {
261  // We can improve the best estimate.
262  queryNode.Stat().UpperBound() = tighterBound;
263 
264  // Remember that our upper bound does correspond to a cluster centroid,
265  // so it does correspond to a cluster. We'll mark the cluster as the
266  // owner, but note that the node is not truly owned unless
267  // Stat().Pruned() is centroids.n_cols.
268  queryNode.Stat().Owner() =
270  oldFromNewCentroids[referenceNode.Descendant(0)] :
271  referenceNode.Descendant(0);
272  }
273  }
274  }
275 
276  // Is everything pruned?
277 
278  if (queryNode.Stat().Pruned() == centroids.n_cols - 1)
279  {
280  queryNode.Stat().Pruned() = centroids.n_cols; // Owner() is already set.
281  return DBL_MAX;
282  }
283 
284 
285  // Set traversal information.
286  traversalInfo.LastQueryNode() = &queryNode;
287  traversalInfo.LastReferenceNode() = &referenceNode;
288  traversalInfo.LastScore() = score;
289 
290  return score;
291 }
292 
293 template<typename MetricType, typename TreeType>
294 inline double DualTreeKMeansRules<MetricType, TreeType>::Rescore(
295  const size_t /* queryIndex */,
296  TreeType& /* referenceNode */,
297  const double oldScore)
298 {
299  // No rescoring (for now).
300  return oldScore;
301 }
302 
303 template<typename MetricType, typename TreeType>
304 inline double DualTreeKMeansRules<MetricType, TreeType>::Rescore(
305  TreeType& queryNode,
306  TreeType& referenceNode,
307  const double oldScore)
308 {
309  if (oldScore == DBL_MAX)
310  return DBL_MAX; // It's already pruned.
311 
312  // oldScore contains the minimum distance between queryNode and referenceNode.
313  // In the time since Score() has been called, the upper bound *may* have
314  // tightened. If it has tightened enough, we may prune this node now.
315  if (oldScore > queryNode.Stat().UpperBound())
316  {
317  // We may still be able to improve the lower bound on pruned nodes.
318  if (oldScore < queryNode.Stat().LowerBound())
319  queryNode.Stat().LowerBound() = oldScore;
320 
321  // This assumes that reference clusters don't appear elsewhere in the tree.
322  queryNode.Stat().Pruned() += referenceNode.NumDescendants();
323  return DBL_MAX;
324  }
325 
326  // Also, check if everything has been pruned.
327  if (queryNode.Stat().Pruned() == centroids.n_cols - 1)
328  {
329  queryNode.Stat().Pruned() = centroids.n_cols; // Owner() is already set.
330  return DBL_MAX;
331  }
332 
333  return oldScore;
334 }
335 
336 } // namespace kmeans
337 } // namespace mlpack
338 
339 #endif
T Lo() const
Get the lower bound.
Definition: range.hpp:61
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
static const bool RearrangesDataset
This is true if the tree rearranges points in the dataset when it is built.
Definition: tree_traits.hpp:105
RangeType< double > Range
3.0.0 TODO: break reverse-compatibility by changing RangeType to Range.
Definition: range.hpp:19
static const bool FirstPointIsCentroid
This is true if the first point of each node is the centroid of its bound.
Definition: tree_traits.hpp:94
double LastBaseCase() const
Get the base case associated with the last node combination.
Definition: traversal_info.hpp:78
TreeType * LastQueryNode() const
Get the last query node.
Definition: traversal_info.hpp:63
double LastScore() const
Get the score associated with the last query and reference nodes.
Definition: traversal_info.hpp:73
TreeType * LastReferenceNode() const
Get the last reference node.
Definition: traversal_info.hpp:68