mlpack
weight_norm_impl.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_LAYER_WEIGHTNORM_IMPL_HPP
14 #define MLPACK_METHODS_ANN_LAYER_WEIGHTNORM_IMPL_HPP
15 
16 // In case it is not included.
17 #include "weight_norm.hpp"
18 
19 #include "../visitor/forward_visitor.hpp"
20 #include "../visitor/backward_visitor.hpp"
21 #include "../visitor/gradient_visitor.hpp"
22 #include "../visitor/bias_set_visitor.hpp"
23 
24 namespace mlpack {
25 namespace ann {
27 template<typename InputDataType, typename OutputDataType,
28  typename... CustomLayers>
30  LayerTypes<CustomLayers...> layer) :
31  wrappedLayer(layer)
32 {
33  layerWeightSize = boost::apply_visitor(weightSizeVisitor, wrappedLayer);
34  weights.set_size(layerWeightSize + 1, 1);
35 
36  layerWeights.set_size(layerWeightSize, 1);
37  layerGradients.set_size(layerWeightSize, 1);
38 }
39 
40 template<typename InputDataType, typename OutputDataType,
41  typename... CustomLayers>
43 {
44  boost::apply_visitor(deleteVisitor, wrappedLayer);
45 }
46 
47 template<typename InputDataType, typename OutputDataType,
48  typename... CustomLayers>
50 {
51  // Set the weights of the inside layer to layerWeights.
52  // This is done to set the non-bias terms correctly.
53  boost::apply_visitor(WeightSetVisitor(layerWeights, 0), wrappedLayer);
54 
55  boost::apply_visitor(resetVisitor, wrappedLayer);
56 
57  biasWeightSize = boost::apply_visitor(BiasSetVisitor(weights, 0),
58  wrappedLayer);
59 
60  vectorParameter = arma::mat(weights.memptr() + biasWeightSize,
61  layerWeightSize - biasWeightSize, 1, false, false);
62 
63  scalarParameter = arma::mat(weights.memptr() + layerWeightSize, 1, 1, false,
64  false);
65 }
66 
67 template<typename InputDataType, typename OutputDataType,
68  typename... CustomLayers>
69 template<typename eT>
71  const arma::Mat<eT>& input, arma::Mat<eT>& output)
72 {
73  // Initialize the non-bias weights of wrapped layer.
74  const double normVectorParameter = arma::norm(vectorParameter, 2);
75  layerWeights.rows(0, layerWeightSize - biasWeightSize - 1) =
76  scalarParameter(0) * vectorParameter / normVectorParameter;
77 
78  boost::apply_visitor(ForwardVisitor(input,
79  boost::apply_visitor(outputParameterVisitor, wrappedLayer)),
80  wrappedLayer);
81 
82  output = boost::apply_visitor(outputParameterVisitor, wrappedLayer);
83 }
84 
85 template<typename InputDataType, typename OutputDataType,
86  typename... CustomLayers>
87 template<typename eT>
89  const arma::Mat<eT>& /* input */, const arma::Mat<eT>& gy, arma::Mat<eT>& g)
90 {
91  boost::apply_visitor(BackwardVisitor(boost::apply_visitor(
92  outputParameterVisitor, wrappedLayer), gy,
93  boost::apply_visitor(deltaVisitor, wrappedLayer)), wrappedLayer);
94 
95  g = boost::apply_visitor(deltaVisitor, wrappedLayer);
96 }
97 
98 template<typename InputDataType, typename OutputDataType,
99  typename... CustomLayers>
100 template<typename eT>
102  const arma::Mat<eT>& input,
103  const arma::Mat<eT>& error,
104  arma::Mat<eT>& gradient)
105 {
106  ResetGradients(layerGradients);
107 
108  // Calculate the gradients of the wrapped layer.
109  boost::apply_visitor(GradientVisitor(input, error), wrappedLayer);
110 
111  // Store the norm of vector parameter temporarily.
112  const double normVectorParameter = arma::norm(vectorParameter, 2);
113 
114  // Set the gradients of the bias terms.
115  if (biasWeightSize != 0)
116  {
117  gradient.rows(0, biasWeightSize - 1) = arma::mat(layerGradients.memptr() +
118  layerWeightSize - biasWeightSize, biasWeightSize, 1, false, false);
119  }
120 
121  // Calculate the gradients of the scalar parameter.
122  gradient[gradient.n_rows - 1] = arma::accu(layerGradients.rows(0,
123  layerWeightSize - biasWeightSize - 1) % vectorParameter) /
124  normVectorParameter;
125 
126  // Calculate the gradients of the vector parameter.
127  gradient.rows(biasWeightSize, layerWeightSize - 1) =
128  scalarParameter(0) / normVectorParameter * (layerGradients.rows(0,
129  layerWeightSize - biasWeightSize - 1) - gradient[gradient.n_rows - 1] /
130  normVectorParameter * vectorParameter);
131 }
132 
133 template<typename InputDataType, typename OutputDataType,
134  typename... CustomLayers>
136  arma::mat& gradient)
137 {
138  boost::apply_visitor(GradientSetVisitor(gradient, 0), wrappedLayer);
139 }
140 
141 template<typename InputDataType, typename OutputDataType,
142  typename... CustomLayers>
143 template<typename Archive>
145  Archive& ar, const uint32_t /* version */)
146 {
147  if (cereal::is_loading<Archive>())
148  {
149  boost::apply_visitor(deleteVisitor, wrappedLayer);
150  }
151 
152  ar(CEREAL_VARIANT_POINTER(wrappedLayer));
153  ar(CEREAL_NVP(layerWeightSize));
154 
155  // If we are loading, we need to initialize the weights.
156  if (cereal::is_loading<Archive>())
157  {
158  weights.set_size(layerWeightSize + 1, 1);
159  }
160 }
161 
162 } // namespace ann
163 } // namespace mlpack
164 
165 #endif
BackwardVisitor executes the Backward() function given the input, error and delta parameter...
Definition: backward_visitor.hpp:28
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
void Reset()
Reset the layer parameters.
Definition: weight_norm_impl.hpp:49
~WeightNorm()
Destructor to release allocated memory.
Definition: weight_norm_impl.hpp:42
void Backward(const arma::Mat< eT > &input, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Backward pass through the layer.
Definition: weight_norm_impl.hpp:88
GradientSetVisitor update the gradient parameter given the gradient set.
Definition: gradient_set_visitor.hpp:26
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Forward pass of the WeightNorm layer.
Definition: weight_norm_impl.hpp:70
Declaration of the WeightNorm layer class.
Definition: layer_types.hpp:215
WeightNorm(LayerTypes< CustomLayers... > layer=LayerTypes< CustomLayers... >())
Create the WeightNorm layer object.
Definition: weight_norm_impl.hpp:29
BiasSetVisitor updates the module bias parameters given the parameters set.
Definition: bias_set_visitor.hpp:26
WeightSetVisitor update the module parameters given the parameters set.
Definition: weight_set_visitor.hpp:26
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: weight_norm_impl.hpp:144
ForwardVisitor executes the Forward() function given the input and output parameter.
Definition: forward_visitor.hpp:28
#define CEREAL_VARIANT_POINTER(T)
Cereal does not support the serialization of raw pointer.
Definition: pointer_variant_wrapper.hpp:155
SearchModeVisitor executes the Gradient() method of the given module using the input and delta parame...
Definition: gradient_visitor.hpp:28
OutputDataType const & Gradient() const
Get the gradient.
Definition: weight_norm.hpp:123