mlpack
linear_impl.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_LAYER_LINEAR_IMPL_HPP
14 #define MLPACK_METHODS_ANN_LAYER_LINEAR_IMPL_HPP
15 
16 // In case it hasn't yet been included.
17 #include "linear.hpp"
18 
19 namespace mlpack {
20 namespace ann {
21 
22 template<typename InputDataType, typename OutputDataType,
23  typename RegularizerType>
25  inSize(0),
26  outSize(0)
27 {
28  // Nothing to do here.
29 }
30 
31 template<typename InputDataType, typename OutputDataType,
32  typename RegularizerType>
34  const size_t inSize,
35  const size_t outSize,
36  RegularizerType regularizer) :
37  inSize(inSize),
38  outSize(outSize),
39  regularizer(regularizer)
40 {
41  weights.set_size(WeightSize(), 1);
42 }
43 
44 template<typename InputDataType, typename OutputDataType,
45  typename RegularizerType>
47  const Linear& layer) :
48  inSize(layer.inSize),
49  outSize(layer.outSize),
50  weights(layer.weights),
51  regularizer(layer.regularizer)
52 {
53  // Nothing to do here.
54 }
55 
56 template<typename InputDataType, typename OutputDataType,
57  typename RegularizerType>
59  Linear&& layer) :
60  inSize(0),
61  outSize(0),
62  weights(std::move(layer.weights)),
63  regularizer(std::move(layer.regularizer))
64 {
65  // Nothing to do here.
66 }
67 
68 template<typename InputDataType, typename OutputDataType,
69  typename RegularizerType>
72 operator=(const Linear& layer)
73 {
74  if (this != &layer)
75  {
76  inSize = layer.inSize;
77  outSize = layer.outSize;
78  weights = layer.weights;
79  regularizer = layer.regularizer;
80  }
81  return *this;
82 }
83 
84 template<typename InputDataType, typename OutputDataType,
85  typename RegularizerType>
88 operator=(Linear&& layer)
89 {
90  if (this != &layer)
91  {
92  inSize = layer.inSize;
93  outSize = layer.outSize;
94  weights = std::move(layer.weights);
95  regularizer = std::move(layer.regularizer);
96  }
97  return *this;
98 }
99 
100 template<typename InputDataType, typename OutputDataType,
101  typename RegularizerType>
103 {
104  weight = arma::mat(weights.memptr(), outSize, inSize, false, false);
105  bias = arma::mat(weights.memptr() + weight.n_elem,
106  outSize, 1, false, false);
107 }
108 
109 template<typename InputDataType, typename OutputDataType,
110  typename RegularizerType>
111 template<typename eT>
113  const arma::Mat<eT>& input, arma::Mat<eT>& output)
114 {
115  output = weight * input;
116  output.each_col() += bias;
117 }
118 
119 template<typename InputDataType, typename OutputDataType,
120  typename RegularizerType>
121 template<typename eT>
123  const arma::Mat<eT>& /* input */, const arma::Mat<eT>& gy, arma::Mat<eT>& g)
124 {
125  g = weight.t() * gy;
126 }
127 
128 template<typename InputDataType, typename OutputDataType,
129  typename RegularizerType>
130 template<typename eT>
132  const arma::Mat<eT>& input,
133  const arma::Mat<eT>& error,
134  arma::Mat<eT>& gradient)
135 {
136  gradient.submat(0, 0, weight.n_elem - 1, 0) = arma::vectorise(
137  error * input.t());
138  gradient.submat(weight.n_elem, 0, gradient.n_elem - 1, 0) =
139  arma::sum(error, 1);
140  regularizer.Evaluate(weights, gradient);
141 }
142 
143 template<typename InputDataType, typename OutputDataType,
144  typename RegularizerType>
145 template<typename Archive>
147  Archive& ar, const uint32_t /* version */)
148 {
149  ar(CEREAL_NVP(inSize));
150  ar(CEREAL_NVP(outSize));
151  ar(CEREAL_NVP(weights));
152 }
153 
154 } // namespace ann
155 } // namespace mlpack
156 
157 #endif
OutputDataType const & Gradient() const
Get the gradient.
Definition: linear.hpp:135
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
Implementation of the Linear layer class.
Definition: layer_types.hpp:93
Definition: pointer_wrapper.hpp:23
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: linear_impl.hpp:146
size_t WeightSize() const
Get the size of the weights.
Definition: linear.hpp:150
void Backward(const arma::Mat< eT > &, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
Definition: linear_impl.hpp:122
Linear()
Create the Linear object.
Definition: linear_impl.hpp:24
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: linear_impl.hpp:112
Linear & operator=(const Linear &layer)
Copy assignment operator.
Definition: linear_impl.hpp:72