mlpack
loss_visitor_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_ANN_VISITOR_LOSS_VISITOR_IMPL_HPP
13 #define MLPACK_METHODS_ANN_VISITOR_LOSS_VISITOR_IMPL_HPP
14 
15 // In case it hasn't been included yet.
16 #include "loss_visitor.hpp"
17 
18 namespace mlpack {
19 namespace ann {
20 
22 template<typename LayerType>
23 inline double LossVisitor::operator()(LayerType* layer) const
24 {
25  return LayerLoss(layer);
26 }
27 
28 inline double LossVisitor::operator()(MoreTypes layer) const
29 {
30  return layer.apply_visitor(*this);
31 }
32 
33 template<typename T>
34 inline typename std::enable_if<
35  !HasLoss<T, double(T::*)()>::value &&
36  !HasModelCheck<T>::value, double>::type
37 LossVisitor::LayerLoss(T* /* layer */) const
38 {
39  return 0;
40 }
41 
42 template<typename T>
43 inline typename std::enable_if<
44  HasLoss<T, double(T::*)()>::value &&
45  !HasModelCheck<T>::value, double>::type
46 LossVisitor::LayerLoss(T* layer) const
47 {
48  return layer->Loss();
49 }
50 
51 template<typename T>
52 inline typename std::enable_if<
53  !HasLoss<T, double(T::*)()>::value &&
54  HasModelCheck<T>::value, double>::type
55 LossVisitor::LayerLoss(T* layer) const
56 {
57  for (size_t i = 0; i < layer->Model().size(); ++i)
58  {
59  double loss = boost::apply_visitor(LossVisitor(),
60  layer->Model()[layer->Model().size() - 1 - i]);
61 
62  if (loss != 0)
63  {
64  return loss;
65  }
66  }
67 
68  return 0;
69 }
70 
71 template<typename T>
72 inline typename std::enable_if<
73  HasLoss<T, double(T::*)()>::value &&
74  HasModelCheck<T>::value, double>::type
75 LossVisitor::LayerLoss(T* layer) const
76 {
77  double loss = layer->Loss();
78 
79  if (loss == 0)
80  {
81  for (size_t i = 0; i < layer->Model().size(); ++i)
82  {
83  loss = boost::apply_visitor(LossVisitor(),
84  layer->Model()[layer->Model().size() - 1 - i]);
85 
86  if (loss != 0)
87  {
88  return loss;
89  }
90  }
91  }
92 
93  return loss;
94 }
95 
96 } // namespace ann
97 } // namespace mlpack
98 
99 #endif
double operator()(LayerType *layer) const
Return the Loss.
Definition: loss_visitor_impl.hpp:23
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
LossVisitor exposes the Loss() method of the given module.
Definition: loss_visitor.hpp:26