mlpack
weight_set_visitor_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_ANN_VISITOR_WEIGHT_SET_VISITOR_IMPL_HPP
13 #define MLPACK_METHODS_ANN_VISITOR_WEIGHT_SET_VISITOR_IMPL_HPP
14 
15 // In case it hasn't been included yet.
16 #include "weight_set_visitor.hpp"
17 
18 namespace mlpack {
19 namespace ann {
20 
22 inline WeightSetVisitor::WeightSetVisitor(arma::mat& weight,
23  const size_t offset) :
24  weight(weight),
25  offset(offset)
26 {
27  /* Nothing to do here. */
28 }
29 
30 template<typename LayerType>
31 inline size_t WeightSetVisitor::operator()(LayerType* layer) const
32 {
33  return LayerSize(layer, layer->OutputParameter());
34 }
35 
36 inline size_t WeightSetVisitor::operator()(MoreTypes layer) const
37 {
38  return layer.apply_visitor(*this);
39 }
40 
41 template<typename T, typename P>
42 inline typename std::enable_if<
43  !HasParametersCheck<T, P&(T::*)()>::value &&
44  !HasModelCheck<T>::value, size_t>::type
45 WeightSetVisitor::LayerSize(T* /* layer */, P&& /*output */) const
46 {
47  return 0;
48 }
49 
50 template<typename T, typename P>
51 inline typename std::enable_if<
52  !HasParametersCheck<T, P&(T::*)()>::value &&
53  HasModelCheck<T>::value, size_t>::type
54 WeightSetVisitor::LayerSize(T* layer, P&& /*output */) const
55 {
56  size_t modelOffset = 0;
57  for (size_t i = 0; i < layer->Model().size(); ++i)
58  {
59  modelOffset += boost::apply_visitor(WeightSetVisitor(
60  weight, modelOffset + offset), layer->Model()[i]);
61  }
62 
63  return modelOffset;
64 }
65 
66 template<typename T, typename P>
67 inline typename std::enable_if<
68  HasParametersCheck<T, P&(T::*)()>::value &&
69  !HasModelCheck<T>::value, size_t>::type
70 WeightSetVisitor::LayerSize(T* layer, P&& /* output */) const
71 {
72  layer->Parameters() = arma::mat(weight.memptr() + offset,
73  layer->Parameters().n_rows, layer->Parameters().n_cols, false, false);
74 
75  return layer->Parameters().n_elem;
76 }
77 
78 template<typename T, typename P>
79 inline typename std::enable_if<
80  HasParametersCheck<T, P&(T::*)()>::value &&
81  HasModelCheck<T>::value, size_t>::type
82 WeightSetVisitor::LayerSize(T* layer, P&& /* output */) const
83 {
84  layer->Parameters() = arma::mat(weight.memptr() + offset,
85  layer->Parameters().n_rows, layer->Parameters().n_cols, false, false);
86 
87  size_t modelOffset = layer->Parameters().n_elem;
88  for (size_t i = 0; i < layer->Model().size(); ++i)
89  {
90  modelOffset += boost::apply_visitor(WeightSetVisitor(
91  weight, modelOffset + offset), layer->Model()[i]);
92  }
93 
94  return modelOffset;
95 }
96 
97 } // namespace ann
98 } // namespace mlpack
99 
100 #endif
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
WeightSetVisitor(arma::mat &weight, const size_t offset=0)
Update the parameters given the parameters set and offset.
Definition: weight_set_visitor_impl.hpp:22
size_t operator()(LayerType *layer) const
Update the parameters set.
Definition: weight_set_visitor_impl.hpp:31