13 #ifndef MLPACK_METHODS_ANN_LAYER_ADD_MERGE_IMPL_HPP 14 #define MLPACK_METHODS_ANN_LAYER_ADD_MERGE_IMPL_HPP 19 #include "../visitor/forward_visitor.hpp" 20 #include "../visitor/backward_visitor.hpp" 21 #include "../visitor/gradient_visitor.hpp" 26 template<
typename InputDataType,
typename OutputDataType,
27 typename... CustomLayers>
29 const bool model,
const bool run) :
30 model(model), run(run), ownsLayers(!model)
35 template<
typename InputDataType,
typename OutputDataType,
36 typename... CustomLayers>
38 const bool model,
const bool run,
const bool ownsLayers) :
39 model(model), run(run), ownsLayers(ownsLayers)
44 template<
typename InputDataType,
typename OutputDataType,
45 typename... CustomLayers>
48 if (!model && ownsLayers)
50 std::for_each(network.begin(), network.end(),
51 boost::apply_visitor(deleteVisitor));
55 template <
typename InputDataType,
typename OutputDataType,
56 typename... CustomLayers>
57 template<
typename InputType,
typename OutputType>
59 const InputType& input, OutputType& output)
63 for (
size_t i = 0; i < network.size(); ++i)
66 boost::apply_visitor(outputParameterVisitor, network[i])),
71 output = boost::apply_visitor(outputParameterVisitor, network.front());
72 for (
size_t i = 1; i < network.size(); ++i)
74 output += boost::apply_visitor(outputParameterVisitor, network[i]);
78 template<
typename InputDataType,
typename OutputDataType,
79 typename... CustomLayers>
82 const arma::Mat<eT>& ,
83 const arma::Mat<eT>& gy,
88 for (
size_t i = 0; i < network.size(); ++i)
91 outputParameterVisitor, network[i]), gy,
92 boost::apply_visitor(deltaVisitor, network[i])), network[i]);
95 g = boost::apply_visitor(deltaVisitor, network[0]);
96 for (
size_t i = 1; i < network.size(); ++i)
98 g += boost::apply_visitor(deltaVisitor, network[i]);
105 template<
typename InputDataType,
typename OutputDataType,
106 typename... CustomLayers>
107 template<
typename eT>
109 const arma::Mat<eT>& ,
110 const arma::Mat<eT>& gy,
115 outputParameterVisitor, network[index]), gy,
116 boost::apply_visitor(deltaVisitor, network[index])), network[index]);
117 g = boost::apply_visitor(deltaVisitor, network[index]);
120 template<
typename InputDataType,
typename OutputDataType,
121 typename... CustomLayers>
122 template<
typename eT>
124 const arma::Mat<eT>& input,
125 const arma::Mat<eT>& error,
130 for (
size_t i = 0; i < network.size(); ++i)
137 template<
typename InputDataType,
typename OutputDataType,
138 typename... CustomLayers>
139 template<
typename eT>
141 const arma::Mat<eT>& input,
142 const arma::Mat<eT>& error,
149 template<
typename InputDataType,
typename OutputDataType,
150 typename... CustomLayers>
151 template<
typename Archive>
153 Archive& ar,
const uint32_t )
156 if (cereal::is_loading<Archive>())
160 ar(CEREAL_NVP(model));
162 ar(CEREAL_NVP(ownsLayers));
BackwardVisitor executes the Backward() function given the input, error and delta parameter...
Definition: backward_visitor.hpp:28
Implementation of the AddMerge module class.
Definition: add_merge.hpp:42
void Backward(const arma::Mat< eT > &, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
Definition: add_merge_impl.hpp:81
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
AddMerge(const bool model=false, const bool run=true)
Create the AddMerge object using the specified parameters.
Definition: add_merge_impl.hpp:28
ForwardVisitor executes the Forward() function given the input and output parameter.
Definition: forward_visitor.hpp:28
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: add_merge_impl.hpp:152
SearchModeVisitor executes the Gradient() method of the given module using the input and delta parame...
Definition: gradient_visitor.hpp:28
~AddMerge()
Destructor to release allocated memory.
Definition: add_merge_impl.hpp:46
void Forward(const InputType &, OutputType &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: add_merge_impl.hpp:58
#define CEREAL_VECTOR_VARIANT_POINTER(T)
Cereal does not support the serialization of raw pointer.
Definition: pointer_vector_variant_wrapper.hpp:92