mlpack
gradient_zero_visitor_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_ANN_VISITOR_GRADIENT_ZERO_VISITOR_IMPL_HPP
13 #define MLPACK_METHODS_ANN_VISITOR_GRADIENT_ZERO_VISITOR_IMPL_HPP
14 
15 // In case it hasn't been included yet.
17 
18 namespace mlpack {
19 namespace ann {
20 
23 {
24  /* Nothing to do here. */
25 }
26 
27 template<typename LayerType>
28 inline void GradientZeroVisitor::operator()(LayerType* layer) const
29 {
30  LayerGradients(layer, layer->OutputParameter());
31 }
32 
33 inline void GradientZeroVisitor::operator()(MoreTypes layer) const
34 {
35  layer.apply_visitor(*this);
36 }
37 
38 template<typename T>
39 inline typename std::enable_if<
40  HasGradientCheck<T, arma::mat&(T::*)()>::value, void>::type
41 GradientZeroVisitor::LayerGradients(T* layer, arma::mat& /* input */) const
42 {
43  layer->Gradient().zeros();
44 }
45 
46 template<typename T, typename P>
47 inline typename std::enable_if<
48  !HasGradientCheck<T, P&(T::*)()>::value, void>::type
49 GradientZeroVisitor::LayerGradients(T* /* layer */, P& /* input */) const
50 {
51  /* Nothing to do here. */
52 }
53 
54 } // namespace ann
55 } // namespace mlpack
56 
57 #endif
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
void operator()(LayerType *layer) const
Set the gradient to zero.
Definition: gradient_zero_visitor_impl.hpp:28
GradientZeroVisitor()
Set the gradient to zero for the given module.
Definition: gradient_zero_visitor_impl.hpp:22