mlpack
gradient_set_visitor_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_ANN_VISITOR_GRADIENT_SET_VISITOR_IMPL_HPP
13 #define MLPACK_METHODS_ANN_VISITOR_GRADIENT_SET_VISITOR_IMPL_HPP
14 
15 // In case it hasn't been included yet.
16 #include "gradient_set_visitor.hpp"
17 
18 namespace mlpack {
19 namespace ann {
20 
22 inline GradientSetVisitor::GradientSetVisitor(arma::mat& gradient,
23  size_t offset) :
24  gradient(gradient),
25  offset(offset)
26 {
27  /* Nothing to do here. */
28 }
29 
30 template<typename LayerType>
31 inline size_t GradientSetVisitor::operator()(LayerType* layer) const
32 {
33  return LayerGradients(layer, layer->OutputParameter());
34 }
35 
36 inline size_t GradientSetVisitor::operator()(MoreTypes layer) const
37 {
38  return layer.apply_visitor(*this);
39 }
40 
41 template<typename T>
42 inline typename std::enable_if<
43  HasGradientCheck<T, arma::mat&(T::*)()>::value &&
44  !HasModelCheck<T>::value, size_t>::type
45 GradientSetVisitor::LayerGradients(T* layer, arma::mat& /* input */) const
46 {
47  layer->Gradient() = arma::mat(gradient.memptr() + offset,
48  layer->Parameters().n_rows, layer->Parameters().n_cols, false, false);
49 
50  return layer->Parameters().n_elem;
51 }
52 
53 template<typename T>
54 inline typename std::enable_if<
55  !HasGradientCheck<T, arma::mat&(T::*)()>::value &&
56  HasModelCheck<T>::value, size_t>::type
57 GradientSetVisitor::LayerGradients(T* layer, arma::mat& /* input */) const
58 {
59  size_t modelOffset = 0;
60  for (size_t i = 0; i < layer->Model().size(); ++i)
61  {
62  modelOffset += boost::apply_visitor(GradientSetVisitor(
63  gradient, modelOffset + offset), layer->Model()[i]);
64  }
65 
66  return modelOffset;
67 }
68 
69 template<typename T>
70 inline typename std::enable_if<
71  HasGradientCheck<T, arma::mat&(T::*)()>::value &&
72  HasModelCheck<T>::value, size_t>::type
73 GradientSetVisitor::LayerGradients(T* layer, arma::mat& /* input */) const
74 {
75  layer->Gradient() = arma::mat(gradient.memptr() + offset,
76  layer->Parameters().n_rows, layer->Parameters().n_cols, false, false);
77 
78  size_t modelOffset = layer->Parameters().n_elem;
79  for (size_t i = 0; i < layer->Model().size(); ++i)
80  {
81  modelOffset += boost::apply_visitor(GradientSetVisitor(
82  gradient, modelOffset + offset), layer->Model()[i]);
83  }
84 
85  return modelOffset;
86 }
87 
88 template<typename T, typename P>
89 inline typename std::enable_if<
90  !HasGradientCheck<T, P&(T::*)()>::value &&
91  !HasModelCheck<T>::value, size_t>::type
92 GradientSetVisitor::LayerGradients(T* /* layer */, P& /* input */) const
93 {
94  return 0;
95 }
96 
97 } // namespace ann
98 } // namespace mlpack
99 
100 #endif
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
GradientSetVisitor(arma::mat &gradient, size_t offset=0)
Update the gradient parameter given the gradient set.
Definition: gradient_set_visitor_impl.hpp:22
size_t operator()(LayerType *layer) const
Update the gradient parameter.
Definition: gradient_set_visitor_impl.hpp:31