mlpack
relu6_impl.hpp
Go to the documentation of this file.
1 
23 #ifndef MLPACK_METHODS_ANN_LAYER_RELU6_IMPL_HPP
24 #define MLPACK_METHODS_ANN_LAYER_RELU6_IMPL_HPP
25 
26 // In case it hasn't yet been included.
27 #include "relu6.hpp"
28 
29 namespace mlpack {
30 namespace ann {
31 
32 template<typename InputDataType, typename OutputDataType>
34 {
35  // Nothing to do here.
36 }
37 
38 template<typename InputDataType, typename OutputDataType>
39 template<typename InputType, typename OutputType>
41  const InputType& input, OutputType& output)
42 {
43  OutputType outputTemp(arma::size(input));
44  outputTemp.fill(6.0);
45  output = arma::zeros<OutputType>(arma::size(input));
46  output = arma::min(arma::max(output, input), outputTemp);
47 }
48 
49 template<typename InputDataType, typename OutputDataType>
50 template<typename DataType>
52  const DataType& input, const DataType& gy, DataType& g)
53 {
54  DataType derivative(arma::size(gy));
55  derivative.fill(0.0);
56  for (size_t i = 0; i < input.n_elem; ++i)
57  {
58  if (input(i) < 6 && input(i) > 0)
59  derivative(i) = 1.0;
60  }
61 
62  g = gy % derivative;
63 }
64 
65 template<typename InputDataType, typename OutputDataType>
66 template<typename Archive>
68  Archive& ar,
69  const uint32_t /* version */)
70 {
71  // Nothing to do here.
72 }
73 
74 } // namespace ann
75 } // namespace mlpack
76 
77 #endif
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
ReLU6()
Create the ReLU6 object.
Definition: relu6_impl.hpp:33
void Backward(const DataType &input, const DataType &gy, DataType &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
Definition: relu6_impl.hpp:51
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: relu6_impl.hpp:67
void Forward(const InputType &input, OutputType &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: relu6_impl.hpp:40