mlpack
add_merge_impl.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_LAYER_ADD_MERGE_IMPL_HPP
14 #define MLPACK_METHODS_ANN_LAYER_ADD_MERGE_IMPL_HPP
15 
16 // In case it hasn't yet been included.
17 #include "add_merge.hpp"
18 
19 #include "../visitor/forward_visitor.hpp"
20 #include "../visitor/backward_visitor.hpp"
21 #include "../visitor/gradient_visitor.hpp"
22 
23 namespace mlpack {
24 namespace ann {
25 
26 template<typename InputDataType, typename OutputDataType,
27  typename... CustomLayers>
29  const bool model, const bool run) :
30  model(model), run(run), ownsLayers(!model)
31 {
32  // Nothing to do here.
33 }
34 
35 template<typename InputDataType, typename OutputDataType,
36  typename... CustomLayers>
38  const bool model, const bool run, const bool ownsLayers) :
39  model(model), run(run), ownsLayers(ownsLayers)
40 {
41  // Nothing to do here.
42 }
43 
44 template<typename InputDataType, typename OutputDataType,
45  typename... CustomLayers>
47 {
48  if (!model && ownsLayers)
49  {
50  std::for_each(network.begin(), network.end(),
51  boost::apply_visitor(deleteVisitor));
52  }
53 }
54 
55 template <typename InputDataType, typename OutputDataType,
56  typename... CustomLayers>
57 template<typename InputType, typename OutputType>
59  const InputType& input, OutputType& output)
60 {
61  if (run)
62  {
63  for (size_t i = 0; i < network.size(); ++i)
64  {
65  boost::apply_visitor(ForwardVisitor(input,
66  boost::apply_visitor(outputParameterVisitor, network[i])),
67  network[i]);
68  }
69  }
70 
71  output = boost::apply_visitor(outputParameterVisitor, network.front());
72  for (size_t i = 1; i < network.size(); ++i)
73  {
74  output += boost::apply_visitor(outputParameterVisitor, network[i]);
75  }
76 }
77 
78 template<typename InputDataType, typename OutputDataType,
79  typename... CustomLayers>
80 template<typename eT>
82  const arma::Mat<eT>& /* input */,
83  const arma::Mat<eT>& gy,
84  arma::Mat<eT>& g)
85 {
86  if (run)
87  {
88  for (size_t i = 0; i < network.size(); ++i)
89  {
90  boost::apply_visitor(BackwardVisitor(boost::apply_visitor(
91  outputParameterVisitor, network[i]), gy,
92  boost::apply_visitor(deltaVisitor, network[i])), network[i]);
93  }
94 
95  g = boost::apply_visitor(deltaVisitor, network[0]);
96  for (size_t i = 1; i < network.size(); ++i)
97  {
98  g += boost::apply_visitor(deltaVisitor, network[i]);
99  }
100  }
101  else
102  g = gy;
103 }
104 
105 template<typename InputDataType, typename OutputDataType,
106  typename... CustomLayers>
107 template<typename eT>
109  const arma::Mat<eT>& /* input */,
110  const arma::Mat<eT>& gy,
111  arma::Mat<eT>& g,
112  const size_t index)
113 {
114  boost::apply_visitor(BackwardVisitor(boost::apply_visitor(
115  outputParameterVisitor, network[index]), gy,
116  boost::apply_visitor(deltaVisitor, network[index])), network[index]);
117  g = boost::apply_visitor(deltaVisitor, network[index]);
118 }
119 
120 template<typename InputDataType, typename OutputDataType,
121  typename... CustomLayers>
122 template<typename eT>
124  const arma::Mat<eT>& input,
125  const arma::Mat<eT>& error,
126  arma::Mat<eT>& /* gradient */ )
127 {
128  if (run)
129  {
130  for (size_t i = 0; i < network.size(); ++i)
131  {
132  boost::apply_visitor(GradientVisitor(input, error), network[i]);
133  }
134  }
135 }
136 
137 template<typename InputDataType, typename OutputDataType,
138  typename... CustomLayers>
139 template<typename eT>
141  const arma::Mat<eT>& input,
142  const arma::Mat<eT>& error,
143  arma::Mat<eT>& /* gradient */,
144  const size_t index)
145 {
146  boost::apply_visitor(GradientVisitor(input, error), network[index]);
147 }
148 
149 template<typename InputDataType, typename OutputDataType,
150  typename... CustomLayers>
151 template<typename Archive>
153  Archive& ar, const uint32_t /* version */)
154 {
155  // Be sure to clear other layers before loading.
156  if (cereal::is_loading<Archive>())
157  network.clear();
158 
159  ar(CEREAL_VECTOR_VARIANT_POINTER(network));
160  ar(CEREAL_NVP(model));
161  ar(CEREAL_NVP(run));
162  ar(CEREAL_NVP(ownsLayers));
163 }
164 
165 } // namespace ann
166 } // namespace mlpack
167 
168 #endif
BackwardVisitor executes the Backward() function given the input, error and delta parameter...
Definition: backward_visitor.hpp:28
Implementation of the AddMerge module class.
Definition: add_merge.hpp:42
void Backward(const arma::Mat< eT > &, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
Definition: add_merge_impl.hpp:81
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
AddMerge(const bool model=false, const bool run=true)
Create the AddMerge object using the specified parameters.
Definition: add_merge_impl.hpp:28
ForwardVisitor executes the Forward() function given the input and output parameter.
Definition: forward_visitor.hpp:28
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: add_merge_impl.hpp:152
SearchModeVisitor executes the Gradient() method of the given module using the input and delta parame...
Definition: gradient_visitor.hpp:28
~AddMerge()
Destructor to release allocated memory.
Definition: add_merge_impl.hpp:46
void Forward(const InputType &, OutputType &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: add_merge_impl.hpp:58
#define CEREAL_VECTOR_VARIANT_POINTER(T)
Cereal does not support the serialization of raw pointer.
Definition: pointer_vector_variant_wrapper.hpp:92