mlpack
celu_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_ANN_LAYER_CELU_IMPL_HPP
13 #define MLPACK_METHODS_ANN_LAYER_CELU_IMPL_HPP
14 
15 // In case it hasn't yet been included.
16 #include "celu.hpp"
17 
18 namespace mlpack {
19 namespace ann {
20 
21 template<typename InputDataType, typename OutputDataType>
23  alpha(alpha),
24  deterministic(false)
25 {
26  if (alpha == 0)
27  {
28  Log::Fatal << "The value of alpha cannot be equal to 0, "
29  << "terminating the program." << std::endl;
30  }
31 }
32 
33 template<typename InputDataType, typename OutputDataType>
34 template<typename InputType, typename OutputType>
36  const InputType& input, OutputType& output)
37 {
38  output = arma::ones<OutputDataType>(arma::size(input));
39  for (size_t i = 0; i < input.n_elem; ++i)
40  {
41  output(i) = (input(i) >= 0) ? input(i) : alpha *
42  (std::exp(input(i) / alpha) - 1);
43  }
44 
45  if (!deterministic)
46  {
47  derivative.set_size(arma::size(input));
48  for (size_t i = 0; i < input.n_elem; ++i)
49  {
50  derivative(i) = (input(i) >= 0) ? 1 :
51  (output(i) / alpha) + 1;
52  }
53  }
54 }
55 
56 template<typename InputDataType, typename OutputDataType>
57 template<typename DataType>
59  const DataType& /* input */, const DataType& gy, DataType& g)
60 {
61  g = gy % derivative;
62 }
63 
64 template<typename InputDataType, typename OutputDataType>
65 template<typename Archive>
67  Archive& ar,
68  const uint32_t /* version */)
69 {
70  ar(CEREAL_NVP(alpha));
71 }
72 
73 } // namespace ann
74 } // namespace mlpack
75 
76 #endif
CELU(const double alpha=1.0)
Create the CELU object using the specified parameter.
Definition: celu_impl.hpp:22
static MLPACK_EXPORT util::PrefixedOutStream Fatal
Prints fatal messages prefixed with [FATAL], then terminates the program.
Definition: log.hpp:90
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
void Forward(const InputType &input, OutputType &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: celu_impl.hpp:35
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: celu_impl.hpp:66
void Backward(const DataType &input, const DataType &gy, DataType &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
Definition: celu_impl.hpp:58