12 #ifndef MLPACK_METHODS_ANN_LAYER_CELU_IMPL_HPP 13 #define MLPACK_METHODS_ANN_LAYER_CELU_IMPL_HPP 21 template<
typename InputDataType,
typename OutputDataType>
28 Log::Fatal <<
"The value of alpha cannot be equal to 0, " 29 <<
"terminating the program." << std::endl;
33 template<
typename InputDataType,
typename OutputDataType>
34 template<
typename InputType,
typename OutputType>
36 const InputType& input, OutputType& output)
38 output = arma::ones<OutputDataType>(arma::size(input));
39 for (
size_t i = 0; i < input.n_elem; ++i)
41 output(i) = (input(i) >= 0) ? input(i) : alpha *
42 (std::exp(input(i) / alpha) - 1);
47 derivative.set_size(arma::size(input));
48 for (
size_t i = 0; i < input.n_elem; ++i)
50 derivative(i) = (input(i) >= 0) ? 1 :
51 (output(i) / alpha) + 1;
56 template<
typename InputDataType,
typename OutputDataType>
57 template<
typename DataType>
59 const DataType& ,
const DataType& gy, DataType& g)
64 template<
typename InputDataType,
typename OutputDataType>
65 template<
typename Archive>
70 ar(CEREAL_NVP(alpha));
CELU(const double alpha=1.0)
Create the CELU object using the specified parameter.
Definition: celu_impl.hpp:22
static MLPACK_EXPORT util::PrefixedOutStream Fatal
Prints fatal messages prefixed with [FATAL], then terminates the program.
Definition: log.hpp:90
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
void Forward(const InputType &input, OutputType &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: celu_impl.hpp:35
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: celu_impl.hpp:66
void Backward(const DataType &input, const DataType &gy, DataType &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
Definition: celu_impl.hpp:58