mlpack
linear_no_bias_impl.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_LAYER_LINEAR_NO_BIAS_IMPL_HPP
14 #define MLPACK_METHODS_ANN_LAYER_LINEAR_NO_BIAS_IMPL_HPP
15 
16 // In case it hasn't yet been included.
17 #include "linear_no_bias.hpp"
18 
19 namespace mlpack {
20 namespace ann {
21 
22 template<typename InputDataType, typename OutputDataType,
23  typename RegularizerType>
25  inSize(0),
26  outSize(0)
27 {
28  // Nothing to do here.
29 }
30 
31 template<typename InputDataType, typename OutputDataType,
32  typename RegularizerType>
34  const size_t inSize,
35  const size_t outSize,
36  RegularizerType regularizer) :
37  inSize(inSize),
38  outSize(outSize),
39  regularizer(regularizer)
40 {
41  weights.set_size(WeightSize(), 1);
42 }
43 
44 template<typename InputDataType, typename OutputDataType,
45  typename RegularizerType>
47 {
48  weight = arma::mat(weights.memptr(), outSize, inSize, false, false);
49 }
50 
51 template<typename InputDataType, typename OutputDataType,
52  typename RegularizerType>
53 template<typename eT>
55  const arma::Mat<eT>& input, arma::Mat<eT>& output)
56 {
57  output = weight * input;
58 }
59 
60 template<typename InputDataType, typename OutputDataType,
61  typename RegularizerType>
62 template<typename eT>
64  const arma::Mat<eT>& /* input */, const arma::Mat<eT>& gy, arma::Mat<eT>& g)
65 {
66  g = weight.t() * gy;
67 }
68 
69 template<typename InputDataType, typename OutputDataType,
70  typename RegularizerType>
71 template<typename eT>
73  const arma::Mat<eT>& input,
74  const arma::Mat<eT>& error,
75  arma::Mat<eT>& gradient)
76 {
77  gradient.submat(0, 0, weight.n_elem - 1, 0) = arma::vectorise(
78  error * input.t());
79  regularizer.Evaluate(weights, gradient);
80 }
81 
82 template<typename InputDataType, typename OutputDataType,
83  typename RegularizerType>
84 template<typename Archive>
86  Archive& ar, const uint32_t /* version */)
87 {
88  ar(CEREAL_NVP(inSize));
89  ar(CEREAL_NVP(outSize));
90 
91  // This is inefficient, but necessary so that WeightSetVisitor sets the right
92  // size.
93  if (cereal::is_loading<Archive>())
94  weights.set_size(outSize * inSize, 1);
95 }
96 
97 } // namespace ann
98 } // namespace mlpack
99 
100 #endif
LinearNoBias()
Create the LinearNoBias object.
Definition: linear_no_bias_impl.hpp:24
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: linear_no_bias_impl.hpp:85
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
size_t WeightSize() const
Get the size of the weights.
Definition: linear_no_bias.hpp:127
OutputDataType const & Gradient() const
Get the gradient.
Definition: linear_no_bias.hpp:122
Implementation of the LinearNoBias class.
Definition: layer_types.hpp:103
void Backward(const arma::Mat< eT > &, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
Definition: linear_no_bias_impl.hpp:63
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: linear_no_bias_impl.hpp:54