12 #ifndef MLPACK_METHODS_EMST_DTB_RULES_IMPL_HPP 13 #define MLPACK_METHODS_EMST_DTB_RULES_IMPL_HPP 18 template<
typename MetricType,
typename TreeType>
19 DTBRules<MetricType, TreeType>::
20 DTBRules(
const arma::mat& dataSet,
21 UnionFind& connections,
22 arma::vec& neighborsDistances,
23 arma::Col<size_t>& neighborsInComponent,
24 arma::Col<size_t>& neighborsOutComponent,
28 connections(connections),
29 neighborsDistances(neighborsDistances),
30 neighborsInComponent(neighborsInComponent),
31 neighborsOutComponent(neighborsOutComponent),
39 template<
typename MetricType,
typename TreeType>
41 double DTBRules<MetricType, TreeType>::BaseCase(
const size_t queryIndex,
42 const size_t referenceIndex)
47 double newUpperBound = -1.0;
50 size_t queryComponentIndex = connections.
Find(queryIndex);
52 size_t referenceComponentIndex = connections.
Find(referenceIndex);
54 if (queryComponentIndex != referenceComponentIndex)
57 double distance = metric.Evaluate(dataSet.col(queryIndex),
58 dataSet.col(referenceIndex));
60 if (distance < neighborsDistances[queryComponentIndex])
64 neighborsDistances[queryComponentIndex] = distance;
65 neighborsInComponent[queryComponentIndex] = queryIndex;
66 neighborsOutComponent[queryComponentIndex] = referenceIndex;
70 if (newUpperBound < neighborsDistances[queryComponentIndex])
71 newUpperBound = neighborsDistances[queryComponentIndex];
78 template<
typename MetricType,
typename TreeType>
80 TreeType& referenceNode)
82 size_t queryComponentIndex = connections.
Find(queryIndex);
87 if (queryComponentIndex ==
88 (
size_t) referenceNode.Stat().ComponentMembership())
91 const arma::vec queryPoint = dataSet.unsafe_col(queryIndex);
92 const double distance = referenceNode.MinDistance(queryPoint);
96 return neighborsDistances[queryComponentIndex] < distance
100 template<
typename MetricType,
typename TreeType>
103 const double oldScore)
107 return (oldScore > neighborsDistances[connections.
Find(queryIndex)])
108 ? DBL_MAX : oldScore;
111 template<
typename MetricType,
typename TreeType>
113 TreeType& referenceNode)
117 if ((queryNode.Stat().ComponentMembership() >= 0) &&
118 (queryNode.Stat().ComponentMembership() ==
119 referenceNode.Stat().ComponentMembership()))
123 const double distance = queryNode.MinDistance(referenceNode);
124 const double bound = CalculateBound(queryNode);
128 return (bound < distance) ? DBL_MAX : distance;
131 template<
typename MetricType,
typename TreeType>
134 const double oldScore)
const 136 const double bound = CalculateBound(queryNode);
137 return (oldScore > bound) ? DBL_MAX : oldScore;
142 template<
typename MetricType,
typename TreeType>
144 TreeType& queryNode)
const 146 double worstPointBound = -DBL_MAX;
147 double bestPointBound = DBL_MAX;
149 double worstChildBound = -DBL_MAX;
150 double bestChildBound = DBL_MAX;
153 for (
size_t i = 0; i < queryNode.NumPoints(); ++i)
155 const size_t pointComponent = connections.
Find(queryNode.Point(i));
156 const double bound = neighborsDistances[pointComponent];
158 if (bound > worstPointBound)
159 worstPointBound = bound;
160 if (bound < bestPointBound)
161 bestPointBound = bound;
165 for (
size_t i = 0; i < queryNode.NumChildren(); ++i)
167 const double maxBound = queryNode.Child(i).Stat().MaxNeighborDistance();
168 if (maxBound > worstChildBound)
169 worstChildBound = maxBound;
171 const double minBound = queryNode.Child(i).Stat().MinNeighborDistance();
172 if (minBound < bestChildBound)
173 bestChildBound = minBound;
177 const double worstBound = std::max(worstPointBound, worstChildBound);
178 const double bestBound = std::min(bestPointBound, bestChildBound);
180 const double bestAdjustedBound = (bestBound == DBL_MAX) ? DBL_MAX :
181 bestBound + 2 * queryNode.FurthestDescendantDistance();
184 queryNode.Stat().MaxNeighborDistance() = worstBound;
185 queryNode.Stat().MinNeighborDistance() = bestBound;
186 queryNode.Stat().Bound() = std::min(worstBound, bestAdjustedBound);
188 return queryNode.Stat().Bound();
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
Definition: dtb_rules.hpp:23
size_t Find(const size_t x)
Returns the component containing an element.
Definition: union_find.hpp:56
double Score(const size_t queryIndex, TreeType &referenceNode)
Get the score for recursion order.
Definition: dtb_rules_impl.hpp:79
double Rescore(const size_t queryIndex, TreeType &referenceNode, const double oldScore)
Re-evaluate the score for recursion order.
Definition: dtb_rules_impl.hpp:101
static void Assert(bool condition, const std::string &message="Assert Failed.")
Checks if the specified condition is true.
Definition: log.cpp:38