12 #ifndef MLPACK_METHODS_ANN_VISITOR_GRADIENT_VISITOR_IMPL_HPP 13 #define MLPACK_METHODS_ANN_VISITOR_GRADIENT_VISITOR_IMPL_HPP 23 const arma::mat& delta) :
33 const arma::mat& delta,
43 template<
typename LayerType>
46 LayerGradients(layer, layer->OutputParameter());
51 layer.apply_visitor(*
this);
55 inline typename std::enable_if<
56 HasGradientCheck<T, arma::mat&(T::*)()>::value &&
57 !HasRunCheck<T, bool&(T::*)(void)>::value,
void>::type
58 GradientVisitor::LayerGradients(T* layer, arma::mat& )
const 60 layer->Gradient(input, delta, layer->Gradient());
64 inline typename std::enable_if<
65 HasGradientCheck<T, arma::mat&(T::*)()>::value &&
66 HasRunCheck<T, bool&(T::*)(void)>::value,
void>::type
67 GradientVisitor::LayerGradients(T* layer, arma::mat& )
const 71 layer->Gradient(input, delta, layer->Gradient());
75 layer->Gradient(input, delta, layer->Gradient(), index);
79 template<
typename T,
typename P>
80 inline typename std::enable_if<
81 !HasGradientCheck<T, P&(T::*)()>::value,
void>::type
82 GradientVisitor::LayerGradients(T* , P& )
const Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
void operator()(LayerType *layer) const
Executes the Gradient() method.
Definition: gradient_visitor_impl.hpp:44
GradientVisitor(const arma::mat &input, const arma::mat &delta)
Executes the Gradient() method of the given module using the input and delta parameter.
Definition: gradient_visitor_impl.hpp:22