mlpack
bias_set_visitor_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_ANN_VISITOR_BIAS_SET_VISITOR_IMPL_HPP
13 #define MLPACK_METHODS_ANN_VISITOR_BIAS_SET_VISITOR_IMPL_HPP
14 
15 // In case it hasn't been included yet.
16 #include "bias_set_visitor.hpp"
17 
18 namespace mlpack {
19 namespace ann {
20 
22 inline BiasSetVisitor::BiasSetVisitor(arma::mat& weight, const size_t offset) :
23  weight(weight),
24  offset(offset)
25 {
26  /* Nothing to do here. */
27 }
28 
29 template<typename LayerType>
30 inline size_t BiasSetVisitor::operator()(LayerType* layer) const
31 {
32  return LayerSize(layer);
33 }
34 
35 inline size_t BiasSetVisitor::operator()(MoreTypes layer) const
36 {
37  return layer.apply_visitor(*this);
38 }
39 
40 template<typename T>
41 inline typename std::enable_if<
42  !HasBiasCheck<T, arma::mat&(T::*)()>::value &&
43  !HasModelCheck<T>::value, size_t>::type
44 BiasSetVisitor::LayerSize(T* /* layer */) const
45 {
46  return 0;
47 }
48 
49 template<typename T>
50 inline typename std::enable_if<
51  !HasBiasCheck<T, arma::mat&(T::*)()>::value &&
52  HasModelCheck<T>::value, size_t>::type
53 BiasSetVisitor::LayerSize(T* layer) const
54 {
55  size_t modelOffset = 0;
56 
57  for (size_t i = 0; i < layer->Model().size(); ++i)
58  {
59  modelOffset += boost::apply_visitor(BiasSetVisitor(
60  weight, modelOffset + offset), layer->Model()[i]);
61  }
62 
63  return modelOffset;
64 }
65 
66 template<typename T>
67 inline typename std::enable_if<
68  HasBiasCheck<T, arma::mat&(T::*)()>::value &&
69  !HasModelCheck<T>::value, size_t>::type
70 BiasSetVisitor::LayerSize(T* layer) const
71 {
72  layer->Bias() = arma::mat(weight.memptr() + offset,
73  layer->Bias().n_rows, layer->Bias().n_cols, false, false);
74 
75  return layer->Bias().n_elem;
76 }
77 
78 template<typename T>
79 inline typename std::enable_if<
80  HasBiasCheck<T, arma::mat&(T::*)()>::value &&
81  HasModelCheck<T>::value, size_t>::type
82 BiasSetVisitor::LayerSize(T* layer) const
83 {
84  layer->Bias() = arma::mat(weight.memptr() + offset,
85  layer->Bias().n_rows, layer->Bias().n_cols, false, false);
86 
87  size_t modelOffset = layer->Bias().n_elem;
88 
89  for (size_t i = 0; i < layer->Model().size(); ++i)
90  {
91  modelOffset += boost::apply_visitor(BiasSetVisitor(
92  weight, modelOffset + offset), layer->Model()[i]);
93  }
94 
95  return modelOffset;
96 }
97 
98 } // namespace ann
99 } // namespace mlpack
100 
101 #endif
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
BiasSetVisitor(arma::mat &weight, const size_t offset=0)
Update the bias parameters given the parameters&#39; set and offset.
Definition: bias_set_visitor_impl.hpp:22
size_t operator()(LayerType *layer) const
Update the parameters&#39; set.
Definition: bias_set_visitor_impl.hpp:30