mlpack
dtb_rules_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_EMST_DTB_RULES_IMPL_HPP
13 #define MLPACK_METHODS_EMST_DTB_RULES_IMPL_HPP
14 
15 namespace mlpack {
16 namespace emst {
17 
18 template<typename MetricType, typename TreeType>
19 DTBRules<MetricType, TreeType>::
20 DTBRules(const arma::mat& dataSet,
21  UnionFind& connections,
22  arma::vec& neighborsDistances,
23  arma::Col<size_t>& neighborsInComponent,
24  arma::Col<size_t>& neighborsOutComponent,
25  MetricType& metric)
26 :
27  dataSet(dataSet),
28  connections(connections),
29  neighborsDistances(neighborsDistances),
30  neighborsInComponent(neighborsInComponent),
31  neighborsOutComponent(neighborsOutComponent),
32  metric(metric),
33  baseCases(0),
34  scores(0)
35 {
36  // Nothing else to do.
37 }
38 
39 template<typename MetricType, typename TreeType>
40 inline force_inline
41 double DTBRules<MetricType, TreeType>::BaseCase(const size_t queryIndex,
42  const size_t referenceIndex)
43 {
44  // Check if the points are in the same component at this iteration.
45  // If not, return the distance between them. Also, store a better result as
46  // the current neighbor, if necessary.
47  double newUpperBound = -1.0;
48 
49  // Find the index of the component the query is in.
50  size_t queryComponentIndex = connections.Find(queryIndex);
51 
52  size_t referenceComponentIndex = connections.Find(referenceIndex);
53 
54  if (queryComponentIndex != referenceComponentIndex)
55  {
56  ++baseCases;
57  double distance = metric.Evaluate(dataSet.col(queryIndex),
58  dataSet.col(referenceIndex));
59 
60  if (distance < neighborsDistances[queryComponentIndex])
61  {
62  Log::Assert(queryIndex != referenceIndex);
63 
64  neighborsDistances[queryComponentIndex] = distance;
65  neighborsInComponent[queryComponentIndex] = queryIndex;
66  neighborsOutComponent[queryComponentIndex] = referenceIndex;
67  }
68  }
69 
70  if (newUpperBound < neighborsDistances[queryComponentIndex])
71  newUpperBound = neighborsDistances[queryComponentIndex];
72 
73  Log::Assert(newUpperBound >= 0.0);
74 
75  return newUpperBound;
76 }
77 
78 template<typename MetricType, typename TreeType>
79 double DTBRules<MetricType, TreeType>::Score(const size_t queryIndex,
80  TreeType& referenceNode)
81 {
82  size_t queryComponentIndex = connections.Find(queryIndex);
83 
84  // If the query belongs to the same component as all of the references,
85  // then prune. The cast is to stop a warning about comparing unsigned to
86  // signed values.
87  if (queryComponentIndex ==
88  (size_t) referenceNode.Stat().ComponentMembership())
89  return DBL_MAX;
90 
91  const arma::vec queryPoint = dataSet.unsafe_col(queryIndex);
92  const double distance = referenceNode.MinDistance(queryPoint);
93 
94  // If all the points in the reference node are farther than the candidate
95  // nearest neighbor for the query's component, we prune.
96  return neighborsDistances[queryComponentIndex] < distance
97  ? DBL_MAX : distance;
98 }
99 
100 template<typename MetricType, typename TreeType>
101 double DTBRules<MetricType, TreeType>::Rescore(const size_t queryIndex,
102  TreeType& /* referenceNode */,
103  const double oldScore)
104 {
105  // We don't need to check component membership again, because it can't
106  // change inside a single iteration.
107  return (oldScore > neighborsDistances[connections.Find(queryIndex)])
108  ? DBL_MAX : oldScore;
109 }
110 
111 template<typename MetricType, typename TreeType>
112 double DTBRules<MetricType, TreeType>::Score(TreeType& queryNode,
113  TreeType& referenceNode)
114 {
115  // If all the queries belong to the same component as all the references
116  // then we prune.
117  if ((queryNode.Stat().ComponentMembership() >= 0) &&
118  (queryNode.Stat().ComponentMembership() ==
119  referenceNode.Stat().ComponentMembership()))
120  return DBL_MAX;
121 
122  ++scores;
123  const double distance = queryNode.MinDistance(referenceNode);
124  const double bound = CalculateBound(queryNode);
125 
126  // If all the points in the reference node are farther than the candidate
127  // nearest neighbor for all queries in the node, we prune.
128  return (bound < distance) ? DBL_MAX : distance;
129 }
130 
131 template<typename MetricType, typename TreeType>
132 double DTBRules<MetricType, TreeType>::Rescore(TreeType& queryNode,
133  TreeType& /* referenceNode */,
134  const double oldScore) const
135 {
136  const double bound = CalculateBound(queryNode);
137  return (oldScore > bound) ? DBL_MAX : oldScore;
138 }
139 
140 // Calculate the bound for a given query node in its current state and update
141 // it.
142 template<typename MetricType, typename TreeType>
144  TreeType& queryNode) const
145 {
146  double worstPointBound = -DBL_MAX;
147  double bestPointBound = DBL_MAX;
148 
149  double worstChildBound = -DBL_MAX;
150  double bestChildBound = DBL_MAX;
151 
152  // Now, find the best and worst point bounds.
153  for (size_t i = 0; i < queryNode.NumPoints(); ++i)
154  {
155  const size_t pointComponent = connections.Find(queryNode.Point(i));
156  const double bound = neighborsDistances[pointComponent];
157 
158  if (bound > worstPointBound)
159  worstPointBound = bound;
160  if (bound < bestPointBound)
161  bestPointBound = bound;
162  }
163 
164  // Find the best and worst child bounds.
165  for (size_t i = 0; i < queryNode.NumChildren(); ++i)
166  {
167  const double maxBound = queryNode.Child(i).Stat().MaxNeighborDistance();
168  if (maxBound > worstChildBound)
169  worstChildBound = maxBound;
170 
171  const double minBound = queryNode.Child(i).Stat().MinNeighborDistance();
172  if (minBound < bestChildBound)
173  bestChildBound = minBound;
174  }
175 
176  // Now calculate the actual bounds.
177  const double worstBound = std::max(worstPointBound, worstChildBound);
178  const double bestBound = std::min(bestPointBound, bestChildBound);
179  // We must check that bestBound != DBL_MAX; otherwise, we risk overflow.
180  const double bestAdjustedBound = (bestBound == DBL_MAX) ? DBL_MAX :
181  bestBound + 2 * queryNode.FurthestDescendantDistance();
182 
183  // Update the relevant quantities in the node.
184  queryNode.Stat().MaxNeighborDistance() = worstBound;
185  queryNode.Stat().MinNeighborDistance() = bestBound;
186  queryNode.Stat().Bound() = std::min(worstBound, bestAdjustedBound);
187 
188  return queryNode.Stat().Bound();
189 }
190 
191 } // namespace emst
192 } // namespace mlpack
193 
194 
195 
196 #endif
197 
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
Definition: dtb_rules.hpp:23
size_t Find(const size_t x)
Returns the component containing an element.
Definition: union_find.hpp:56
double Score(const size_t queryIndex, TreeType &referenceNode)
Get the score for recursion order.
Definition: dtb_rules_impl.hpp:79
double Rescore(const size_t queryIndex, TreeType &referenceNode, const double oldScore)
Re-evaluate the score for recursion order.
Definition: dtb_rules_impl.hpp:101
static void Assert(bool condition, const std::string &message="Assert Failed.")
Checks if the specified condition is true.
Definition: log.cpp:38