12 #ifndef MLPACK_METHODS_ANN_VISITOR_GRADIENT_UPDATE_VISITOR_IMPL_HPP 13 #define MLPACK_METHODS_ANN_VISITOR_GRADIENT_UPDATE_VISITOR_IMPL_HPP 30 template<
typename LayerType>
33 return LayerGradients(layer, layer->OutputParameter());
38 return layer.apply_visitor(*
this);
42 inline typename std::enable_if<
43 HasGradientCheck<T, arma::mat&(T::*)()>::value &&
44 !HasModelCheck<T>::value,
size_t>::type
45 GradientUpdateVisitor::LayerGradients(T* layer, arma::mat& )
const 47 if (layer->Parameters().n_elem != 0)
49 layer->Gradient() = gradient.submat(offset, 0,
50 offset + layer->Parameters().n_elem - 1, 0);;
53 return layer->Parameters().n_elem;
57 inline typename std::enable_if<
58 !HasGradientCheck<T, arma::mat&(T::*)()>::value &&
59 HasModelCheck<T>::value,
size_t>::type
60 GradientUpdateVisitor::LayerGradients(T* layer, arma::mat& )
const 62 size_t modelOffset = 0;
63 for (
size_t i = 0; i < layer->Model().size(); ++i)
66 gradient, modelOffset + offset), layer->Model()[i]);
73 inline typename std::enable_if<
74 HasGradientCheck<T, arma::mat&(T::*)()>::value &&
75 HasModelCheck<T>::value,
size_t>::type
76 GradientUpdateVisitor::LayerGradients(T* layer, arma::mat& )
const 78 if (layer->Parameters().n_elem != 0)
80 layer->Gradient() = gradient.submat(offset, 0,
81 offset + layer->Parameters().n_elem - 1, 0);;
84 size_t modelOffset = layer->Parameters().n_elem;
85 for (
size_t i = 0; i < layer->Model().size(); ++i)
88 gradient, modelOffset + offset), layer->Model()[i]);
94 template<
typename T,
typename P>
95 inline typename std::enable_if<
96 !HasGradientCheck<T, P&(T::*)()>::value &&
97 !HasModelCheck<T>::value,
size_t>::type
98 GradientUpdateVisitor::LayerGradients(T* , P& )
const GradientUpdateVisitor(arma::mat &gradient, size_t offset=0)
Update the gradient parameter given the gradient set.
Definition: gradient_update_visitor_impl.hpp:22
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
size_t operator()(LayerType *layer) const
Update the gradient parameter.
Definition: gradient_update_visitor_impl.hpp:31