mlpack
fastmks_model_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_FASTMKS_FASTMKS_MODEL_IMPL_HPP
13 #define MLPACK_METHODS_FASTMKS_FASTMKS_MODEL_IMPL_HPP
14 
15 #include "fastmks_model.hpp"
16 
17 namespace mlpack {
18 namespace fastmks {
19 
21 template<typename KernelType>
23  KernelType& k,
24  arma::mat&& referenceData,
25  const double base)
26 {
27  // Do we need to build the tree?
28  if (base <= 1.0)
29  {
30  throw std::invalid_argument("base must be greater than 1");
31  }
32 
33  if (f.Naive())
34  {
35  f.Train(std::move(referenceData), k);
36  }
37  else
38  {
39  // Create the tree with the specified base.
40  Timer::Start("tree_building");
42  typename FastMKS<KernelType>::Tree* tree =
43  new typename FastMKS<KernelType>::Tree(std::move(referenceData),
44  metric, base);
45  Timer::Stop("tree_building");
46 
47  f.Train(tree);
48  }
49 }
50 
52 template<typename KernelType,
53  typename FastMKSType>
54 void BuildFastMKSModel(FastMKSType& /* f */,
55  KernelType& /* k */,
56  arma::mat&& /* referenceData */,
57  const double /* base */)
58 {
59  throw std::invalid_argument("FastMKSModel::BuildModel(): given kernel type is"
60  " not equal to kernel type of the model!");
61 }
62 
63 template<typename TKernelType>
64 void FastMKSModel::BuildModel(arma::mat&& referenceData,
65  TKernelType& kernel,
66  const bool singleMode,
67  const bool naive,
68  const double base)
69 {
70  // Clean memory if necessary.
71  if (linear)
72  delete linear;
73  if (polynomial)
74  delete polynomial;
75  if (cosine)
76  delete cosine;
77  if (gaussian)
78  delete gaussian;
79  if (epan)
80  delete epan;
81  if (triangular)
82  delete triangular;
83  if (hyptan)
84  delete hyptan;
85 
86  linear = NULL;
87  polynomial = NULL;
88  cosine = NULL;
89  gaussian = NULL;
90  epan = NULL;
91  triangular = NULL;
92  hyptan = NULL;
93 
94  // Instantiate the right model.
95  switch (kernelType)
96  {
97  case LINEAR_KERNEL:
98  linear = new FastMKS<kernel::LinearKernel>(singleMode, naive);
99  BuildFastMKSModel(*linear, kernel, std::move(referenceData), base);
100  break;
101 
102  case POLYNOMIAL_KERNEL:
103  polynomial = new FastMKS<kernel::PolynomialKernel>(singleMode, naive);
104  BuildFastMKSModel(*polynomial, kernel, std::move(referenceData), base);
105  break;
106 
107  case COSINE_DISTANCE:
108  cosine = new FastMKS<kernel::CosineDistance>(singleMode, naive);
109  BuildFastMKSModel(*cosine, kernel, std::move(referenceData), base);
110  break;
111 
112  case GAUSSIAN_KERNEL:
113  gaussian = new FastMKS<kernel::GaussianKernel>(singleMode, naive);
114  BuildFastMKSModel(*gaussian, kernel, std::move(referenceData), base);
115  break;
116 
117  case EPANECHNIKOV_KERNEL:
118  epan = new FastMKS<kernel::EpanechnikovKernel>(singleMode, naive);
119  BuildFastMKSModel(*epan, kernel, std::move(referenceData), base);
120  break;
121 
122  case TRIANGULAR_KERNEL:
123  triangular = new FastMKS<kernel::TriangularKernel>(singleMode, naive);
124  BuildFastMKSModel(*triangular, kernel, std::move(referenceData), base);
125  break;
126 
127  case HYPTAN_KERNEL:
128  hyptan = new FastMKS<kernel::HyperbolicTangentKernel>(singleMode, naive);
129  BuildFastMKSModel(*hyptan, kernel, std::move(referenceData), base);
130  break;
131  }
132 }
133 
134 template<typename Archive>
135 void FastMKSModel::serialize(Archive& ar, const uint32_t /* version */)
136 {
137  ar(CEREAL_NVP(kernelType));
138 
139  if (cereal::is_loading<Archive>())
140  {
141  // Clean memory.
142  if (linear)
143  delete linear;
144  if (polynomial)
145  delete polynomial;
146  if (cosine)
147  delete cosine;
148  if (gaussian)
149  delete gaussian;
150  if (epan)
151  delete epan;
152  if (triangular)
153  delete triangular;
154  if (hyptan)
155  delete hyptan;
156 
157  linear = NULL;
158  polynomial = NULL;
159  cosine = NULL;
160  gaussian = NULL;
161  epan = NULL;
162  triangular = NULL;
163  hyptan = NULL;
164  }
165 
166  // Serialize the correct model.
167  switch (kernelType)
168  {
169  case LINEAR_KERNEL:
170  ar(CEREAL_POINTER(linear));
171  break;
172 
173  case POLYNOMIAL_KERNEL:
174  ar(CEREAL_POINTER(polynomial));
175  break;
176 
177  case COSINE_DISTANCE:
178  ar(CEREAL_POINTER(cosine));
179  break;
180 
181  case GAUSSIAN_KERNEL:
182  ar(CEREAL_POINTER(gaussian));
183  break;
184 
185  case EPANECHNIKOV_KERNEL:
186  ar(CEREAL_POINTER(epan));
187  break;
188 
189  case TRIANGULAR_KERNEL:
190  ar(CEREAL_POINTER(triangular));
191  break;
192 
193  case HYPTAN_KERNEL:
194  ar(CEREAL_POINTER(hyptan));
195  break;
196  }
197 }
198 
199 template<typename FastMKSType>
200 void FastMKSModel::Search(FastMKSType& f,
201  const arma::mat& querySet,
202  const size_t k,
203  arma::Mat<size_t>& indices,
204  arma::mat& kernels,
205  const double base)
206 {
207  if (f.Naive() || f.SingleMode())
208  {
209  f.Search(querySet, k, indices, kernels);
210  }
211  else
212  {
213  Timer::Start("tree_building");
214  typename FastMKSType::Tree queryTree(querySet, base);
215  Timer::Stop("tree_building");
216 
217  f.Search(&queryTree, k, indices, kernels);
218  }
219 }
220 
221 } // namespace fastmks
222 } // namespace mlpack
223 
224 #endif
static void Start(const std::string &name)
Start the given timer.
Definition: timers.cpp:28
void Train(const MatType &referenceSet)
"Train" the FastMKS model on the given reference set (this will just build a tree, if the current search mode is not naive mode).
Definition: fastmks_impl.hpp:301
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
void Search(const arma::mat &querySet, const size_t k, arma::Mat< size_t > &indices, arma::mat &kernels, const double base)
Search with a different query set.
Definition: fastmks_model.cpp:250
void BuildModel(arma::mat &&referenceData, TKernelType &kernel, const bool singleMode, const bool naive, const double base)
Build the model on the given reference set.
Definition: fastmks_model_impl.hpp:64
The inner product metric, IPMetric, takes a given Mercer kernel (KernelType), and when Evaluate() is ...
Definition: ip_metric.hpp:32
void BuildFastMKSModel(FastMKS< KernelType > &f, KernelType &k, arma::mat &&referenceData, const double base)
This is called when the KernelType is the same as the model.
Definition: fastmks_model_impl.hpp:22
static void Stop(const std::string &name)
Stop the given timer.
Definition: timers.cpp:36
#define CEREAL_POINTER(T)
Cereal does not support the serialization of raw pointer.
Definition: pointer_wrapper.hpp:96
TreeType< metric::IPMetric< KernelType >, FastMKSStat, MatType > Tree
Convenience typedef.
Definition: fastmks.hpp:67
bool Naive() const
Get whether or not brute-force (naive) search is used.
Definition: fastmks.hpp:301
An implementation of fast exact max-kernel search.
Definition: fastmks.hpp:63
void serialize(Archive &ar, const uint32_t)
Serialize the model.
Definition: fastmks_model_impl.hpp:135
A cover tree is a tree specifically designed to speed up nearest-neighbor computation in high-dimensi...
Definition: cover_tree.hpp:99