14 #ifndef MLPACK_METHODS_ANN_LAYER_DROPCONNECT_IMPL_HPP 15 #define MLPACK_METHODS_ANN_LAYER_DROPCONNECT_IMPL_HPP 20 #include "../visitor/delete_visitor.hpp" 21 #include "../visitor/forward_visitor.hpp" 22 #include "../visitor/backward_visitor.hpp" 23 #include "../visitor/gradient_visitor.hpp" 24 #include "../visitor/parameters_set_visitor.hpp" 25 #include "../visitor/parameters_visitor.hpp" 30 template<
typename InputDataType,
typename OutputDataType>
39 template<
typename InputDataType,
typename OutputDataType>
45 scale(1.0 / (1 - ratio)),
46 baseLayer(new
Linear<InputDataType, OutputDataType>(inSize, outSize))
48 network.push_back(baseLayer);
51 template<
typename InputDataType,
typename OutputDataType>
54 const arma::Mat<eT>& input,
55 arma::Mat<eT>& output)
70 mask = arma::randu<arma::Mat<eT> >(denoise.n_rows, denoise.n_cols);
71 mask.transform([&](
double val) {
return (val > ratio); });
73 arma::mat tmp = denoise % mask;
78 output = output * scale;
82 template<
typename InputDataType,
typename OutputDataType>
85 const arma::Mat<eT>& input,
86 const arma::Mat<eT>& gy,
92 template<
typename InputDataType,
typename OutputDataType>
95 const arma::Mat<eT>& input,
96 const arma::Mat<eT>& error,
106 template<
typename InputDataType,
typename OutputDataType>
107 template<
typename Archive>
109 Archive& ar,
const uint32_t )
112 if (cereal::is_loading<Archive>())
117 ar(CEREAL_NVP(ratio));
118 ar(CEREAL_NVP(scale));
121 if (cereal::is_loading<Archive>())
124 network.push_back(baseLayer);
DeleteVisitor executes the destructor of the instantiated object.
Definition: delete_visitor.hpp:27
BackwardVisitor executes the Backward() function given the input, error and delta parameter...
Definition: backward_visitor.hpp:28
OutputDataType const & Gradient() const
Get the gradient.
Definition: dropconnect.hpp:133
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
Implementation of the Linear layer class.
Definition: layer_types.hpp:93
void Backward(const arma::Mat< eT > &input, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed backward pass of the DropConnect layer.
Definition: dropconnect_impl.hpp:84
DropConnect()
Create the DropConnect object.
Definition: dropconnect_impl.hpp:31
ForwardVisitor executes the Forward() function given the input and output parameter.
Definition: forward_visitor.hpp:28
#define CEREAL_VARIANT_POINTER(T)
Cereal does not support the serialization of raw pointer.
Definition: pointer_variant_wrapper.hpp:155
SearchModeVisitor executes the Gradient() method of the given module using the input and delta parame...
Definition: gradient_visitor.hpp:28
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Ordinary feed forward pass of the DropConnect layer.
Definition: dropconnect_impl.hpp:53
ParametersVisitor exposes the parameters set of the given module and stores the parameters set into t...
Definition: parameters_visitor.hpp:28
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: dropconnect_impl.hpp:108