mlpack
gradient_visitor_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_ANN_VISITOR_GRADIENT_VISITOR_IMPL_HPP
13 #define MLPACK_METHODS_ANN_VISITOR_GRADIENT_VISITOR_IMPL_HPP
14 
15 // In case it hasn't been included yet.
16 #include "gradient_visitor.hpp"
17 
18 namespace mlpack {
19 namespace ann {
20 
22 inline GradientVisitor::GradientVisitor(const arma::mat& input,
23  const arma::mat& delta) :
24  input(input),
25  delta(delta),
26  index(0),
27  hasIndex(false)
28 {
29  /* Nothing to do here. */
30 }
31 
32 inline GradientVisitor::GradientVisitor(const arma::mat& input,
33  const arma::mat& delta,
34  const size_t index) :
35  input(input),
36  delta(delta),
37  index(index),
38  hasIndex(true)
39 {
40  /* Nothing to do here. */
41 }
42 
43 template<typename LayerType>
44 inline void GradientVisitor::operator()(LayerType* layer) const
45 {
46  LayerGradients(layer, layer->OutputParameter());
47 }
48 
49 inline void GradientVisitor::operator()(MoreTypes layer) const
50 {
51  layer.apply_visitor(*this);
52 }
53 
54 template<typename T>
55 inline typename std::enable_if<
56  HasGradientCheck<T, arma::mat&(T::*)()>::value &&
57  !HasRunCheck<T, bool&(T::*)(void)>::value, void>::type
58 GradientVisitor::LayerGradients(T* layer, arma::mat& /* input */) const
59 {
60  layer->Gradient(input, delta, layer->Gradient());
61 }
62 
63 template<typename T>
64 inline typename std::enable_if<
65  HasGradientCheck<T, arma::mat&(T::*)()>::value &&
66  HasRunCheck<T, bool&(T::*)(void)>::value, void>::type
67 GradientVisitor::LayerGradients(T* layer, arma::mat& /* input */) const
68 {
69  if (!hasIndex)
70  {
71  layer->Gradient(input, delta, layer->Gradient());
72  }
73  else
74  {
75  layer->Gradient(input, delta, layer->Gradient(), index);
76  }
77 }
78 
79 template<typename T, typename P>
80 inline typename std::enable_if<
81  !HasGradientCheck<T, P&(T::*)()>::value, void>::type
82 GradientVisitor::LayerGradients(T* /* layer */, P& /* input */) const
83 {
84  /* Nothing to do here. */
85 }
86 
87 } // namespace ann
88 } // namespace mlpack
89 
90 #endif
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
void operator()(LayerType *layer) const
Executes the Gradient() method.
Definition: gradient_visitor_impl.hpp:44
GradientVisitor(const arma::mat &input, const arma::mat &delta)
Executes the Gradient() method of the given module using the input and delta parameter.
Definition: gradient_visitor_impl.hpp:22