mlpack
multiply_merge_impl.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_LAYER_MULTIPLY_MERGE_IMPL_HPP
14 #define MLPACK_METHODS_ANN_LAYER_MULTIPLY_MERGE_IMPL_HPP
15 
16 // In case it hasn't yet been included.
17 #include "multiply_merge.hpp"
18 
19 #include "../visitor/forward_visitor.hpp"
20 #include "../visitor/backward_visitor.hpp"
21 #include "../visitor/gradient_visitor.hpp"
22 
23 namespace mlpack {
24 namespace ann {
25 
26 template<typename InputDataType, typename OutputDataType,
27  typename... CustomLayers>
29  const bool model, const bool run) :
30  model(model), run(run), ownsLayer(!model)
31 {
32  // Nothing to do here.
33 }
34 
35 template<typename InputDataType, typename OutputDataType,
36  typename... CustomLayers>
38  const MultiplyMerge& layer) :
39  model(layer.model),
40  run(layer.run),
41  ownsLayer(layer.ownsLayer),
42  network(layer.network),
43  weights(layer.weights)
44 {
45  // Nothing to do here.
46 }
47 
48 template<typename InputDataType, typename OutputDataType,
49  typename... CustomLayers>
51  MultiplyMerge&& layer) :
52  model(std::move(layer.model)),
53  run(std::move(layer.run)),
54  ownsLayer(std::move(layer.ownsLayer)),
55  network(std::move(layer.network)),
56  weights(std::move(layer.weights))
57 {
58  // Nothing to do here.
59 }
60 
61 template<typename InputDataType, typename OutputDataType,
62  typename... CustomLayers>
63 MultiplyMerge<InputDataType, OutputDataType, CustomLayers...>&
65  const MultiplyMerge& layer)
66 {
67  if (this != &layer)
68  {
69  model = layer.model;
70  run = layer.run;
71  ownsLayer = layer.ownsLayer;
72  network = layer.network;
73  weights = layer.weights;
74  }
75  return *this;
76 }
77 
78 template<typename InputDataType, typename OutputDataType,
79  typename... CustomLayers>
80 MultiplyMerge<InputDataType, OutputDataType, CustomLayers...>&
82  MultiplyMerge&& layer)
83 {
84  if (this != &layer)
85  {
86  model = std::move(layer.model);
87  run = std::move(layer.run);
88  ownsLayer = std::move(layer.ownsLayer);
89  network = std::move(layer.network);
90  weights = std::move(layer.weights);
91  }
92  return *this;
93 }
94 
95 template<typename InputDataType, typename OutputDataType,
96  typename... CustomLayers>
98 {
99  if (ownsLayer)
100  {
101  std::for_each(network.begin(), network.end(),
102  boost::apply_visitor(deleteVisitor));
103  }
104 }
105 
106 template <typename InputDataType, typename OutputDataType,
107  typename... CustomLayers>
108 template<typename InputType, typename OutputType>
110  const InputType& input, OutputType& output)
111 {
112  if (run)
113  {
114  for (size_t i = 0; i < network.size(); ++i)
115  {
116  boost::apply_visitor(ForwardVisitor(input,
117  boost::apply_visitor(outputParameterVisitor, network[i])),
118  network[i]);
119  }
120  }
121 
122  output = boost::apply_visitor(outputParameterVisitor, network.front());
123  for (size_t i = 1; i < network.size(); ++i)
124  {
125  output %= boost::apply_visitor(outputParameterVisitor, network[i]);
126  }
127 }
128 
129 template<typename InputDataType, typename OutputDataType,
130  typename... CustomLayers>
131 template<typename eT>
133  const arma::Mat<eT>& /* input */, const arma::Mat<eT>& gy, arma::Mat<eT>& g)
134 {
135  if (run)
136  {
137  for (size_t i = 0; i < network.size(); ++i)
138  {
139  boost::apply_visitor(BackwardVisitor(boost::apply_visitor(
140  outputParameterVisitor, network[i]), gy,
141  boost::apply_visitor(deltaVisitor, network[i])), network[i]);
142  }
143 
144  g = boost::apply_visitor(deltaVisitor, network[0]);
145  for (size_t i = 1; i < network.size(); ++i)
146  {
147  g += boost::apply_visitor(deltaVisitor, network[i]);
148  }
149  }
150  else
151  g = gy;
152 }
153 
154 template<typename InputDataType, typename OutputDataType,
155  typename... CustomLayers>
156 template<typename eT>
158  const arma::Mat<eT>& input,
159  const arma::Mat<eT>& error,
160  arma::Mat<eT>& /* gradient */ )
161 {
162  if (run)
163  {
164  for (size_t i = 0; i < network.size(); ++i)
165  {
166  boost::apply_visitor(GradientVisitor(input, error), network[i]);
167  }
168  }
169 }
170 
171 template<typename InputDataType, typename OutputDataType,
172  typename... CustomLayers>
173 template<typename Archive>
175  Archive& ar, const uint32_t /* version */)
176 {
177  // Be sure to clear other layers before loading.
178  if (cereal::is_loading<Archive>())
179  network.clear();
180 
181  ar(CEREAL_VECTOR_VARIANT_POINTER(network));
182  ar(CEREAL_NVP(model));
183  ar(CEREAL_NVP(run));
184  ar(CEREAL_NVP(ownsLayer));
185 }
186 
187 } // namespace ann
188 } // namespace mlpack
189 
190 #endif
BackwardVisitor executes the Backward() function given the input, error and delta parameter...
Definition: backward_visitor.hpp:28
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
void Forward(const InputType &, OutputType &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: multiply_merge_impl.hpp:109
MultiplyMerge & operator=(const MultiplyMerge &layer)
Copy assignment operator.
Definition: multiply_merge_impl.hpp:64
void Backward(const arma::Mat< eT > &, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
Definition: multiply_merge_impl.hpp:132
Definition: pointer_wrapper.hpp:23
OutputDataType const & Gradient() const
Get the gradient.
Definition: multiply_merge.hpp:130
Implementation of the MultiplyMerge module class.
Definition: layer_types.hpp:209
ForwardVisitor executes the Forward() function given the input and output parameter.
Definition: forward_visitor.hpp:28
SearchModeVisitor executes the Gradient() method of the given module using the input and delta parame...
Definition: gradient_visitor.hpp:28
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: multiply_merge_impl.hpp:174
MultiplyMerge(const bool model=false, const bool run=true)
Create the MultiplyMerge object using the specified parameters.
Definition: multiply_merge_impl.hpp:28
~MultiplyMerge()
Destructor to release allocated memory.
Definition: multiply_merge_impl.hpp:97
#define CEREAL_VECTOR_VARIANT_POINTER(T)
Cereal does not support the serialization of raw pointer.
Definition: pointer_vector_variant_wrapper.hpp:92