mlpack
weight_size_visitor_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_ANN_VISITOR_WEIGHT_SIZE_VISITOR_IMPL_HPP
13 #define MLPACK_METHODS_ANN_VISITOR_WEIGHT_SIZE_VISITOR_IMPL_HPP
14 
15 // In case it hasn't been included yet.
16 #include "weight_size_visitor.hpp"
17 
18 namespace mlpack {
19 namespace ann {
20 
22 template<typename LayerType>
23 inline size_t WeightSizeVisitor::operator()(LayerType* layer) const
24 {
25  return LayerSize(layer, layer->OutputParameter());
26 }
27 
28 inline size_t WeightSizeVisitor::operator()(MoreTypes layer) const
29 {
30  return layer.apply_visitor(*this);
31 }
32 
33 template<typename T, typename P>
34 inline typename std::enable_if<
35  !HasParametersCheck<T, P&(T::*)()>::value &&
36  !HasModelCheck<T>::value, size_t>::type
37 WeightSizeVisitor::LayerSize(T* /* layer */, P& /* output */) const
38 {
39  return 0;
40 }
41 
42 template<typename T, typename P>
43 inline typename std::enable_if<
44  !HasParametersCheck<T, P&(T::*)()>::value &&
45  HasModelCheck<T>::value, size_t>::type
46 WeightSizeVisitor::LayerSize(T* layer, P& /* output */) const
47 {
48  size_t weights = 0;
49  for (size_t i = 0; i < layer->Model().size(); ++i)
50  {
51  weights += boost::apply_visitor(WeightSizeVisitor(), layer->Model()[i]);
52  }
53 
54  return weights;
55 }
56 
57 template<typename T, typename P>
58 inline typename std::enable_if<
59  HasParametersCheck<T, P&(T::*)()>::value &&
60  !HasModelCheck<T>::value, size_t>::type
61 WeightSizeVisitor::LayerSize(T* layer, P& /* output */) const
62 {
63  return layer->Parameters().n_elem;
64 }
65 
66 template<typename T, typename P>
67 inline typename std::enable_if<
68  HasParametersCheck<T, P&(T::*)()>::value &&
69  HasModelCheck<T>::value, size_t>::type
70 WeightSizeVisitor::LayerSize(T* layer, P& /* output */) const
71 {
72  size_t weights = layer->Parameters().n_elem;
73  for (size_t i = 0; i < layer->Model().size(); ++i)
74  {
75  weights += boost::apply_visitor(WeightSizeVisitor(), layer->Model()[i]);
76  }
77 
78  return weights;
79 }
80 
81 } // namespace ann
82 } // namespace mlpack
83 
84 #endif
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
WeightSizeVisitor returns the number of weights of the given module.
Definition: weight_size_visitor.hpp:27
size_t operator()(LayerType *layer) const
Return the number of weights.
Definition: weight_size_visitor_impl.hpp:23