12 #ifndef MLPACK_METHODS_ANN_VISITOR_GRADIENT_SET_VISITOR_IMPL_HPP 13 #define MLPACK_METHODS_ANN_VISITOR_GRADIENT_SET_VISITOR_IMPL_HPP 30 template<
typename LayerType>
33 return LayerGradients(layer, layer->OutputParameter());
38 return layer.apply_visitor(*
this);
42 inline typename std::enable_if<
43 HasGradientCheck<T, arma::mat&(T::*)()>::value &&
44 !HasModelCheck<T>::value,
size_t>::type
45 GradientSetVisitor::LayerGradients(T* layer, arma::mat& )
const 47 layer->Gradient() = arma::mat(gradient.memptr() + offset,
48 layer->Parameters().n_rows, layer->Parameters().n_cols,
false,
false);
50 return layer->Parameters().n_elem;
54 inline typename std::enable_if<
55 !HasGradientCheck<T, arma::mat&(T::*)()>::value &&
56 HasModelCheck<T>::value,
size_t>::type
57 GradientSetVisitor::LayerGradients(T* layer, arma::mat& )
const 59 size_t modelOffset = 0;
60 for (
size_t i = 0; i < layer->Model().size(); ++i)
63 gradient, modelOffset + offset), layer->Model()[i]);
70 inline typename std::enable_if<
71 HasGradientCheck<T, arma::mat&(T::*)()>::value &&
72 HasModelCheck<T>::value,
size_t>::type
73 GradientSetVisitor::LayerGradients(T* layer, arma::mat& )
const 75 layer->Gradient() = arma::mat(gradient.memptr() + offset,
76 layer->Parameters().n_rows, layer->Parameters().n_cols,
false,
false);
78 size_t modelOffset = layer->Parameters().n_elem;
79 for (
size_t i = 0; i < layer->Model().size(); ++i)
82 gradient, modelOffset + offset), layer->Model()[i]);
88 template<
typename T,
typename P>
89 inline typename std::enable_if<
90 !HasGradientCheck<T, P&(T::*)()>::value &&
91 !HasModelCheck<T>::value,
size_t>::type
92 GradientSetVisitor::LayerGradients(T* , P& )
const Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
GradientSetVisitor(arma::mat &gradient, size_t offset=0)
Update the gradient parameter given the gradient set.
Definition: gradient_set_visitor_impl.hpp:22
size_t operator()(LayerType *layer) const
Update the gradient parameter.
Definition: gradient_set_visitor_impl.hpp:31