15 #ifndef MLPACK_METHODS_ANN_LAYER_PReLU_IMPL_HPP 16 #define MLPACK_METHODS_ANN_LAYER_PReLU_IMPL_HPP 24 template<
typename InputDataType,
typename OutputDataType>
26 const double userAlpha) : userAlpha(userAlpha)
32 template<
typename InputDataType,
typename OutputDataType>
39 template<
typename InputDataType,
typename OutputDataType>
40 template<
typename InputType,
typename OutputType>
42 const InputType& input, OutputType& output)
45 arma::uvec negative = arma::find(input < 0);
46 output(negative) = input(negative) * alpha(0);
49 template<
typename InputDataType,
typename OutputDataType>
50 template<
typename DataType>
52 const DataType& input,
const DataType& gy, DataType& g)
55 derivative.set_size(arma::size(input));
56 for (
size_t i = 0; i < input.n_elem; ++i)
58 derivative(i) = (input(i) >= 0) ? 1 : alpha(0);
64 template<
typename InputDataType,
typename OutputDataType>
67 const arma::Mat<eT>& input,
68 const arma::Mat<eT>& error,
69 arma::Mat<eT>& gradient)
71 if (gradient.n_elem == 0)
73 gradient = arma::zeros<arma::mat>(1, 1);
76 arma::mat zeros = arma::zeros<arma::mat>(input.n_rows, input.n_cols);
77 gradient(0) = arma::accu(error % arma::min(zeros, input)) / input.n_cols;
80 template<
typename InputDataType,
typename OutputDataType>
81 template<
typename Archive>
86 ar(CEREAL_NVP(alpha));
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
size_t WeightSize() const
Get size of weights.
Definition: parametric_relu.hpp:123
void Backward(const DataType &input, const DataType &gy, DataType &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
Definition: parametric_relu_impl.hpp:51
OutputDataType const & Gradient() const
Get the gradient.
Definition: parametric_relu.hpp:113
PReLU(const double userAlpha=0.03)
Create the PReLU object using the specified parameters.
Definition: parametric_relu_impl.hpp:25
void Forward(const InputType &input, OutputType &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: parametric_relu_impl.hpp:41
void Reset()
Definition: parametric_relu_impl.hpp:33
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: parametric_relu_impl.hpp:82