12 #ifndef MLPACK_METHODS_KMEANS_DUAL_TREE_KMEANS_RULES_IMPL_HPP 13 #define MLPACK_METHODS_KMEANS_DUAL_TREE_KMEANS_RULES_IMPL_HPP 20 template<
typename MetricType,
typename TreeType>
21 DualTreeKMeansRules<MetricType, TreeType>::DualTreeKMeansRules(
22 const arma::mat& centroids,
23 const arma::mat& dataset,
24 arma::Row<size_t>& assignments,
25 arma::vec& upperBounds,
26 arma::vec& lowerBounds,
28 const std::vector<bool>& prunedPoints,
29 const std::vector<size_t>& oldFromNewCentroids,
30 std::vector<bool>& visited) :
33 assignments(assignments),
34 upperBounds(upperBounds),
35 lowerBounds(lowerBounds),
37 prunedPoints(prunedPoints),
38 oldFromNewCentroids(oldFromNewCentroids),
42 lastQueryIndex(dataset.n_cols),
43 lastReferenceIndex(centroids.n_cols),
53 template<
typename MetricType,
typename TreeType>
54 inline force_inline
double DualTreeKMeansRules<MetricType, TreeType>::BaseCase(
55 const size_t queryIndex,
56 const size_t referenceIndex)
58 if (prunedPoints[queryIndex])
62 if ((lastQueryIndex == queryIndex) && (lastReferenceIndex == referenceIndex))
66 visited[queryIndex] =
true;
70 const double distance = metric.Evaluate(dataset.col(queryIndex),
71 centroids.col(referenceIndex));
73 if (distance < upperBounds[queryIndex])
75 lowerBounds[queryIndex] = upperBounds[queryIndex];
76 upperBounds[queryIndex] = distance;
78 oldFromNewCentroids[referenceIndex] : referenceIndex;
80 else if (distance < lowerBounds[queryIndex])
82 lowerBounds[queryIndex] = distance;
86 lastQueryIndex = queryIndex;
87 lastReferenceIndex = referenceIndex;
88 lastBaseCase = distance;
93 template<
typename MetricType,
typename TreeType>
94 inline double DualTreeKMeansRules<MetricType, TreeType>::Score(
95 const size_t queryIndex,
99 if (prunedPoints[queryIndex])
107 template<
typename MetricType,
typename TreeType>
108 inline double DualTreeKMeansRules<MetricType, TreeType>::Score(
110 TreeType& referenceNode)
112 if (queryNode.Stat().StaticPruned() ==
true)
116 if (queryNode.Stat().Pruned() == size_t(-1))
118 queryNode.Stat().Pruned() = queryNode.Parent()->Stat().Pruned();
119 queryNode.Stat().LowerBound() = queryNode.Parent()->Stat().LowerBound();
120 queryNode.Stat().Owner() = queryNode.Parent()->Stat().Owner();
123 if (queryNode.Stat().Pruned() == centroids.n_cols)
129 const double queryParentDist = queryNode.ParentDistance();
130 const double queryDescDist = queryNode.FurthestDescendantDistance();
131 const double refParentDist = referenceNode.ParentDistance();
132 const double refDescDist = referenceNode.FurthestDescendantDistance();
133 const double lastScore = traversalInfo.
LastScore();
134 double adjustedScore;
144 else if (lastScore == 0.0)
156 const double lastQueryDescDist =
158 const double lastRefDescDist =
160 adjustedScore = lastScore + lastQueryDescDist + lastRefDescDist;
171 const double queryAdjust = queryParentDist + queryDescDist;
172 adjustedScore -= queryAdjust;
176 adjustedScore -= queryDescDist;
192 const double refAdjust = refParentDist + refDescDist;
193 adjustedScore -= refAdjust;
197 adjustedScore -= refDescDist;
213 if (adjustedScore > queryNode.Stat().UpperBound())
220 if (adjustedScore < queryNode.Stat().LowerBound())
223 queryNode.Stat().LowerBound() = std::min(queryNode.Stat().LowerBound(),
224 queryNode.MinDistance(referenceNode));
228 queryNode.Stat().Pruned() += referenceNode.NumDescendants();
233 if (score != DBL_MAX)
236 const math::Range distances = queryNode.RangeDistance(referenceNode);
238 score = distances.
Lo();
240 if (distances.Lo() > queryNode.Stat().UpperBound())
244 if (distances.Lo() < queryNode.Stat().LowerBound())
245 queryNode.Stat().LowerBound() = distances.Lo();
249 queryNode.Stat().Pruned() += referenceNode.NumDescendants();
252 else if (distances.Hi() < queryNode.Stat().UpperBound())
255 const double tighterBound =
256 queryNode.MaxDistance(centroids.col(referenceNode.Descendant(0)));
259 if (tighterBound <= queryNode.Stat().UpperBound())
262 queryNode.Stat().UpperBound() = tighterBound;
268 queryNode.Stat().Owner() =
270 oldFromNewCentroids[referenceNode.Descendant(0)] :
271 referenceNode.Descendant(0);
278 if (queryNode.Stat().Pruned() == centroids.n_cols - 1)
280 queryNode.Stat().Pruned() = centroids.n_cols;
293 template<
typename MetricType,
typename TreeType>
294 inline double DualTreeKMeansRules<MetricType, TreeType>::Rescore(
297 const double oldScore)
303 template<
typename MetricType,
typename TreeType>
304 inline double DualTreeKMeansRules<MetricType, TreeType>::Rescore(
306 TreeType& referenceNode,
307 const double oldScore)
309 if (oldScore == DBL_MAX)
315 if (oldScore > queryNode.Stat().UpperBound())
318 if (oldScore < queryNode.Stat().LowerBound())
319 queryNode.Stat().LowerBound() = oldScore;
322 queryNode.Stat().Pruned() += referenceNode.NumDescendants();
327 if (queryNode.Stat().Pruned() == centroids.n_cols - 1)
329 queryNode.Stat().Pruned() = centroids.n_cols;
T Lo() const
Get the lower bound.
Definition: range.hpp:61
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
static const bool RearrangesDataset
This is true if the tree rearranges points in the dataset when it is built.
Definition: tree_traits.hpp:105
RangeType< double > Range
3.0.0 TODO: break reverse-compatibility by changing RangeType to Range.
Definition: range.hpp:19
static const bool FirstPointIsCentroid
This is true if the first point of each node is the centroid of its bound.
Definition: tree_traits.hpp:94
double LastBaseCase() const
Get the base case associated with the last node combination.
Definition: traversal_info.hpp:78
TreeType * LastQueryNode() const
Get the last query node.
Definition: traversal_info.hpp:63
double LastScore() const
Get the score associated with the last query and reference nodes.
Definition: traversal_info.hpp:73
TreeType * LastReferenceNode() const
Get the last reference node.
Definition: traversal_info.hpp:68