mlpack
set_input_height_visitor_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_ANN_VISITOR_SET_INPUT_HEIGHT_VISITOR_IMPL_HPP
13 #define MLPACK_METHODS_ANN_VISITOR_SET_INPUT_HEIGHT_VISITOR_IMPL_HPP
14 
15 // In case it hasn't been included yet.
17 
18 namespace mlpack {
19 namespace ann {
20 
22 inline SetInputHeightVisitor::SetInputHeightVisitor(const size_t inputHeight,
23  const bool reset) :
24  inputHeight(inputHeight),
25  reset(reset)
26 {
27  /* Nothing to do here. */
28 }
29 
30 template<typename LayerType>
31 inline bool SetInputHeightVisitor::operator()(LayerType* layer) const
32 {
33  return LayerInputHeight(layer);
34 }
35 
36 inline bool SetInputHeightVisitor::operator()(MoreTypes layer) const
37 {
38  return layer.apply_visitor(*this);
39 }
40 
41 template<typename T>
42 inline typename std::enable_if<
43  !HasInputHeight<T, size_t&(T::*)()>::value &&
44  !HasModelCheck<T>::value, bool>::type
45 SetInputHeightVisitor::LayerInputHeight(T* /* layer */) const
46 {
47  return false;
48 }
49 
50 template<typename T>
51 inline typename std::enable_if<
52  HasInputHeight<T, size_t&(T::*)()>::value &&
53  !HasModelCheck<T>::value, bool>::type
54 SetInputHeightVisitor::LayerInputHeight(T* layer) const
55 {
56  if (layer->InputHeight() == 0 || reset)
57  {
58  layer->InputHeight() = inputHeight;
59  }
60 
61  return true;
62 }
63 
64 template<typename T>
65 inline typename std::enable_if<
66  !HasInputHeight<T, size_t&(T::*)()>::value &&
67  HasModelCheck<T>::value, bool>::type
68 SetInputHeightVisitor::LayerInputHeight(T* layer) const
69 {
70  for (size_t i = 0; i < layer->Model().size(); ++i)
71  {
72  boost::apply_visitor(SetInputHeightVisitor(inputHeight, reset),
73  layer->Model()[i]);
74  }
75 
76  return true;
77 }
78 
79 template<typename T>
80 inline typename std::enable_if<
81  HasInputHeight<T, size_t&(T::*)()>::value &&
82  HasModelCheck<T>::value, bool>::type
83 SetInputHeightVisitor::LayerInputHeight(T* layer) const
84 {
85  if (layer->InputHeight() == 0 || reset)
86  {
87  layer->InputHeight() = inputHeight;
88  }
89 
90  for (size_t i = 0; i < layer->Model().size(); ++i)
91  {
92  boost::apply_visitor(SetInputHeightVisitor(inputHeight, reset),
93  layer->Model()[i]);
94  }
95 
96  return true;
97 }
98 
99 } // namespace ann
100 } // namespace mlpack
101 
102 #endif
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
bool operator()(LayerType *layer) const
Update the input height parameter.
Definition: set_input_height_visitor_impl.hpp:31
SetInputHeightVisitor(const size_t inputHeight=0, const bool reset=false)
Update the input height parameter with the given input height.
Definition: set_input_height_visitor_impl.hpp:22