13 #ifndef MLPACK_METHODS_ANN_ACTIVATION_FUNCTIONS_GELU_FUNCTION_HPP 14 #define MLPACK_METHODS_ANN_ACTIVATION_FUNCTIONS_GELU_FUNCTION_HPP 40 static double Fn(
const double x)
42 return 0.5 * x * (1 + std::tanh(std::sqrt(2 / M_PI) *
43 (x + 0.044715 * std::pow(x, 3))));
52 template<
typename InputVecType,
typename OutputVecType>
53 static void Fn(
const InputVecType& x, OutputVecType& y)
55 y = 0.5 * x % (1 + arma::tanh(std::sqrt(2 / M_PI) *
56 (x + 0.044715 * arma::pow(x, 3))));
65 static double Deriv(
const double y)
67 return 0.5 * std::tanh(0.0356774 * std::pow(y, 3) + 0.797885 * y) +
68 (0.0535161 * std::pow(y, 3) + 0.398942 * y) *
69 std::pow(1 / std::cosh(0.0356774 * std::pow(y, 3) +
70 0.797885 * y), 2) + 0.5;
79 template<
typename InputVecType,
typename OutputVecType>
80 static void Deriv(
const InputVecType& y, OutputVecType& x)
82 x = 0.5 * arma::tanh(0.0356774 * arma::pow(y, 3) + 0.797885 * y) +
83 (0.0535161 * arma::pow(y, 3) + 0.398942 * y) %
84 arma::pow(1 / arma::cosh(0.0356774 * arma::pow(y, 3) +
85 0.797885 * y), 2) + 0.5;
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
static double Deriv(const double y)
Computes the first derivative of the GELU function.
Definition: gelu_function.hpp:65
The core includes that mlpack expects; standard C++ includes and Armadillo.
static double Fn(const double x)
Computes the GELU function.
Definition: gelu_function.hpp:40
static void Deriv(const InputVecType &y, OutputVecType &x)
Computes the first derivatives of the GELU function.
Definition: gelu_function.hpp:80
The GELU function, defined by.
Definition: gelu_function.hpp:31
static void Fn(const InputVecType &x, OutputVecType &y)
Computes the GELU function.
Definition: gelu_function.hpp:53