mlpack
cf_model.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_CF_CF_MODEL_HPP
14 #define MLPACK_METHODS_CF_CF_MODEL_HPP
15 
16 #include <mlpack/core.hpp>
17 #include "cf.hpp"
18 
19 namespace mlpack {
20 namespace cf {
21 
27 {
28  COSINE_SEARCH,
29  EUCLIDEAN_SEARCH,
30  PEARSON_SEARCH
31 };
32 
38 {
39  AVERAGE_INTERPOLATION,
40  REGRESSION_INTERPOLATION,
41  SIMILARITY_INTERPOLATION
42 };
43 
50 {
51  public:
54 
56  virtual CFWrapperBase* Clone() const = 0;
57 
59  virtual ~CFWrapperBase() { }
60 
62  virtual void Predict(const NeighborSearchTypes nsType,
63  const InterpolationTypes interpolationType,
64  const arma::Mat<size_t>& combinations,
65  arma::vec& predictions) = 0;
66 
68  virtual void GetRecommendations(
69  const NeighborSearchTypes nsType,
70  const InterpolationTypes interpolationType,
71  const size_t numRecs,
72  arma::Mat<size_t>& recommendations) = 0;
73 
75  virtual void GetRecommendations(
76  const NeighborSearchTypes nsType,
77  const InterpolationTypes interpolationType,
78  const size_t numRecs,
79  arma::Mat<size_t>& recommendations,
80  const arma::Col<size_t>& users) = 0;
81 };
82 
87 template<typename DecompositionPolicy, typename NormalizationPolicy>
88 class CFWrapper : public CFWrapperBase
89 {
90  protected:
92 
93  public:
96  CFWrapper() { }
97 
99  CFWrapper(const arma::mat& data,
100  const DecompositionPolicy& decomposition,
101  const size_t numUsersForSimilarity,
102  const size_t rank,
103  const size_t maxIterations,
104  const size_t minResidue,
105  const bool mit) :
106  cf(data,
107  decomposition,
108  numUsersForSimilarity,
109  rank,
110  maxIterations,
111  minResidue,
112  mit)
113  {
114  // Nothing else to do.
115  }
116 
118  virtual CFWrapper* Clone() const { return new CFWrapper(*this); }
119 
121  virtual ~CFWrapper() { }
122 
124  CFModelType& CF() { return cf; }
125 
127  virtual void Predict(const NeighborSearchTypes nsType,
128  const InterpolationTypes interpolationType,
129  const arma::Mat<size_t>& combinations,
130  arma::vec& predictions);
131 
133  virtual void GetRecommendations(
134  const NeighborSearchTypes nsType,
135  const InterpolationTypes interpolationType,
136  const size_t numRecs,
137  arma::Mat<size_t>& recommendations);
138 
140  virtual void GetRecommendations(
141  const NeighborSearchTypes nsType,
142  const InterpolationTypes interpolationType,
143  const size_t numRecs,
144  arma::Mat<size_t>& recommendations,
145  const arma::Col<size_t>& users);
146 
148  template<typename Archive>
149  void serialize(Archive& ar, const uint32_t /* version */)
150  {
151  ar(CEREAL_NVP(cf));
152  }
153 
154  protected:
156  CFModelType cf;
157 };
158 
162 class CFModel
163 {
164  public:
165  enum DecompositionTypes
166  {
167  NMF,
168  BATCH_SVD,
169  RANDOMIZED_SVD,
170  REG_SVD,
171  SVD_COMPLETE,
172  SVD_INCOMPLETE,
173  BIAS_SVD,
174  SVD_PLUS_PLUS
175  };
176 
177  enum NormalizationTypes
178  {
179  NO_NORMALIZATION,
180  ITEM_MEAN_NORMALIZATION,
181  USER_MEAN_NORMALIZATION,
182  OVERALL_MEAN_NORMALIZATION,
183  Z_SCORE_NORMALIZATION
184  };
185 
186  private:
188  DecompositionTypes decompositionType;
190  NormalizationTypes normalizationType;
191 
197  CFWrapperBase* cf;
198 
199  public:
201  CFModel();
202 
204  CFModel(const CFModel& other);
205 
207  CFModel(CFModel&& other);
208 
210  CFModel& operator=(const CFModel& other);
211 
213  CFModel& operator=(CFModel&& other);
214 
216  ~CFModel();
217 
219  CFWrapperBase* CF() const { return cf; }
220 
222  const DecompositionTypes& DecompositionType() const
223  {
224  return decompositionType;
225  }
227  DecompositionTypes& DecompositionType()
228  {
229  return decompositionType;
230  }
231 
233  const NormalizationTypes& NormalizationType() const
234  {
235  return normalizationType;
236  }
238  NormalizationTypes& NormalizationType()
239  {
240  return normalizationType;
241  }
242 
244  void Train(const arma::mat& data,
245  const size_t numUsersForSimilarity,
246  const size_t rank,
247  const size_t maxIterations,
248  const double minResidue,
249  const bool mit);
250 
252  void Predict(const NeighborSearchTypes nsType,
253  const InterpolationTypes interpolationType,
254  const arma::Mat<size_t>& combinations,
255  arma::vec& predictions);
256 
258  void GetRecommendations(const NeighborSearchTypes nsType,
259  const InterpolationTypes interpolationType,
260  const size_t numRecs,
261  arma::Mat<size_t>& recommendations,
262  const arma::Col<size_t>& users);
263 
265  void GetRecommendations(const NeighborSearchTypes nsType,
266  const InterpolationTypes interpolationType,
267  const size_t numRecs,
268  arma::Mat<size_t>& recommendations);
269 
271  template<typename Archive>
272  void serialize(Archive& ar, const uint32_t /* version */);
273 };
274 
275 } // namespace cf
276 } // namespace mlpack
277 
278 // Include implementation.
279 #include "cf_model_impl.hpp"
280 
281 #endif
CFModelType cf
This is the CF object that we are wrapping.
Definition: cf_model.hpp:156
CFWrapper(const arma::mat &data, const DecompositionPolicy &decomposition, const size_t numUsersForSimilarity, const size_t rank, const size_t maxIterations, const size_t minResidue, const bool mit)
Create the CFWrapper object, initializing the held CF object.
Definition: cf_model.hpp:99
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The CFWrapperBase class provides a unified interface that can be used by the CFModel class to interac...
Definition: cf_model.hpp:49
const NormalizationTypes & NormalizationType() const
Get the normalization type.
Definition: cf_model.hpp:233
CFWrapper()
Create the CFWrapper object, using default parameters to initialize the held CF object.
Definition: cf_model.hpp:96
void serialize(Archive &ar, const uint32_t)
Serialize the model.
Definition: cf_model.hpp:149
virtual ~CFWrapper()
Destroy the CFWrapper object.
Definition: cf_model.hpp:121
InterpolationTypes
InterpolationTypes contains the set of InterpolationPolicy classes that are usable by CFModel at pred...
Definition: cf_model.hpp:37
NeighborSearchTypes
NeighborSearchTypes contains the set of NeighborSearchPolicy classes that are usable by CFModel at pr...
Definition: cf_model.hpp:26
virtual CFWrapperBase * Clone() const =0
Make a copy of the object.
The model to save to disk.
Definition: cf_model.hpp:162
CFWrapperBase * CF() const
Get the CFWrapperBase object. (Be careful!)
Definition: cf_model.hpp:219
NormalizationTypes & NormalizationType()
Set the normalization type.
Definition: cf_model.hpp:238
Include all of the base components required to write mlpack methods, and the main mlpack Doxygen docu...
Definition: hmm_train_main.cpp:300
const DecompositionTypes & DecompositionType() const
Get the decomposition type.
Definition: cf_model.hpp:222
CFWrapperBase()
Create the object. The base class has nothing to hold.
Definition: cf_model.hpp:53
virtual void Predict(const NeighborSearchTypes nsType, const InterpolationTypes interpolationType, const arma::Mat< size_t > &combinations, arma::vec &predictions)=0
Compute predictions for users.
virtual void GetRecommendations(const NeighborSearchTypes nsType, const InterpolationTypes interpolationType, const size_t numRecs, arma::Mat< size_t > &recommendations)=0
Compute recommendations for all users.
DecompositionTypes & DecompositionType()
Set the decomposition type.
Definition: cf_model.hpp:227
The CFWrapper class wraps the functionality of all CF types.
Definition: cf_model.hpp:88
virtual ~CFWrapperBase()
Delete the object.
Definition: cf_model.hpp:59
virtual CFWrapper * Clone() const
Clone the CFWrapper object. This handles polymorphism correctly.
Definition: cf_model.hpp:118
CFModelType & CF()
Get the CFType object.
Definition: cf_model.hpp:124