13 #ifndef MLPACK_METHODS_ANN_LAYER_MULTIPLY_MERGE_IMPL_HPP 14 #define MLPACK_METHODS_ANN_LAYER_MULTIPLY_MERGE_IMPL_HPP 19 #include "../visitor/forward_visitor.hpp" 20 #include "../visitor/backward_visitor.hpp" 21 #include "../visitor/gradient_visitor.hpp" 26 template<
typename InputDataType,
typename OutputDataType,
27 typename... CustomLayers>
29 const bool model,
const bool run) :
30 model(model), run(run), ownsLayer(!model)
35 template<
typename InputDataType,
typename OutputDataType,
36 typename... CustomLayers>
41 ownsLayer(layer.ownsLayer),
42 network(layer.network),
43 weights(layer.weights)
48 template<
typename InputDataType,
typename OutputDataType,
49 typename... CustomLayers>
52 model(
std::move(layer.model)),
53 run(
std::move(layer.run)),
54 ownsLayer(
std::move(layer.ownsLayer)),
55 network(
std::move(layer.network)),
56 weights(
std::move(layer.weights))
61 template<
typename InputDataType,
typename OutputDataType,
62 typename... CustomLayers>
63 MultiplyMerge<InputDataType, OutputDataType, CustomLayers...>&
71 ownsLayer = layer.ownsLayer;
72 network = layer.network;
73 weights = layer.weights;
78 template<
typename InputDataType,
typename OutputDataType,
79 typename... CustomLayers>
80 MultiplyMerge<InputDataType, OutputDataType, CustomLayers...>&
86 model = std::move(layer.model);
87 run = std::move(layer.run);
88 ownsLayer = std::move(layer.ownsLayer);
89 network = std::move(layer.network);
90 weights = std::move(layer.weights);
95 template<
typename InputDataType,
typename OutputDataType,
96 typename... CustomLayers>
101 std::for_each(network.begin(), network.end(),
102 boost::apply_visitor(deleteVisitor));
106 template <
typename InputDataType,
typename OutputDataType,
107 typename... CustomLayers>
108 template<
typename InputType,
typename OutputType>
110 const InputType& input, OutputType& output)
114 for (
size_t i = 0; i < network.size(); ++i)
117 boost::apply_visitor(outputParameterVisitor, network[i])),
122 output = boost::apply_visitor(outputParameterVisitor, network.front());
123 for (
size_t i = 1; i < network.size(); ++i)
125 output %= boost::apply_visitor(outputParameterVisitor, network[i]);
129 template<
typename InputDataType,
typename OutputDataType,
130 typename... CustomLayers>
131 template<
typename eT>
133 const arma::Mat<eT>& ,
const arma::Mat<eT>& gy, arma::Mat<eT>& g)
137 for (
size_t i = 0; i < network.size(); ++i)
140 outputParameterVisitor, network[i]), gy,
141 boost::apply_visitor(deltaVisitor, network[i])), network[i]);
144 g = boost::apply_visitor(deltaVisitor, network[0]);
145 for (
size_t i = 1; i < network.size(); ++i)
147 g += boost::apply_visitor(deltaVisitor, network[i]);
154 template<
typename InputDataType,
typename OutputDataType,
155 typename... CustomLayers>
156 template<
typename eT>
158 const arma::Mat<eT>& input,
159 const arma::Mat<eT>& error,
164 for (
size_t i = 0; i < network.size(); ++i)
171 template<
typename InputDataType,
typename OutputDataType,
172 typename... CustomLayers>
173 template<
typename Archive>
175 Archive& ar,
const uint32_t )
178 if (cereal::is_loading<Archive>())
182 ar(CEREAL_NVP(model));
184 ar(CEREAL_NVP(ownsLayer));
BackwardVisitor executes the Backward() function given the input, error and delta parameter...
Definition: backward_visitor.hpp:28
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
void Forward(const InputType &, OutputType &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: multiply_merge_impl.hpp:109
MultiplyMerge & operator=(const MultiplyMerge &layer)
Copy assignment operator.
Definition: multiply_merge_impl.hpp:64
void Backward(const arma::Mat< eT > &, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
Definition: multiply_merge_impl.hpp:132
Definition: pointer_wrapper.hpp:23
OutputDataType const & Gradient() const
Get the gradient.
Definition: multiply_merge.hpp:130
Implementation of the MultiplyMerge module class.
Definition: layer_types.hpp:209
ForwardVisitor executes the Forward() function given the input and output parameter.
Definition: forward_visitor.hpp:28
SearchModeVisitor executes the Gradient() method of the given module using the input and delta parame...
Definition: gradient_visitor.hpp:28
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: multiply_merge_impl.hpp:174
MultiplyMerge(const bool model=false, const bool run=true)
Create the MultiplyMerge object using the specified parameters.
Definition: multiply_merge_impl.hpp:28
~MultiplyMerge()
Destructor to release allocated memory.
Definition: multiply_merge_impl.hpp:97
#define CEREAL_VECTOR_VARIANT_POINTER(T)
Cereal does not support the serialization of raw pointer.
Definition: pointer_vector_variant_wrapper.hpp:92