13 #ifndef MLPACK_METHODS_ANN_LAYER_LINEAR_IMPL_HPP 14 #define MLPACK_METHODS_ANN_LAYER_LINEAR_IMPL_HPP 22 template<
typename InputDataType,
typename OutputDataType,
23 typename RegularizerType>
31 template<
typename InputDataType,
typename OutputDataType,
32 typename RegularizerType>
36 RegularizerType regularizer) :
39 regularizer(regularizer)
44 template<
typename InputDataType,
typename OutputDataType,
45 typename RegularizerType>
49 outSize(layer.outSize),
50 weights(layer.weights),
51 regularizer(layer.regularizer)
56 template<
typename InputDataType,
typename OutputDataType,
57 typename RegularizerType>
62 weights(
std::move(layer.weights)),
63 regularizer(
std::move(layer.regularizer))
68 template<
typename InputDataType,
typename OutputDataType,
69 typename RegularizerType>
76 inSize = layer.inSize;
77 outSize = layer.outSize;
78 weights = layer.weights;
79 regularizer = layer.regularizer;
84 template<
typename InputDataType,
typename OutputDataType,
85 typename RegularizerType>
92 inSize = layer.inSize;
93 outSize = layer.outSize;
94 weights = std::move(layer.weights);
95 regularizer = std::move(layer.regularizer);
100 template<
typename InputDataType,
typename OutputDataType,
101 typename RegularizerType>
104 weight = arma::mat(weights.memptr(), outSize, inSize,
false,
false);
105 bias = arma::mat(weights.memptr() + weight.n_elem,
106 outSize, 1,
false,
false);
109 template<
typename InputDataType,
typename OutputDataType,
110 typename RegularizerType>
111 template<
typename eT>
113 const arma::Mat<eT>& input, arma::Mat<eT>& output)
115 output = weight * input;
116 output.each_col() += bias;
119 template<
typename InputDataType,
typename OutputDataType,
120 typename RegularizerType>
121 template<
typename eT>
123 const arma::Mat<eT>& ,
const arma::Mat<eT>& gy, arma::Mat<eT>& g)
128 template<
typename InputDataType,
typename OutputDataType,
129 typename RegularizerType>
130 template<
typename eT>
132 const arma::Mat<eT>& input,
133 const arma::Mat<eT>& error,
134 arma::Mat<eT>& gradient)
136 gradient.submat(0, 0, weight.n_elem - 1, 0) = arma::vectorise(
138 gradient.submat(weight.n_elem, 0, gradient.n_elem - 1, 0) =
140 regularizer.Evaluate(weights, gradient);
143 template<
typename InputDataType,
typename OutputDataType,
144 typename RegularizerType>
145 template<
typename Archive>
147 Archive& ar,
const uint32_t )
149 ar(CEREAL_NVP(inSize));
150 ar(CEREAL_NVP(outSize));
151 ar(CEREAL_NVP(weights));
OutputDataType const & Gradient() const
Get the gradient.
Definition: linear.hpp:135
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
Implementation of the Linear layer class.
Definition: layer_types.hpp:93
Definition: pointer_wrapper.hpp:23
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: linear_impl.hpp:146
size_t WeightSize() const
Get the size of the weights.
Definition: linear.hpp:150
void Backward(const arma::Mat< eT > &, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
Definition: linear_impl.hpp:122
Linear()
Create the Linear object.
Definition: linear_impl.hpp:24
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: linear_impl.hpp:112
Linear & operator=(const Linear &layer)
Copy assignment operator.
Definition: linear_impl.hpp:72