12 #ifndef MLPACK_METHODS_CF_CF_MODEL_IMPL_HPP 13 #define MLPACK_METHODS_CF_CF_MODEL_IMPL_HPP 28 #include "decomposition_policies/randomized_svd_method.hpp" 43 template<
typename NeighborSearchPolicy,
typename CFType>
44 void PredictHelper(CFType& cf,
46 const arma::Mat<size_t>& combinations,
47 arma::vec& predictions)
49 switch (interpolationType)
51 case AVERAGE_INTERPOLATION:
52 cf.template Predict<NeighborSearchPolicy,
53 AverageInterpolation>(combinations, predictions);
56 case REGRESSION_INTERPOLATION:
57 cf.template Predict<NeighborSearchPolicy,
58 RegressionInterpolation>(combinations, predictions);
61 case SIMILARITY_INTERPOLATION:
62 cf.template Predict<NeighborSearchPolicy,
63 SimilarityInterpolation>(combinations, predictions);
69 template<
typename DecompositionPolicy,
typename NormalizationPolicy>
73 const arma::Mat<size_t>& combinations,
74 arma::vec& predictions)
79 PredictHelper<CosineSearch>(cf, interpolationType, combinations,
83 case EUCLIDEAN_SEARCH:
84 PredictHelper<EuclideanSearch>(cf, interpolationType, combinations,
89 PredictHelper<PearsonSearch>(cf, interpolationType, combinations,
95 template<
typename NeighborSearchPolicy,
typename CFType>
96 void GetRecommendationsHelper(
100 arma::Mat<size_t>& recommendations,
101 const arma::Col<size_t>& users)
103 switch (interpolationType)
105 case AVERAGE_INTERPOLATION:
106 cf.template GetRecommendations<NeighborSearchPolicy,
108 numRecs, recommendations, users);
111 case REGRESSION_INTERPOLATION:
112 cf.template GetRecommendations<NeighborSearchPolicy,
114 numRecs, recommendations, users);
117 case SIMILARITY_INTERPOLATION:
118 cf.template GetRecommendations<NeighborSearchPolicy,
120 numRecs, recommendations, users);
126 template<
typename DecompositionPolicy,
typename NormalizationPolicy>
130 const size_t numRecs,
131 arma::Mat<size_t>& recommendations,
132 const arma::Col<size_t>& users)
137 GetRecommendationsHelper<CosineSearch>(cf, interpolationType, numRecs,
138 recommendations, users);
141 case EUCLIDEAN_SEARCH:
142 GetRecommendationsHelper<EuclideanSearch>(cf, interpolationType, numRecs,
143 recommendations, users);
147 GetRecommendationsHelper<PearsonSearch>(cf, interpolationType, numRecs,
148 recommendations, users);
153 template<
typename NeighborSearchPolicy,
typename CFType>
154 void GetRecommendationsHelper(
157 const size_t numRecs,
158 arma::Mat<size_t>& recommendations)
160 switch (interpolationType)
162 case AVERAGE_INTERPOLATION:
163 cf.template GetRecommendations<NeighborSearchPolicy,
165 numRecs, recommendations);
168 case REGRESSION_INTERPOLATION:
169 cf.template GetRecommendations<NeighborSearchPolicy,
171 numRecs, recommendations);
174 case SIMILARITY_INTERPOLATION:
175 cf.template GetRecommendations<NeighborSearchPolicy,
177 numRecs, recommendations);
183 template<
typename DecompositionPolicy,
typename NormalizationPolicy>
187 const size_t numRecs,
188 arma::Mat<size_t>& recommendations)
193 GetRecommendationsHelper<CosineSearch>(cf, interpolationType, numRecs,
197 case EUCLIDEAN_SEARCH:
198 GetRecommendationsHelper<EuclideanSearch>(cf, interpolationType, numRecs,
203 GetRecommendationsHelper<PearsonSearch>(cf, interpolationType, numRecs,
209 template<
typename DecompositionPolicy>
211 CFModel::NormalizationTypes normalizationType)
213 switch (normalizationType)
215 case CFModel::NO_NORMALIZATION:
218 case CFModel::ITEM_MEAN_NORMALIZATION:
221 case CFModel::USER_MEAN_NORMALIZATION:
224 case CFModel::OVERALL_MEAN_NORMALIZATION:
227 case CFModel::Z_SCORE_NORMALIZATION:
236 CFModel::DecompositionTypes decompositionType,
237 CFModel::NormalizationTypes normalizationType)
239 switch (decompositionType)
242 return InitializeModelHelper<NMFPolicy>(normalizationType);
244 case CFModel::BATCH_SVD:
245 return InitializeModelHelper<BatchSVDPolicy>(normalizationType);
247 case CFModel::RANDOMIZED_SVD:
248 return InitializeModelHelper<RandomizedSVDPolicy>(normalizationType);
250 case CFModel::REG_SVD:
251 return InitializeModelHelper<RegSVDPolicy>(normalizationType);
253 case CFModel::SVD_COMPLETE:
254 return InitializeModelHelper<SVDCompletePolicy>(normalizationType);
256 case CFModel::SVD_INCOMPLETE:
257 return InitializeModelHelper<SVDIncompletePolicy>(normalizationType);
259 case CFModel::BIAS_SVD:
260 return InitializeModelHelper<BiasSVDPolicy>(normalizationType);
262 case CFModel::SVD_PLUS_PLUS:
263 return InitializeModelHelper<SVDPlusPlusPolicy>(normalizationType);
270 template<
typename DecompositionPolicy,
typename Archive>
271 void SerializeHelper(Archive& ar,
273 CFModel::NormalizationTypes normalizationType)
275 switch (normalizationType)
277 case CFModel::NO_NORMALIZATION:
280 dynamic_cast<CFWrapper<DecompositionPolicy,
282 ar(CEREAL_NVP(typedModel));
286 case CFModel::ITEM_MEAN_NORMALIZATION:
289 dynamic_cast<CFWrapper<DecompositionPolicy,
291 ar(CEREAL_NVP(typedModel));
295 case CFModel::USER_MEAN_NORMALIZATION:
298 dynamic_cast<CFWrapper<DecompositionPolicy,
300 ar(CEREAL_NVP(typedModel));
304 case CFModel::OVERALL_MEAN_NORMALIZATION:
307 dynamic_cast<CFWrapper<DecompositionPolicy,
309 ar(CEREAL_NVP(typedModel));
313 case CFModel::Z_SCORE_NORMALIZATION:
316 dynamic_cast<CFWrapper<DecompositionPolicy,
318 ar(CEREAL_NVP(typedModel));
324 template<
typename Archive>
327 ar(CEREAL_NVP(decompositionType));
328 ar(CEREAL_NVP(normalizationType));
331 if (cereal::is_loading<Archive>())
334 cf = InitializeModel(decompositionType, normalizationType);
338 switch (decompositionType)
341 SerializeHelper<NMFPolicy>(ar, cf, normalizationType);
345 SerializeHelper<BatchSVDPolicy>(ar, cf, normalizationType);
349 SerializeHelper<RandomizedSVDPolicy>(ar, cf, normalizationType);
353 SerializeHelper<RegSVDPolicy>(ar, cf, normalizationType);
357 SerializeHelper<SVDCompletePolicy>(ar, cf, normalizationType);
361 SerializeHelper<SVDIncompletePolicy>(ar, cf, normalizationType);
365 SerializeHelper<BiasSVDPolicy>(ar, cf, normalizationType);
369 SerializeHelper<SVDPlusPlusPolicy>(ar, cf, normalizationType);
This normalization class performs item mean normalization on raw ratings.
Definition: item_mean_normalization.hpp:39
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The CFWrapperBase class provides a unified interface that can be used by the CFModel class to interac...
Definition: cf_model.hpp:49
void serialize(Archive &ar, const uint32_t)
Serialize the model.
Definition: cf_model_impl.hpp:325
This class performs average interpolation to generate interpolation weights for neighborhood-based co...
Definition: average_interpolation.hpp:39
This normalization class performs user mean normalization on raw ratings.
Definition: user_mean_normalization.hpp:39
InterpolationTypes
InterpolationTypes contains the set of InterpolationPolicy classes that are usable by CFModel at pred...
Definition: cf_model.hpp:37
Implementation of regression-based interpolation method.
Definition: regression_interpolation.hpp:56
virtual void Predict(const NeighborSearchTypes nsType, const InterpolationTypes interpolationType, const arma::Mat< size_t > &combinations, arma::vec &predictions)
Compute predictions for users.
Definition: cf_model_impl.hpp:70
NeighborSearchTypes
NeighborSearchTypes contains the set of NeighborSearchPolicy classes that are usable by CFModel at pr...
Definition: cf_model.hpp:26
This normalization class doesn't perform any normalization.
Definition: no_normalization.hpp:25
virtual void GetRecommendations(const NeighborSearchTypes nsType, const InterpolationTypes interpolationType, const size_t numRecs, arma::Mat< size_t > &recommendations)
Compute recommendations for all users.
Definition: cf_model_impl.hpp:184
With SimilarityInterpolation, interpolation weights are based on similarities between query user and ...
Definition: similarity_interpolation.hpp:41
This normalization class performs overall mean normalization on raw ratings.
Definition: overall_mean_normalization.hpp:39
The CFWrapper class wraps the functionality of all CF types.
Definition: cf_model.hpp:88
This class implements Collaborative Filtering (CF).
Definition: cf.hpp:70
This normalization class performs z-score normalization on raw ratings.
Definition: z_score_normalization.hpp:38