13 #ifndef MLPACK_METHODS_ANN_LAYER_WEIGHTNORM_IMPL_HPP 14 #define MLPACK_METHODS_ANN_LAYER_WEIGHTNORM_IMPL_HPP 19 #include "../visitor/forward_visitor.hpp" 20 #include "../visitor/backward_visitor.hpp" 21 #include "../visitor/gradient_visitor.hpp" 22 #include "../visitor/bias_set_visitor.hpp" 27 template<
typename InputDataType,
typename OutputDataType,
28 typename... CustomLayers>
30 LayerTypes<CustomLayers...> layer) :
33 layerWeightSize = boost::apply_visitor(weightSizeVisitor, wrappedLayer);
34 weights.set_size(layerWeightSize + 1, 1);
36 layerWeights.set_size(layerWeightSize, 1);
37 layerGradients.set_size(layerWeightSize, 1);
40 template<
typename InputDataType,
typename OutputDataType,
41 typename... CustomLayers>
44 boost::apply_visitor(deleteVisitor, wrappedLayer);
47 template<
typename InputDataType,
typename OutputDataType,
48 typename... CustomLayers>
55 boost::apply_visitor(resetVisitor, wrappedLayer);
60 vectorParameter = arma::mat(weights.memptr() + biasWeightSize,
61 layerWeightSize - biasWeightSize, 1,
false,
false);
63 scalarParameter = arma::mat(weights.memptr() + layerWeightSize, 1, 1,
false,
67 template<
typename InputDataType,
typename OutputDataType,
68 typename... CustomLayers>
71 const arma::Mat<eT>& input, arma::Mat<eT>& output)
74 const double normVectorParameter = arma::norm(vectorParameter, 2);
75 layerWeights.rows(0, layerWeightSize - biasWeightSize - 1) =
76 scalarParameter(0) * vectorParameter / normVectorParameter;
79 boost::apply_visitor(outputParameterVisitor, wrappedLayer)),
82 output = boost::apply_visitor(outputParameterVisitor, wrappedLayer);
85 template<
typename InputDataType,
typename OutputDataType,
86 typename... CustomLayers>
89 const arma::Mat<eT>& ,
const arma::Mat<eT>& gy, arma::Mat<eT>& g)
92 outputParameterVisitor, wrappedLayer), gy,
93 boost::apply_visitor(deltaVisitor, wrappedLayer)), wrappedLayer);
95 g = boost::apply_visitor(deltaVisitor, wrappedLayer);
98 template<
typename InputDataType,
typename OutputDataType,
99 typename... CustomLayers>
100 template<
typename eT>
102 const arma::Mat<eT>& input,
103 const arma::Mat<eT>& error,
104 arma::Mat<eT>& gradient)
106 ResetGradients(layerGradients);
112 const double normVectorParameter = arma::norm(vectorParameter, 2);
115 if (biasWeightSize != 0)
117 gradient.rows(0, biasWeightSize - 1) = arma::mat(layerGradients.memptr() +
118 layerWeightSize - biasWeightSize, biasWeightSize, 1,
false,
false);
122 gradient[gradient.n_rows - 1] = arma::accu(layerGradients.rows(0,
123 layerWeightSize - biasWeightSize - 1) % vectorParameter) /
127 gradient.rows(biasWeightSize, layerWeightSize - 1) =
128 scalarParameter(0) / normVectorParameter * (layerGradients.rows(0,
129 layerWeightSize - biasWeightSize - 1) - gradient[gradient.n_rows - 1] /
130 normVectorParameter * vectorParameter);
133 template<
typename InputDataType,
typename OutputDataType,
134 typename... CustomLayers>
141 template<
typename InputDataType,
typename OutputDataType,
142 typename... CustomLayers>
143 template<
typename Archive>
145 Archive& ar,
const uint32_t )
147 if (cereal::is_loading<Archive>())
149 boost::apply_visitor(deleteVisitor, wrappedLayer);
153 ar(CEREAL_NVP(layerWeightSize));
156 if (cereal::is_loading<Archive>())
158 weights.set_size(layerWeightSize + 1, 1);
BackwardVisitor executes the Backward() function given the input, error and delta parameter...
Definition: backward_visitor.hpp:28
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
void Reset()
Reset the layer parameters.
Definition: weight_norm_impl.hpp:49
~WeightNorm()
Destructor to release allocated memory.
Definition: weight_norm_impl.hpp:42
void Backward(const arma::Mat< eT > &input, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Backward pass through the layer.
Definition: weight_norm_impl.hpp:88
GradientSetVisitor update the gradient parameter given the gradient set.
Definition: gradient_set_visitor.hpp:26
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Forward pass of the WeightNorm layer.
Definition: weight_norm_impl.hpp:70
Declaration of the WeightNorm layer class.
Definition: layer_types.hpp:215
WeightNorm(LayerTypes< CustomLayers... > layer=LayerTypes< CustomLayers... >())
Create the WeightNorm layer object.
Definition: weight_norm_impl.hpp:29
BiasSetVisitor updates the module bias parameters given the parameters set.
Definition: bias_set_visitor.hpp:26
WeightSetVisitor update the module parameters given the parameters set.
Definition: weight_set_visitor.hpp:26
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: weight_norm_impl.hpp:144
ForwardVisitor executes the Forward() function given the input and output parameter.
Definition: forward_visitor.hpp:28
#define CEREAL_VARIANT_POINTER(T)
Cereal does not support the serialization of raw pointer.
Definition: pointer_variant_wrapper.hpp:155
SearchModeVisitor executes the Gradient() method of the given module using the input and delta parame...
Definition: gradient_visitor.hpp:28
OutputDataType const & Gradient() const
Get the gradient.
Definition: weight_norm.hpp:123