mlpack
elu_impl.hpp
Go to the documentation of this file.
1 
18 #ifndef MLPACK_METHODS_ANN_LAYER_ELU_IMPL_HPP
19 #define MLPACK_METHODS_ANN_LAYER_ELU_IMPL_HPP
20 
21 // In case it hasn't yet been included.
22 #include "elu.hpp"
23 
24 namespace mlpack {
25 namespace ann {
26 
27 // This constructor is called for SELU activation function. The values of
28 // alpha and lambda are constant for normalized inputs.
29 template<typename InputDataType, typename OutputDataType>
31  alpha(1.6732632423543774),
32  lambda(1.0507009873554802),
33  deterministic(false)
34 {
35  // Nothing to do here.
36 }
37 
38 // This constructor is called for ELU activation function. The value of lambda
39 // is fixed and equal to 1. 'alpha' is a hyperparameter.
40 template<typename InputDataType, typename OutputDataType>
42  alpha(alpha),
43  lambda(1),
44  deterministic(false)
45 {
46  // Nothing to do here.
47 }
48 
49 template<typename InputDataType, typename OutputDataType>
50 template<typename InputType, typename OutputType>
52  const InputType& input, OutputType& output)
53 {
54  output = arma::ones<OutputDataType>(arma::size(input));
55  for (size_t i = 0; i < input.n_elem; ++i)
56  {
57  if (input(i) < DBL_MAX)
58  {
59  output(i) = (input(i) > 0) ? lambda * input(i) : lambda *
60  alpha * (std::exp(input(i)) - 1);
61  }
62  }
63 
64  if (!deterministic)
65  {
66  derivative.set_size(arma::size(input));
67  for (size_t i = 0; i < input.n_elem; ++i)
68  {
69  derivative(i) = (input(i) > 0) ? lambda : output(i) +
70  lambda * alpha;
71  }
72  }
73 }
74 
75 template<typename InputDataType, typename OutputDataType>
76 template<typename DataType>
78  const DataType& /* input */, const DataType& gy, DataType& g)
79 {
80  g = gy % derivative;
81 }
82 
83 template<typename InputDataType, typename OutputDataType>
84 template<typename Archive>
86  Archive& ar,
87  const uint32_t /* version */)
88 {
89  ar(CEREAL_NVP(alpha));
90  ar(CEREAL_NVP(lambda));
91 }
92 
93 } // namespace ann
94 } // namespace mlpack
95 
96 #endif
void Forward(const InputType &input, OutputType &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: elu_impl.hpp:51
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: elu_impl.hpp:85
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
ELU()
Create the ELU object.
Definition: elu_impl.hpp:30
void Backward(const DataType &input, const DataType &gy, DataType &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
Definition: elu_impl.hpp:77