mlpack
linear3d_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_ANN_LAYER_LINEAR3D_IMPL_HPP
13 #define MLPACK_METHODS_ANN_LAYER_LINEAR3D_IMPL_HPP
14 
15 // In case it hasn't yet been included.
16 #include "linear3d.hpp"
17 
18 namespace mlpack {
19 namespace ann {
20 
21 template<typename InputDataType, typename OutputDataType,
22  typename RegularizerType>
24  inSize(0),
25  outSize(0)
26 {
27  // Nothing to do here.
28 }
29 
30 template<typename InputDataType, typename OutputDataType,
31  typename RegularizerType>
33  const size_t inSize,
34  const size_t outSize,
35  RegularizerType regularizer) :
36  inSize(inSize),
37  outSize(outSize),
38  regularizer(regularizer)
39 {
40  weights.set_size(outSize * inSize + outSize, 1);
41 }
42 
43 template<typename InputDataType, typename OutputDataType,
44  typename RegularizerType>
46  const Linear3D& layer) :
47  inSize(layer.inSize),
48  outSize(layer.outSize),
49  weights(layer.weights),
50  regularizer(layer.regularizer)
51 {
52  // Nothing to do here.
53 }
54 
55 template<typename InputDataType, typename OutputDataType,
56  typename RegularizerType>
58  Linear3D&& layer) :
59  inSize(0),
60  outSize(0),
61  weights(std::move(layer.weights)),
62  regularizer(std::move(layer.regularizer))
63 {
64  // Nothing to do here.
65 }
66 
67 template<typename InputDataType, typename OutputDataType,
68  typename RegularizerType>
71 operator=(const Linear3D& layer)
72 {
73  if (this != &layer)
74  {
75  inSize = layer.inSize;
76  outSize = layer.outSize;
77  weights = layer.weights;
78  regularizer = layer.regularizer;
79  }
80  return *this;
81 }
82 
83 template<typename InputDataType, typename OutputDataType,
84  typename RegularizerType>
88 {
89  if (this != &layer)
90  {
91  inSize = 0;
92  outSize = 0;
93  weights = std::move(layer.weights);
94  regularizer = std::move(layer.regularizer);
95  }
96  return *this;
97 }
98 
99 template<typename InputDataType, typename OutputDataType,
100  typename RegularizerType>
102 {
103  typedef typename arma::Mat<typename OutputDataType::elem_type> MatType;
104 
105  weight = MatType(weights.memptr(), outSize, inSize, false, false);
106  bias = MatType(weights.memptr() + weight.n_elem, outSize, 1, false, false);
107 }
108 
109 template<typename InputDataType, typename OutputDataType,
110  typename RegularizerType>
111 template<typename eT>
113  const arma::Mat<eT>& input, arma::Mat<eT>& output)
114 {
115  typedef typename arma::Mat<eT> MatType;
116  typedef typename arma::Cube<eT> CubeType;
117 
118  if (input.n_rows % inSize != 0)
119  {
120  Log::Fatal << "Number of features in the input must be divisible by inSize."
121  << std::endl;
122  }
123 
124  const size_t nPoints = input.n_rows / inSize;
125  const size_t batchSize = input.n_cols;
126 
127  output.set_size(outSize * nPoints, batchSize);
128 
129  const CubeType inputTemp(const_cast<MatType&>(input).memptr(), inSize,
130  nPoints, batchSize, false, false);
131 
132  for (size_t i = 0; i < batchSize; ++i)
133  {
134  // Shape of weight : (outSize, inSize).
135  // Shape of inputTemp : (inSize, nPoints, batchSize).
136  MatType z = weight * inputTemp.slice(i);
137  z.each_col() += bias;
138  output.col(i) = arma::vectorise(z);
139  }
140 }
141 
142 template<typename InputDataType, typename OutputDataType,
143  typename RegularizerType>
144 template<typename eT>
146  const arma::Mat<eT>& /* input */,
147  const arma::Mat<eT>& gy,
148  arma::Mat<eT>& g)
149 {
150  typedef typename arma::Mat<eT> MatType;
151  typedef typename arma::Cube<eT> CubeType;
152 
153  if (gy.n_rows % outSize != 0)
154  {
155  Log::Fatal << "Number of rows in propagated error must be divisible by \
156  outSize." << std::endl;
157  }
158 
159  const size_t nPoints = gy.n_rows / outSize;
160  const size_t batchSize = gy.n_cols;
161 
162  const CubeType gyTemp(const_cast<MatType&>(gy).memptr(), outSize,
163  nPoints, batchSize, false, false);
164 
165  g.set_size(inSize * nPoints, batchSize);
166 
167  for (size_t i = 0; i < gyTemp.n_slices; ++i)
168  {
169  // Shape of weight : (outSize, inSize).
170  // Shape of gyTemp : (outSize, nPoints, batchSize).
171  g.col(i) = arma::vectorise(weight.t() * gyTemp.slice(i));
172  }
173 }
174 
175 template<typename InputDataType, typename OutputDataType,
176  typename RegularizerType>
177 template<typename eT>
179  const arma::Mat<eT>& input,
180  const arma::Mat<eT>& error,
181  arma::Mat<eT>& gradient)
182 {
183  typedef typename arma::Mat<eT> MatType;
184  typedef typename arma::Cube<eT> CubeType;
185 
186  if (error.n_rows % outSize != 0)
187  Log::Fatal << "Propagated error matrix has invalid dimension!" << std::endl;
188 
189  const size_t nPoints = input.n_rows / inSize;
190  const size_t batchSize = input.n_cols;
191 
192  const CubeType inputTemp(const_cast<MatType&>(input).memptr(), inSize,
193  nPoints, batchSize, false, false);
194  const CubeType errorTemp(const_cast<MatType&>(error).memptr(), outSize,
195  nPoints, batchSize, false, false);
196 
197  CubeType dW(outSize, inSize, batchSize);
198  for (size_t i = 0; i < batchSize; ++i)
199  {
200  // Shape of errorTemp : (outSize, nPoints, batchSize).
201  // Shape of inputTemp : (inSize, nPoints, batchSize).
202  dW.slice(i) = errorTemp.slice(i) * inputTemp.slice(i).t();
203  }
204 
205  gradient.set_size(arma::size(weights));
206 
207  gradient.submat(0, 0, weight.n_elem - 1, 0)
208  = arma::vectorise(arma::sum(dW, 2));
209 
210  gradient.submat(weight.n_elem, 0, weights.n_elem - 1, 0)
211  = arma::vectorise(arma::sum(arma::sum(errorTemp, 2), 1));
212 
213  regularizer.Evaluate(weights, gradient);
214 }
215 
216 template<typename InputDataType, typename OutputDataType,
217  typename RegularizerType>
218 template<typename Archive>
220  Archive& ar, const uint32_t /* version */)
221 {
222  ar(CEREAL_NVP(inSize));
223  ar(CEREAL_NVP(outSize));
224 
225  // This is inefficient, but we have to allocate this memory so that
226  // WeightSetVisitor gets the right size.
227  if (cereal::is_loading<Archive>())
228  weights.set_size(outSize * inSize + outSize, 1);
229 }
230 
231 } // namespace ann
232 } // namespace mlpack
233 
234 #endif
static MLPACK_EXPORT util::PrefixedOutStream Fatal
Prints fatal messages prefixed with [FATAL], then terminates the program.
Definition: log.hpp:90
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
Definition: pointer_wrapper.hpp:23
OutputDataType const & Gradient() const
Get the gradient.
Definition: linear3d.hpp:137
Implementation of the Linear3D layer class.
Definition: layer_types.hpp:112
void Backward(const arma::Mat< eT > &, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
Definition: linear3d_impl.hpp:145
Linear3D & operator=(const Linear3D &layer)
Copy assignment operator.
Definition: linear3d_impl.hpp:71
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: linear3d_impl.hpp:112
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: linear3d_impl.hpp:219
Linear3D()
Create the Linear3D object.
Definition: linear3d_impl.hpp:23