13 #ifndef MLPACK_METHODS_ANN_INIT_RULES_NETWORK_INIT_HPP 14 #define MLPACK_METHODS_ANN_INIT_RULES_NETWORK_INIT_HPP 18 #include "../visitor/reset_visitor.hpp" 19 #include "../visitor/weight_size_visitor.hpp" 20 #include "../visitor/weight_set_visitor.hpp" 32 template<
typename InitializationRuleType,
typename... CustomLayers>
42 const InitializationRuleType& initializeRule = InitializationRuleType()) :
43 initializeRule(initializeRule)
56 template <
typename eT>
57 void Initialize(
const std::vector<LayerTypes<CustomLayers...> >& network,
58 arma::Mat<eT>& parameter,
size_t parameterOffset = 0)
61 if (parameter.is_empty())
64 for (
size_t i = 0; i < network.size(); ++i)
65 weights += boost::apply_visitor(weightSizeVisitor, network[i]);
66 parameter.set_size(weights, 1);
72 for (
size_t i = 0, offset = parameterOffset; i < network.size(); ++i)
76 const size_t weight = boost::apply_visitor(weightSizeVisitor,
78 arma::Mat<eT> tmp = arma::mat(parameter.memptr() + offset,
79 weight, 1,
false,
false);
80 initializeRule.Initialize(tmp, tmp.n_elem, 1);
88 initializeRule.Initialize(parameter, parameter.n_elem, 1);
95 for (
size_t i = 0, offset = parameterOffset; i < network.size(); ++i)
100 boost::apply_visitor(resetVisitor, network[i]);
107 InitializationRuleType initializeRule;
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.
void Initialize(const std::vector< LayerTypes< CustomLayers... > > &network, arma::Mat< eT > ¶meter, size_t parameterOffset=0)
Initialize the specified network and store the results in the given parameter.
Definition: network_init.hpp:57
NetworkInitialization(const InitializationRuleType &initializeRule=InitializationRuleType())
Use the given initialization rule to initialize the specified network.
Definition: network_init.hpp:41
WeightSizeVisitor returns the number of weights of the given module.
Definition: weight_size_visitor.hpp:27
This is a template class that can provide information about various initialization methods...
Definition: init_rules_traits.hpp:28
WeightSetVisitor update the module parameters given the parameters set.
Definition: weight_set_visitor.hpp:26
ResetVisitor executes the Reset() function.
Definition: reset_visitor.hpp:26
This class is used to initialize the network with the given initialization rule.
Definition: network_init.hpp:33