mlpack
dropconnect_impl.hpp
Go to the documentation of this file.
1 
14 #ifndef MLPACK_METHODS_ANN_LAYER_DROPCONNECT_IMPL_HPP
15 #define MLPACK_METHODS_ANN_LAYER_DROPCONNECT_IMPL_HPP
16 
17 // In case it hasn't yet been included.
18 #include "dropconnect.hpp"
19 
20 #include "../visitor/delete_visitor.hpp"
21 #include "../visitor/forward_visitor.hpp"
22 #include "../visitor/backward_visitor.hpp"
23 #include "../visitor/gradient_visitor.hpp"
24 #include "../visitor/parameters_set_visitor.hpp"
25 #include "../visitor/parameters_visitor.hpp"
26 
27 namespace mlpack {
28 namespace ann {
29 
30 template<typename InputDataType, typename OutputDataType>
32  ratio(0.5),
33  scale(2.0),
34  deterministic(true)
35 {
36  // Nothing to do here.
37 }
38 
39 template<typename InputDataType, typename OutputDataType>
41  const size_t inSize,
42  const size_t outSize,
43  const double ratio) :
44  ratio(ratio),
45  scale(1.0 / (1 - ratio)),
46  baseLayer(new Linear<InputDataType, OutputDataType>(inSize, outSize))
47 {
48  network.push_back(baseLayer);
49 }
50 
51 template<typename InputDataType, typename OutputDataType>
52 template<typename eT>
54  const arma::Mat<eT>& input,
55  arma::Mat<eT>& output)
56 {
57  // The DropConnect mask will not be multiplied in the deterministic mode
58  // (during testing).
59  if (deterministic)
60  {
61  boost::apply_visitor(ForwardVisitor(input, output), baseLayer);
62  }
63  else
64  {
65  // Save weights for denoising.
66  boost::apply_visitor(ParametersVisitor(denoise), baseLayer);
67 
68  // Scale with input / (1 - ratio) and set values to zero with
69  // probability ratio.
70  mask = arma::randu<arma::Mat<eT> >(denoise.n_rows, denoise.n_cols);
71  mask.transform([&](double val) { return (val > ratio); });
72 
73  arma::mat tmp = denoise % mask;
74  boost::apply_visitor(ParametersSetVisitor(tmp), baseLayer);
75 
76  boost::apply_visitor(ForwardVisitor(input, output), baseLayer);
77 
78  output = output * scale;
79  }
80 }
81 
82 template<typename InputDataType, typename OutputDataType>
83 template<typename eT>
85  const arma::Mat<eT>& input,
86  const arma::Mat<eT>& gy,
87  arma::Mat<eT>& g)
88 {
89  boost::apply_visitor(BackwardVisitor(input, gy, g), baseLayer);
90 }
91 
92 template<typename InputDataType, typename OutputDataType>
93 template<typename eT>
95  const arma::Mat<eT>& input,
96  const arma::Mat<eT>& error,
97  arma::Mat<eT>& /* gradient */)
98 {
99  boost::apply_visitor(GradientVisitor(input, error),
100  baseLayer);
101 
102  // Denoise the weights.
103  boost::apply_visitor(ParametersSetVisitor(denoise), baseLayer);
104 }
105 
106 template<typename InputDataType, typename OutputDataType>
107 template<typename Archive>
109  Archive& ar, const uint32_t /* version */)
110 {
111  // Delete the old network first, if needed.
112  if (cereal::is_loading<Archive>())
113  {
114  boost::apply_visitor(DeleteVisitor(), baseLayer);
115  }
116 
117  ar(CEREAL_NVP(ratio));
118  ar(CEREAL_NVP(scale));
119  ar(CEREAL_VARIANT_POINTER(baseLayer));
120 
121  if (cereal::is_loading<Archive>())
122  {
123  network.clear();
124  network.push_back(baseLayer);
125  }
126 }
127 
128 } // namespace ann
129 } // namespace mlpack
130 
131 #endif
DeleteVisitor executes the destructor of the instantiated object.
Definition: delete_visitor.hpp:27
BackwardVisitor executes the Backward() function given the input, error and delta parameter...
Definition: backward_visitor.hpp:28
OutputDataType const & Gradient() const
Get the gradient.
Definition: dropconnect.hpp:133
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
ParametersSetVisitor update the parameters set using the given matrix.
Definition: parameters_set_visitor.hpp:27
Implementation of the Linear layer class.
Definition: layer_types.hpp:93
void Backward(const arma::Mat< eT > &input, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed backward pass of the DropConnect layer.
Definition: dropconnect_impl.hpp:84
DropConnect()
Create the DropConnect object.
Definition: dropconnect_impl.hpp:31
ForwardVisitor executes the Forward() function given the input and output parameter.
Definition: forward_visitor.hpp:28
#define CEREAL_VARIANT_POINTER(T)
Cereal does not support the serialization of raw pointer.
Definition: pointer_variant_wrapper.hpp:155
SearchModeVisitor executes the Gradient() method of the given module using the input and delta parame...
Definition: gradient_visitor.hpp:28
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Ordinary feed forward pass of the DropConnect layer.
Definition: dropconnect_impl.hpp:53
ParametersVisitor exposes the parameters set of the given module and stores the parameters set into t...
Definition: parameters_visitor.hpp:28
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: dropconnect_impl.hpp:108