mlpack
cf_model_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_CF_CF_MODEL_IMPL_HPP
13 #define MLPACK_METHODS_CF_CF_MODEL_IMPL_HPP
14 
15 #include "cf_model.hpp"
16 
20 
24 
28 #include "decomposition_policies/randomized_svd_method.hpp"
33 
39 
40 namespace mlpack {
41 namespace cf {
42 
43 template<typename NeighborSearchPolicy, typename CFType>
44 void PredictHelper(CFType& cf,
45  const InterpolationTypes interpolationType,
46  const arma::Mat<size_t>& combinations,
47  arma::vec& predictions)
48 {
49  switch (interpolationType)
50  {
51  case AVERAGE_INTERPOLATION:
52  cf.template Predict<NeighborSearchPolicy,
53  AverageInterpolation>(combinations, predictions);
54  break;
55 
56  case REGRESSION_INTERPOLATION:
57  cf.template Predict<NeighborSearchPolicy,
58  RegressionInterpolation>(combinations, predictions);
59  break;
60 
61  case SIMILARITY_INTERPOLATION:
62  cf.template Predict<NeighborSearchPolicy,
63  SimilarityInterpolation>(combinations, predictions);
64  break;
65  }
66 }
67 
69 template<typename DecompositionPolicy, typename NormalizationPolicy>
71  const NeighborSearchTypes nsType,
72  const InterpolationTypes interpolationType,
73  const arma::Mat<size_t>& combinations,
74  arma::vec& predictions)
75 {
76  switch (nsType)
77  {
78  case COSINE_SEARCH:
79  PredictHelper<CosineSearch>(cf, interpolationType, combinations,
80  predictions);
81  break;
82 
83  case EUCLIDEAN_SEARCH:
84  PredictHelper<EuclideanSearch>(cf, interpolationType, combinations,
85  predictions);
86  break;
87 
88  case PEARSON_SEARCH:
89  PredictHelper<PearsonSearch>(cf, interpolationType, combinations,
90  predictions);
91  break;
92  }
93 }
94 
95 template<typename NeighborSearchPolicy, typename CFType>
96 void GetRecommendationsHelper(
97  CFType& cf,
98  const InterpolationTypes interpolationType,
99  const size_t numRecs,
100  arma::Mat<size_t>& recommendations,
101  const arma::Col<size_t>& users)
102 {
103  switch (interpolationType)
104  {
105  case AVERAGE_INTERPOLATION:
106  cf.template GetRecommendations<NeighborSearchPolicy,
108  numRecs, recommendations, users);
109  break;
110 
111  case REGRESSION_INTERPOLATION:
112  cf.template GetRecommendations<NeighborSearchPolicy,
114  numRecs, recommendations, users);
115  break;
116 
117  case SIMILARITY_INTERPOLATION:
118  cf.template GetRecommendations<NeighborSearchPolicy,
120  numRecs, recommendations, users);
121  break;
122  }
123 }
124 
126 template<typename DecompositionPolicy, typename NormalizationPolicy>
128  const NeighborSearchTypes nsType,
129  const InterpolationTypes interpolationType,
130  const size_t numRecs,
131  arma::Mat<size_t>& recommendations,
132  const arma::Col<size_t>& users)
133 {
134  switch (nsType)
135  {
136  case COSINE_SEARCH:
137  GetRecommendationsHelper<CosineSearch>(cf, interpolationType, numRecs,
138  recommendations, users);
139  break;
140 
141  case EUCLIDEAN_SEARCH:
142  GetRecommendationsHelper<EuclideanSearch>(cf, interpolationType, numRecs,
143  recommendations, users);
144  break;
145 
146  case PEARSON_SEARCH:
147  GetRecommendationsHelper<PearsonSearch>(cf, interpolationType, numRecs,
148  recommendations, users);
149  break;
150  }
151 }
152 
153 template<typename NeighborSearchPolicy, typename CFType>
154 void GetRecommendationsHelper(
155  CFType& cf,
156  const InterpolationTypes interpolationType,
157  const size_t numRecs,
158  arma::Mat<size_t>& recommendations)
159 {
160  switch (interpolationType)
161  {
162  case AVERAGE_INTERPOLATION:
163  cf.template GetRecommendations<NeighborSearchPolicy,
165  numRecs, recommendations);
166  break;
167 
168  case REGRESSION_INTERPOLATION:
169  cf.template GetRecommendations<NeighborSearchPolicy,
171  numRecs, recommendations);
172  break;
173 
174  case SIMILARITY_INTERPOLATION:
175  cf.template GetRecommendations<NeighborSearchPolicy,
177  numRecs, recommendations);
178  break;
179  }
180 }
181 
183 template<typename DecompositionPolicy, typename NormalizationPolicy>
185  const NeighborSearchTypes nsType,
186  const InterpolationTypes interpolationType,
187  const size_t numRecs,
188  arma::Mat<size_t>& recommendations)
189 {
190  switch (nsType)
191  {
192  case COSINE_SEARCH:
193  GetRecommendationsHelper<CosineSearch>(cf, interpolationType, numRecs,
194  recommendations);
195  break;
196 
197  case EUCLIDEAN_SEARCH:
198  GetRecommendationsHelper<EuclideanSearch>(cf, interpolationType, numRecs,
199  recommendations);
200  break;
201 
202  case PEARSON_SEARCH:
203  GetRecommendationsHelper<PearsonSearch>(cf, interpolationType, numRecs,
204  recommendations);
205  break;
206  }
207 }
208 
209 template<typename DecompositionPolicy>
210 CFWrapperBase* InitializeModelHelper(
211  CFModel::NormalizationTypes normalizationType)
212 {
213  switch (normalizationType)
214  {
215  case CFModel::NO_NORMALIZATION:
217 
218  case CFModel::ITEM_MEAN_NORMALIZATION:
220 
221  case CFModel::USER_MEAN_NORMALIZATION:
223 
224  case CFModel::OVERALL_MEAN_NORMALIZATION:
226 
227  case CFModel::Z_SCORE_NORMALIZATION:
229  }
230 
231  // This shouldn't ever happen.
232  return NULL;
233 }
234 
235 inline CFWrapperBase* InitializeModel(
236  CFModel::DecompositionTypes decompositionType,
237  CFModel::NormalizationTypes normalizationType)
238 {
239  switch (decompositionType)
240  {
241  case CFModel::NMF:
242  return InitializeModelHelper<NMFPolicy>(normalizationType);
243 
244  case CFModel::BATCH_SVD:
245  return InitializeModelHelper<BatchSVDPolicy>(normalizationType);
246 
247  case CFModel::RANDOMIZED_SVD:
248  return InitializeModelHelper<RandomizedSVDPolicy>(normalizationType);
249 
250  case CFModel::REG_SVD:
251  return InitializeModelHelper<RegSVDPolicy>(normalizationType);
252 
253  case CFModel::SVD_COMPLETE:
254  return InitializeModelHelper<SVDCompletePolicy>(normalizationType);
255 
256  case CFModel::SVD_INCOMPLETE:
257  return InitializeModelHelper<SVDIncompletePolicy>(normalizationType);
258 
259  case CFModel::BIAS_SVD:
260  return InitializeModelHelper<BiasSVDPolicy>(normalizationType);
261 
262  case CFModel::SVD_PLUS_PLUS:
263  return InitializeModelHelper<SVDPlusPlusPolicy>(normalizationType);
264  }
265 
266  // This shouldn't ever happen.
267  return NULL;
268 };
269 
270 template<typename DecompositionPolicy, typename Archive>
271 void SerializeHelper(Archive& ar,
272  CFWrapperBase* cf,
273  CFModel::NormalizationTypes normalizationType)
274 {
275  switch (normalizationType)
276  {
277  case CFModel::NO_NORMALIZATION:
278  {
280  dynamic_cast<CFWrapper<DecompositionPolicy,
281  NoNormalization>&>(*cf);
282  ar(CEREAL_NVP(typedModel));
283  break;
284  }
285 
286  case CFModel::ITEM_MEAN_NORMALIZATION:
287  {
289  dynamic_cast<CFWrapper<DecompositionPolicy,
290  ItemMeanNormalization>&>(*cf);
291  ar(CEREAL_NVP(typedModel));
292  break;
293  }
294 
295  case CFModel::USER_MEAN_NORMALIZATION:
296  {
298  dynamic_cast<CFWrapper<DecompositionPolicy,
299  UserMeanNormalization>&>(*cf);
300  ar(CEREAL_NVP(typedModel));
301  break;
302  }
303 
304  case CFModel::OVERALL_MEAN_NORMALIZATION:
305  {
307  dynamic_cast<CFWrapper<DecompositionPolicy,
308  OverallMeanNormalization>&>(*cf);
309  ar(CEREAL_NVP(typedModel));
310  break;
311  }
312 
313  case CFModel::Z_SCORE_NORMALIZATION:
314  {
316  dynamic_cast<CFWrapper<DecompositionPolicy,
317  ZScoreNormalization>&>(*cf);
318  ar(CEREAL_NVP(typedModel));
319  break;
320  }
321  }
322 }
323 
324 template<typename Archive>
325 void CFModel::serialize(Archive& ar, const uint32_t /* version */)
326 {
327  ar(CEREAL_NVP(decompositionType));
328  ar(CEREAL_NVP(normalizationType));
329 
330  // This should never happen, but just in case, be clean with memory.
331  if (cereal::is_loading<Archive>())
332  {
333  delete cf;
334  cf = InitializeModel(decompositionType, normalizationType);
335  }
336 
337  // Avoid polymorphic serialization by determining the type directly.
338  switch (decompositionType)
339  {
340  case NMF:
341  SerializeHelper<NMFPolicy>(ar, cf, normalizationType);
342  break;
343 
344  case BATCH_SVD:
345  SerializeHelper<BatchSVDPolicy>(ar, cf, normalizationType);
346  break;
347 
348  case RANDOMIZED_SVD:
349  SerializeHelper<RandomizedSVDPolicy>(ar, cf, normalizationType);
350  break;
351 
352  case REG_SVD:
353  SerializeHelper<RegSVDPolicy>(ar, cf, normalizationType);
354  break;
355 
356  case SVD_COMPLETE:
357  SerializeHelper<SVDCompletePolicy>(ar, cf, normalizationType);
358  break;
359 
360  case SVD_INCOMPLETE:
361  SerializeHelper<SVDIncompletePolicy>(ar, cf, normalizationType);
362  break;
363 
364  case BIAS_SVD:
365  SerializeHelper<BiasSVDPolicy>(ar, cf, normalizationType);
366  break;
367 
368  case SVD_PLUS_PLUS:
369  SerializeHelper<SVDPlusPlusPolicy>(ar, cf, normalizationType);
370  break;
371  }
372 }
373 
374 } // namespace cf
375 } // namespace mlpack
376 
377 #endif
This normalization class performs item mean normalization on raw ratings.
Definition: item_mean_normalization.hpp:39
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The CFWrapperBase class provides a unified interface that can be used by the CFModel class to interac...
Definition: cf_model.hpp:49
void serialize(Archive &ar, const uint32_t)
Serialize the model.
Definition: cf_model_impl.hpp:325
This class performs average interpolation to generate interpolation weights for neighborhood-based co...
Definition: average_interpolation.hpp:39
This normalization class performs user mean normalization on raw ratings.
Definition: user_mean_normalization.hpp:39
InterpolationTypes
InterpolationTypes contains the set of InterpolationPolicy classes that are usable by CFModel at pred...
Definition: cf_model.hpp:37
Implementation of regression-based interpolation method.
Definition: regression_interpolation.hpp:56
virtual void Predict(const NeighborSearchTypes nsType, const InterpolationTypes interpolationType, const arma::Mat< size_t > &combinations, arma::vec &predictions)
Compute predictions for users.
Definition: cf_model_impl.hpp:70
NeighborSearchTypes
NeighborSearchTypes contains the set of NeighborSearchPolicy classes that are usable by CFModel at pr...
Definition: cf_model.hpp:26
This normalization class doesn&#39;t perform any normalization.
Definition: no_normalization.hpp:25
virtual void GetRecommendations(const NeighborSearchTypes nsType, const InterpolationTypes interpolationType, const size_t numRecs, arma::Mat< size_t > &recommendations)
Compute recommendations for all users.
Definition: cf_model_impl.hpp:184
With SimilarityInterpolation, interpolation weights are based on similarities between query user and ...
Definition: similarity_interpolation.hpp:41
This normalization class performs overall mean normalization on raw ratings.
Definition: overall_mean_normalization.hpp:39
The CFWrapper class wraps the functionality of all CF types.
Definition: cf_model.hpp:88
This class implements Collaborative Filtering (CF).
Definition: cf.hpp:70
This normalization class performs z-score normalization on raw ratings.
Definition: z_score_normalization.hpp:38