mlpack
Public Member Functions | List of all members
mlpack::regression::LinearRegression Class Reference

A simple linear regression algorithm using ordinary least squares. More...

#include <linear_regression.hpp>

Public Member Functions

 LinearRegression (const arma::mat &predictors, const arma::rowvec &responses, const double lambda=0, const bool intercept=true)
 Creates the model. More...
 
 LinearRegression (const arma::mat &predictors, const arma::rowvec &responses, const arma::rowvec &weights, const double lambda=0, const bool intercept=true)
 Creates the model with weighted learning. More...
 
 LinearRegression ()
 Empty constructor. More...
 
double Train (const arma::mat &predictors, const arma::rowvec &responses, const bool intercept=true)
 Train the LinearRegression model on the given data. More...
 
double Train (const arma::mat &predictors, const arma::rowvec &responses, const arma::rowvec &weights, const bool intercept=true)
 Train the LinearRegression model on the given data and weights. More...
 
void Predict (const arma::mat &points, arma::rowvec &predictions) const
 Calculate y_i for each data point in points. More...
 
double ComputeError (const arma::mat &points, const arma::rowvec &responses) const
 Calculate the L2 squared error on the given predictors and responses using this linear regression model. More...
 
const arma::vec & Parameters () const
 Return the parameters (the b vector).
 
arma::vec & Parameters ()
 Modify the parameters (the b vector).
 
double Lambda () const
 Return the Tikhonov regularization parameter for ridge regression.
 
double & Lambda ()
 Modify the Tikhonov regularization parameter for ridge regression.
 
bool Intercept () const
 Return whether or not an intercept term is used in the model.
 
template<typename Archive >
void serialize (Archive &ar, const uint32_t)
 Serialize the model.
 

Detailed Description

A simple linear regression algorithm using ordinary least squares.

Optionally, this class can perform ridge regression, if the lambda parameter is set to a number greater than zero.

Constructor & Destructor Documentation

◆ LinearRegression() [1/3]

LinearRegression::LinearRegression ( const arma::mat &  predictors,
const arma::rowvec &  responses,
const double  lambda = 0,
const bool  intercept = true 
)

Creates the model.

Parameters
predictorsX, matrix of data points.
responsesy, the measured data for each point in X.
lambdaRegularization constant for ridge regression.
interceptWhether or not to include an intercept term.

◆ LinearRegression() [2/3]

LinearRegression::LinearRegression ( const arma::mat &  predictors,
const arma::rowvec &  responses,
const arma::rowvec &  weights,
const double  lambda = 0,
const bool  intercept = true 
)

Creates the model with weighted learning.

Parameters
predictorsX, matrix of data points.
responsesy, the measured data for each point in X.
weightsObservation weights (for boosting).
lambdaRegularization constant for ridge regression.
interceptWhether or not to include an intercept term.

◆ LinearRegression() [3/3]

mlpack::regression::LinearRegression::LinearRegression ( )
inline

Empty constructor.

This gives a non-working model, so make sure Train() is called (or make sure the model parameters are set) before calling Predict()!

Member Function Documentation

◆ ComputeError()

double LinearRegression::ComputeError ( const arma::mat &  points,
const arma::rowvec &  responses 
) const

Calculate the L2 squared error on the given predictors and responses using this linear regression model.

This calculation returns

\[ (1 / n) * \| y - X B \|^2_2 \]

where \( y \) is the responses vector, \( X \) is the matrix of predictors, and \( B \) is the parameters of the trained linear regression model.

As this number decreases to 0, the linear regression fit is better.

Parameters
pointsMatrix of predictors (X).
responsesTransposed vector of responses (y^T).

◆ Predict()

void LinearRegression::Predict ( const arma::mat &  points,
arma::rowvec &  predictions 
) const

Calculate y_i for each data point in points.

Parameters
pointsthe data points to calculate with.
predictionsy, will contain calculated values on completion.

◆ Train() [1/2]

double LinearRegression::Train ( const arma::mat &  predictors,
const arma::rowvec &  responses,
const bool  intercept = true 
)

Train the LinearRegression model on the given data.

Careful! This will completely ignore and overwrite the existing model. This particular implementation does not have an incremental training algorithm. To set the regularization parameter lambda, call Lambda() or set a different value in the constructor.

Parameters
predictorsX, the matrix of data points to train the model on.
responsesy, the responses to the data points.
interceptWhether or not to fit an intercept term.
Returns
The least squares error after training.

◆ Train() [2/2]

double LinearRegression::Train ( const arma::mat &  predictors,
const arma::rowvec &  responses,
const arma::rowvec &  weights,
const bool  intercept = true 
)

Train the LinearRegression model on the given data and weights.

Careful! This will completely ignore and overwrite the existing model. This particular implementation does not have an incremental training algorithm. To set the regularization parameter lambda, call Lambda() or set a different value in the constructor.

Parameters
predictorsX, the matrix of data points to train the model on.
responsesy, the responses to the data points.
weightsObservation weights (for boosting).
interceptWhether or not to fit an intercept term.
Returns
The least squares error after training.

The documentation for this class was generated from the following files: