mlpack
backward_visitor_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_ANN_VISITOR_BACKWARD_VISITOR_IMPL_HPP
13 #define MLPACK_METHODS_ANN_VISITOR_BACKWARD_VISITOR_IMPL_HPP
14 
15 // In case it hasn't been included yet.
16 #include "backward_visitor.hpp"
17 
18 namespace mlpack {
19 namespace ann {
20 
22 inline BackwardVisitor::BackwardVisitor(const arma::mat& input,
23  const arma::mat& error,
24  arma::mat& delta) :
25  input(input),
26  error(error),
27  delta(delta),
28  index(0),
29  hasIndex(false)
30 {
31  /* Nothing to do here. */
32 }
33 
34 inline BackwardVisitor::BackwardVisitor(const arma::mat& input,
35  const arma::mat& error,
36  arma::mat& delta,
37  const size_t index) :
38  input(input),
39  error(error),
40  delta(delta),
41  index(index),
42  hasIndex(true)
43 {
44  /* Nothing to do here. */
45 }
46 
47 template<typename LayerType>
48 inline void BackwardVisitor::operator()(LayerType* layer) const
49 {
50  LayerBackward(layer, layer->OutputParameter());
51 }
52 
53 inline void BackwardVisitor::operator()(MoreTypes layer) const
54 {
55  layer.apply_visitor(*this);
56 }
57 
58 template<typename T>
59 inline typename std::enable_if<
60  !HasRunCheck<T, bool&(T::*)(void)>::value, void>::type
61 BackwardVisitor::LayerBackward(T* layer, arma::mat& /* input */) const
62 {
63  layer->Backward(input, error, delta);
64 }
65 
66 template<typename T>
67 inline typename std::enable_if<
68  HasRunCheck<T, bool&(T::*)(void)>::value, void>::type
69 BackwardVisitor::LayerBackward(T* layer, arma::mat& /* input */) const
70 {
71  if (!hasIndex)
72  {
73  layer->Backward(input, error, delta);
74  }
75  else
76  {
77  layer->Backward(input, error, delta, index);
78  }
79 }
80 
81 } // namespace ann
82 } // namespace mlpack
83 
84 #endif
BackwardVisitor(const arma::mat &input, const arma::mat &error, arma::mat &delta)
Execute the Backward() function given the input, error and delta parameter.
Definition: backward_visitor_impl.hpp:22
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
void operator()(LayerType *layer) const
Execute the Backward() function.
Definition: backward_visitor_impl.hpp:48