mlpack
concat_impl.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_LAYER_CONCAT_IMPL_HPP
14 #define MLPACK_METHODS_ANN_LAYER_CONCAT_IMPL_HPP
15 
16 // In case it hasn't yet been included.
17 #include "concat.hpp"
18 
19 #include "../visitor/forward_visitor.hpp"
20 #include "../visitor/backward_visitor.hpp"
21 #include "../visitor/gradient_visitor.hpp"
22 
23 namespace mlpack {
24 namespace ann {
25 
26 template<typename InputDataType, typename OutputDataType,
27  typename... CustomLayers>
29  const bool model, const bool run) :
30  axis(0),
31  useAxis(false),
32  model(model),
33  run(run),
34  channels(1)
35 {
36  weights.set_size(0, 0);
37 }
38 
39 template<typename InputDataType, typename OutputDataType,
40  typename... CustomLayers>
42  arma::Row<size_t>& inputSize,
43  const size_t axis,
44  const bool model,
45  const bool run) :
46  inputSize(inputSize),
47  axis(axis),
48  useAxis(true),
49  model(model),
50  run(run)
51 {
52  weights.set_size(0, 0);
53 
54  // Parameters to help calculate the number of channels.
55  size_t oldColSize = 1, newColSize = 1;
56  // Axis is specified and useAxis is true.
57  if (useAxis)
58  {
59  // Axis is specified without input dimension.
60  // Throw an error.
61  if (inputSize.n_elem > 0)
62  {
63  // Calculate rowSize, newColSize based on the axis
64  // of concatenation. Finally concat along cols and
65  // reshape to original format i.e. (input, batch_size).
66  size_t i = std::min(axis + 1, (size_t) inputSize.n_elem);
67  for (; i < inputSize.n_elem; ++i)
68  {
69  newColSize *= inputSize[i];
70  }
71  }
72  else
73  {
74  throw std::logic_error("Input dimensions not specified.");
75  }
76  }
77  else
78  {
79  channels = 1;
80  }
81  if (newColSize <= 0)
82  {
83  throw std::logic_error("Col size is zero.");
84  }
85  channels = newColSize / oldColSize;
86  inputSize.clear();
87 }
88 
89 template<typename InputDataType, typename OutputDataType,
90  typename... CustomLayers>
92 {
93  if (!model)
94  {
95  // Clear memory.
96  std::for_each(network.begin(), network.end(),
97  boost::apply_visitor(deleteVisitor));
98  }
99 }
100 
101 template<typename InputDataType, typename OutputDataType,
102  typename... CustomLayers>
103 template<typename eT>
105  const arma::Mat<eT>& input, arma::Mat<eT>& output)
106 {
107  if (run)
108  {
109  for (size_t i = 0; i < network.size(); ++i)
110  {
111  boost::apply_visitor(ForwardVisitor(input,
112  boost::apply_visitor(outputParameterVisitor, network[i])),
113  network[i]);
114  }
115  }
116 
117  output = boost::apply_visitor(outputParameterVisitor, network.front());
118 
119  // Reshape output to incorporate the channels.
120  output.reshape(output.n_rows / channels, output.n_cols * channels);
121 
122  for (size_t i = 1; i < network.size(); ++i)
123  {
124  arma::Mat<eT> out = boost::apply_visitor(outputParameterVisitor,
125  network[i]);
126 
127  out.reshape(out.n_rows / channels, out.n_cols * channels);
128 
129  // Vertically concatentate output from each layer.
130  output = arma::join_cols(output, out);
131  }
132  // Reshape output to its original shape.
133  output.reshape(output.n_rows * channels, output.n_cols / channels);
134 }
135 
136 template<typename InputDataType, typename OutputDataType,
137  typename... CustomLayers>
138 template<typename eT>
140  const arma::Mat<eT>& /* input */, const arma::Mat<eT>& gy, arma::Mat<eT>& g)
141 {
142  size_t rowCount = 0;
143  if (run)
144  {
145  arma::Mat<eT> delta;
146  arma::Mat<eT> gyTmp(((arma::Mat<eT>&) gy).memptr(), gy.n_rows / channels,
147  gy.n_cols * channels, false, false);
148  for (size_t i = 0; i < network.size(); ++i)
149  {
150  // Use rows from the error corresponding to the output from each layer.
151  size_t rows = boost::apply_visitor(
152  outputParameterVisitor, network[i]).n_rows;
153 
154  // Extract from gy the parameters for the i-th network.
155  delta = gyTmp.rows(rowCount / channels, (rowCount + rows) / channels - 1);
156  delta.reshape(delta.n_rows * channels, delta.n_cols / channels);
157 
158  boost::apply_visitor(BackwardVisitor(
159  boost::apply_visitor(outputParameterVisitor,
160  network[i]), delta,
161  boost::apply_visitor(deltaVisitor, network[i])), network[i]);
162  rowCount += rows;
163  }
164 
165  g = boost::apply_visitor(deltaVisitor, network[0]);
166  for (size_t i = 1; i < network.size(); ++i)
167  {
168  g += boost::apply_visitor(deltaVisitor, network[i]);
169  }
170  }
171  else
172  {
173  g = gy;
174  }
175 }
176 
177 template<typename InputDataType, typename OutputDataType,
178  typename... CustomLayers>
179 template<typename eT>
181  const arma::Mat<eT>& /* input */,
182  const arma::Mat<eT>& gy,
183  arma::Mat<eT>& g,
184  const size_t index)
185 {
186  size_t rowCount = 0, rows = 0;
187 
188  for (size_t i = 0; i < index; ++i)
189  {
190  rowCount += boost::apply_visitor(
191  outputParameterVisitor, network[i]).n_rows;
192  }
193  rows = boost::apply_visitor(outputParameterVisitor, network[index]).n_rows;
194 
195  // Reshape gy to extract the i-th layer gy.
196  arma::Mat<eT> gyTmp(((arma::Mat<eT>&) gy).memptr(), gy.n_rows / channels,
197  gy.n_cols * channels, false, false);
198 
199  arma::Mat<eT> delta = gyTmp.rows(rowCount / channels, (rowCount + rows) /
200  channels - 1);
201  delta.reshape(delta.n_rows * channels, delta.n_cols / channels);
202 
203  boost::apply_visitor(BackwardVisitor(boost::apply_visitor(
204  outputParameterVisitor, network[index]), delta,
205  boost::apply_visitor(deltaVisitor, network[index])), network[index]);
206 
207  g = boost::apply_visitor(deltaVisitor, network[index]);
208 }
209 
210 template<typename InputDataType, typename OutputDataType,
211  typename... CustomLayers>
212 template<typename eT>
214  const arma::Mat<eT>& input,
215  const arma::Mat<eT>& error,
216  arma::Mat<eT>& /* gradient */)
217 {
218  if (run)
219  {
220  size_t rowCount = 0;
221  // Reshape error to extract the i-th layer error.
222  arma::Mat<eT> errorTmp(((arma::Mat<eT>&) error).memptr(),
223  error.n_rows / channels, error.n_cols * channels, false, false);
224  for (size_t i = 0; i < network.size(); ++i)
225  {
226  size_t rows = boost::apply_visitor(
227  outputParameterVisitor, network[i]).n_rows;
228 
229  // Extract from error the parameters for the i-th network.
230  arma::Mat<eT> err = errorTmp.rows(rowCount / channels, (rowCount + rows) /
231  channels - 1);
232  err.reshape(err.n_rows * channels, err.n_cols / channels);
233 
234  boost::apply_visitor(GradientVisitor(input, err), network[i]);
235  rowCount += rows;
236  }
237  }
238 }
239 
240 template<typename InputDataType, typename OutputDataType,
241  typename... CustomLayers>
242 template<typename eT>
244  const arma::Mat<eT>& input,
245  const arma::Mat<eT>& error,
246  arma::Mat<eT>& /* gradient */,
247  const size_t index)
248 {
249  size_t rowCount = 0;
250  for (size_t i = 0; i < index; ++i)
251  {
252  rowCount += boost::apply_visitor(outputParameterVisitor,
253  network[i]).n_rows;
254  }
255  size_t rows = boost::apply_visitor(
256  outputParameterVisitor, network[index]).n_rows;
257 
258  arma::Mat<eT> errorTmp(((arma::Mat<eT>&) error).memptr(),
259  error.n_rows / channels, error.n_cols * channels, false, false);
260  arma::Mat<eT> err = errorTmp.rows(rowCount / channels, (rowCount + rows) /
261  channels - 1);
262  err.reshape(err.n_rows * channels, err.n_cols / channels);
263 
264  boost::apply_visitor(GradientVisitor(input, err), network[index]);
265 }
266 
267 template<typename InputDataType, typename OutputDataType,
268  typename... CustomLayers>
269 template<typename Archive>
271  Archive& ar, const uint32_t /* version */)
272 {
273  ar(CEREAL_NVP(model));
274  ar(CEREAL_NVP(run));
275 
276  // Do we have to load or save a model?
277  if (model)
278  {
279  // Clear memory first, if needed.
280  if (cereal::is_loading<Archive>())
281  {
282  std::for_each(network.begin(), network.end(),
283  boost::apply_visitor(deleteVisitor));
284  }
285  ar(CEREAL_VECTOR_VARIANT_POINTER(network));
286  }
287 }
288 
289 } // namespace ann
290 } // namespace mlpack
291 
292 
293 #endif
BackwardVisitor executes the Backward() function given the input, error and delta parameter...
Definition: backward_visitor.hpp:28
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: concat_impl.hpp:270
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
Concat(const bool model=false, const bool run=true)
Create the Concat object using the specified parameters.
Definition: concat_impl.hpp:28
arma::mat const & Gradient() const
Get the gradient.
Definition: concat.hpp:190
ForwardVisitor executes the Forward() function given the input and output parameter.
Definition: forward_visitor.hpp:28
~Concat()
Destroy the layers held by the model.
Definition: concat_impl.hpp:91
SearchModeVisitor executes the Gradient() method of the given module using the input and delta parame...
Definition: gradient_visitor.hpp:28
void Backward(const arma::Mat< eT > &, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed backward pass of a neural network, using 3rd-order tensors as input, calculating the function f(x) by propagating x backwards through f.
Definition: concat_impl.hpp:139
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: concat_impl.hpp:104
#define CEREAL_VECTOR_VARIANT_POINTER(T)
Cereal does not support the serialization of raw pointer.
Definition: pointer_vector_variant_wrapper.hpp:92