17 #ifndef MLPACK_METHODS_ANN_LAYER_FLEXIBLERELU_IMPL_HPP 18 #define MLPACK_METHODS_ANN_LAYER_FLEXIBLERELU_IMPL_HPP 26 template<
typename InputDataType,
typename OutputDataType>
28 const double alpha) : userAlpha(alpha)
30 this->alpha.set_size(1, 1);
31 this->alpha(0) = userAlpha;
34 template<
typename InputDataType,
typename OutputDataType>
41 template<
typename InputDataType,
typename OutputDataType>
42 template<
typename InputType,
typename OutputType>
44 const InputType& input, OutputType& output)
46 output = arma::clamp(input, 0.0, DBL_MAX) + alpha(0);
49 template<
typename InputDataType,
typename OutputDataType>
50 template<
typename DataType>
52 const DataType& input,
const DataType& gy, DataType& g)
55 g = gy % arma::clamp(arma::sign(input), 0.0, 1.0);
58 template<
typename InputDataType,
typename OutputDataType>
61 const arma::Mat<eT>& input,
62 const arma::Mat<eT>& error,
63 arma::Mat<eT>& gradient)
65 if (gradient.n_elem == 0)
67 gradient.set_size(1, 1);
70 gradient(0) = arma::accu(error) / input.n_cols;
74 template<
typename InputDataType,
typename OutputDataType>
75 template<
typename Archive>
80 ar(CEREAL_NVP(alpha));
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
void Reset()
Reset the layer parameter.
Definition: flexible_relu_impl.hpp:35
void Forward(const InputType &input, OutputType &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: flexible_relu_impl.hpp:43
FlexibleReLU(const double alpha=0)
Create the FlexibleReLU object using the specified parameters.
Definition: flexible_relu_impl.hpp:27
OutputDataType const & Gradient() const
Get the gradient.
Definition: flexible_relu.hpp:129
void Backward(const DataType &input, const DataType &gy, DataType &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
Definition: flexible_relu_impl.hpp:51
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: flexible_relu_impl.hpp:76