mlpack
gradient_update_visitor_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_ANN_VISITOR_GRADIENT_UPDATE_VISITOR_IMPL_HPP
13 #define MLPACK_METHODS_ANN_VISITOR_GRADIENT_UPDATE_VISITOR_IMPL_HPP
14 
15 // In case it hasn't been included yet.
17 
18 namespace mlpack {
19 namespace ann {
20 
23  size_t offset) :
24  gradient(gradient),
25  offset(offset)
26 {
27  /* Nothing to do here. */
28 }
29 
30 template<typename LayerType>
31 inline size_t GradientUpdateVisitor::operator()(LayerType* layer) const
32 {
33  return LayerGradients(layer, layer->OutputParameter());
34 }
35 
36 inline size_t GradientUpdateVisitor::operator()(MoreTypes layer) const
37 {
38  return layer.apply_visitor(*this);
39 }
40 
41 template<typename T>
42 inline typename std::enable_if<
43  HasGradientCheck<T, arma::mat&(T::*)()>::value &&
44  !HasModelCheck<T>::value, size_t>::type
45 GradientUpdateVisitor::LayerGradients(T* layer, arma::mat& /* input */) const
46 {
47  if (layer->Parameters().n_elem != 0)
48  {
49  layer->Gradient() = gradient.submat(offset, 0,
50  offset + layer->Parameters().n_elem - 1, 0);;
51  }
52 
53  return layer->Parameters().n_elem;
54 }
55 
56 template<typename T>
57 inline typename std::enable_if<
58  !HasGradientCheck<T, arma::mat&(T::*)()>::value &&
59  HasModelCheck<T>::value, size_t>::type
60 GradientUpdateVisitor::LayerGradients(T* layer, arma::mat& /* input */) const
61 {
62  size_t modelOffset = 0;
63  for (size_t i = 0; i < layer->Model().size(); ++i)
64  {
65  modelOffset += boost::apply_visitor(GradientUpdateVisitor(
66  gradient, modelOffset + offset), layer->Model()[i]);
67  }
68 
69  return modelOffset;
70 }
71 
72 template<typename T>
73 inline typename std::enable_if<
74  HasGradientCheck<T, arma::mat&(T::*)()>::value &&
75  HasModelCheck<T>::value, size_t>::type
76 GradientUpdateVisitor::LayerGradients(T* layer, arma::mat& /* input */) const
77 {
78  if (layer->Parameters().n_elem != 0)
79  {
80  layer->Gradient() = gradient.submat(offset, 0,
81  offset + layer->Parameters().n_elem - 1, 0);;
82  }
83 
84  size_t modelOffset = layer->Parameters().n_elem;
85  for (size_t i = 0; i < layer->Model().size(); ++i)
86  {
87  modelOffset += boost::apply_visitor(GradientUpdateVisitor(
88  gradient, modelOffset + offset), layer->Model()[i]);
89  }
90 
91  return modelOffset;
92 }
93 
94 template<typename T, typename P>
95 inline typename std::enable_if<
96  !HasGradientCheck<T, P&(T::*)()>::value &&
97  !HasModelCheck<T>::value, size_t>::type
98 GradientUpdateVisitor::LayerGradients(T* /* layer */, P& /* input */) const
99 {
100  return 0;
101 }
102 
103 } // namespace ann
104 } // namespace mlpack
105 
106 #endif
GradientUpdateVisitor(arma::mat &gradient, size_t offset=0)
Update the gradient parameter given the gradient set.
Definition: gradient_update_visitor_impl.hpp:22
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
size_t operator()(LayerType *layer) const
Update the gradient parameter.
Definition: gradient_update_visitor_impl.hpp:31