mlpack
svdplusplus_impl.hpp
Go to the documentation of this file.
1 
14 #ifndef MLPACK_METHODS_SVDPLUSPLUS_SVDPLUSPLUS_IMPL_HPP
15 #define MLPACK_METHODS_SVDPLUSPLUS_SVDPLUSPLUS_IMPL_HPP
16 
17 namespace mlpack {
18 namespace svd {
19 
20 template<typename OptimizerType>
22  const double alpha,
23  const double lambda) :
24  iterations(iterations),
25  alpha(alpha),
26  lambda(lambda)
27 {
28  // Nothing to do.
29 }
30 
31 template<typename OptimizerType>
33  const arma::mat& implicitData,
34  const size_t rank,
35  arma::mat& u,
36  arma::mat& v,
37  arma::vec& p,
38  arma::vec& q,
39  arma::mat& y)
40 {
41  // batchSize is 1 in our implementation of SVDPlusPlus.
42  // batchSize other than 1 has not been supported yet.
43  const int batchSize = 1;
44  Log::Warn << "The batch size for optimizing SVDPlusPlus is 1."
45  << std::endl;
46 
47  // Converts implicitData to the form of sparse matrix.
48  arma::sp_mat cleanedData;
49  CleanData(implicitData, cleanedData, data);
50 
51  // Make the optimizer object using a SVDPlusPlusFunction object.
52  SVDPlusPlusFunction<arma::mat> svdPPFunc(data, cleanedData, rank, lambda);
53  ens::StandardSGD optimizer(alpha, batchSize,
54  iterations * data.n_cols);
55 
56  // Get optimized parameters.
57  arma::mat parameters = svdPPFunc.GetInitialPoint();
58  optimizer.Optimize(svdPPFunc, parameters);
59 
60  // Constants for extracting user and item matrices.
61  const size_t numUsers = max(data.row(0)) + 1;
62  const size_t numItems = max(data.row(1)) + 1;
63 
64  // Extract user and item matrices, user and item bias, item implicit matrix
65  // from the optimized parameters.
66  u = parameters.submat(0, numUsers, rank - 1, numUsers + numItems - 1).t();
67  v = parameters.submat(0, 0, rank - 1, numUsers - 1);
68  p = parameters.row(rank).subvec(numUsers, numUsers + numItems - 1).t();
69  q = parameters.row(rank).subvec(0, numUsers - 1).t();
70  y = parameters.submat(0, numUsers + numItems, rank - 1,
71  numUsers + 2 * numItems - 1);
72 }
73 
74 // Use whether a user rates an item as binary implicit data when implicitData
75 // is not given.
76 template<typename OptimizerType>
78  const size_t rank,
79  arma::mat& u,
80  arma::mat& v,
81  arma::vec& p,
82  arma::vec& q,
83  arma::mat& y)
84 {
85  arma::mat implicitData = data.submat(0, 0, 1, data.n_cols - 1);
86  Apply(data, implicitData, rank, u, v, p, q, y);
87 }
88 
89 template<typename OptimizerType>
90 void SVDPlusPlus<OptimizerType>::CleanData(const arma::mat& implicitData,
91  arma::sp_mat& cleanedData,
92  const arma::mat& data)
93 {
94  // Generate list of locations for batch insert constructor for sparse
95  // matrices.
96  arma::umat locations(2, implicitData.n_cols);
97  arma::vec values(implicitData.n_cols);
98  for (size_t i = 0; i < implicitData.n_cols; ++i)
99  {
100  // We have to transpose it because items are rows, and users are columns.
101  locations(1, i) = ((arma::uword) implicitData(0, i));
102  locations(0, i) = ((arma::uword) implicitData(1, i));
103  values(i) = 1;
104  }
105 
106  // Find maximum user and item IDs.
107  const size_t maxItemID = (size_t) max(data.row(1)) + 1;
108  const size_t maxUserID = (size_t) max(data.row(0)) + 1;
109 
110  // Fill sparse matrix.
111  cleanedData = arma::sp_mat(locations, values, maxItemID, maxUserID);
112 }
113 
114 } // namespace svd
115 } // namespace mlpack
116 
117 #endif
This class contains methods which are used to calculate the cost of SVD++&#39;s objective function...
Definition: svdplusplus_function.hpp:31
SVDPlusPlus(const size_t iterations=10, const double alpha=0.001, const double lambda=0.1)
Constructor of SVDPlusPlus.
Definition: svdplusplus_impl.hpp:21
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
const arma::mat & GetInitialPoint() const
Return the initial point for the optimization.
Definition: svdplusplus_function.hpp:107
void Apply(const arma::mat &data, const arma::mat &implicitData, const size_t rank, arma::mat &u, arma::mat &v, arma::vec &p, arma::vec &q, arma::mat &y)
Trains the model and obtains user/item matrices, user/item bias, and item implicit matrix...
Definition: svdplusplus_impl.hpp:32
static MLPACK_EXPORT util::PrefixedOutStream Warn
Prints warning messages prefixed with [WARN ].
Definition: log.hpp:87
static void CleanData(const arma::mat &implicitData, arma::sp_mat &cleanedData, const arma::mat &data)
Converts the User, Item matrix of implicit data to Item-User Table.
Definition: svdplusplus_impl.hpp:90