mlpack
regularized_svd_impl.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_REGULARIZED_SVD_REGULARIZED_SVD_IMPL_HPP
14 #define MLPACK_METHODS_REGULARIZED_SVD_REGULARIZED_SVD_IMPL_HPP
15 
16 namespace mlpack {
17 namespace svd {
18 
19 template<typename OptimizerType>
21  const double alpha,
22  const double lambda) :
23  iterations(iterations),
24  alpha(alpha),
25  lambda(lambda)
26 {
27  // Nothing to do.
28 }
29 
30 template<typename OptimizerType>
32  const size_t rank,
33  arma::mat& u,
34  arma::mat& v)
35 {
36  // batchSize is 1 in our implementation of Regularized SVD.
37  // batchSize other than 1 has not been supported yet.
38  const int batchSize = 1;
39  Log::Warn << "The batch size for optimizing RegularizedSVD is 1."
40  << std::endl;
41 
42  // Make the optimizer object using a RegularizedSVDFunction object.
43  RegularizedSVDFunction<arma::mat> rSVDFunc(data, rank, lambda);
44  ens::StandardSGD optimizer(alpha, batchSize,
45  iterations * data.n_cols);
46 
47  // Get optimized parameters.
48  arma::mat parameters = rSVDFunc.GetInitialPoint();
49  optimizer.Optimize(rSVDFunc, parameters);
50 
51  // Constants for extracting user and item matrices.
52  const size_t numUsers = max(data.row(0)) + 1;
53  const size_t numItems = max(data.row(1)) + 1;
54 
55  // Extract user and item matrices from the optimized parameters.
56  u = parameters.submat(0, numUsers, rank - 1, numUsers + numItems - 1).t();
57  v = parameters.submat(0, 0, rank - 1, numUsers - 1);
58 }
59 
60 } // namespace svd
61 } // namespace mlpack
62 
63 #endif
const arma::mat & GetInitialPoint() const
Return the initial point for the optimization.
Definition: regularized_svd_function.hpp:98
RegularizedSVD(const size_t iterations=10, const double alpha=0.01, const double lambda=0.02)
Constructor for Regularized SVD.
Definition: regularized_svd_impl.hpp:20
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
static MLPACK_EXPORT util::PrefixedOutStream Warn
Prints warning messages prefixed with [WARN ].
Definition: log.hpp:87
void Apply(const arma::mat &data, const size_t rank, arma::mat &u, arma::mat &v)
Obtains the user and item matrices using the provided data and rank.
Definition: regularized_svd_impl.hpp:31
The data is stored in a matrix of type MatType, so that this class can be used with both dense and sp...
Definition: regularized_svd_function.hpp:29