mlpack
parametric_relu_impl.hpp
Go to the documentation of this file.
1 
15 #ifndef MLPACK_METHODS_ANN_LAYER_PReLU_IMPL_HPP
16 #define MLPACK_METHODS_ANN_LAYER_PReLU_IMPL_HPP
17 
18 // In case it hasn't yet been included.
19 #include "parametric_relu.hpp"
20 
21 namespace mlpack {
22 namespace ann {
23 
24 template<typename InputDataType, typename OutputDataType>
26  const double userAlpha) : userAlpha(userAlpha)
27 {
28  alpha.set_size(WeightSize(), 1);
29  alpha(0) = userAlpha;
30 }
31 
32 template<typename InputDataType, typename OutputDataType>
34 {
36  alpha(0) = userAlpha;
37 }
38 
39 template<typename InputDataType, typename OutputDataType>
40 template<typename InputType, typename OutputType>
42  const InputType& input, OutputType& output)
43 {
44  output = input;
45  arma::uvec negative = arma::find(input < 0);
46  output(negative) = input(negative) * alpha(0);
47 }
48 
49 template<typename InputDataType, typename OutputDataType>
50 template<typename DataType>
52  const DataType& input, const DataType& gy, DataType& g)
53 {
54  DataType derivative;
55  derivative.set_size(arma::size(input));
56  for (size_t i = 0; i < input.n_elem; ++i)
57  {
58  derivative(i) = (input(i) >= 0) ? 1 : alpha(0);
59  }
60 
61  g = gy % derivative;
62 }
63 
64 template<typename InputDataType, typename OutputDataType>
65 template<typename eT>
67  const arma::Mat<eT>& input,
68  const arma::Mat<eT>& error,
69  arma::Mat<eT>& gradient)
70 {
71  if (gradient.n_elem == 0)
72  {
73  gradient = arma::zeros<arma::mat>(1, 1);
74  }
75 
76  arma::mat zeros = arma::zeros<arma::mat>(input.n_rows, input.n_cols);
77  gradient(0) = arma::accu(error % arma::min(zeros, input)) / input.n_cols;
78 }
79 
80 template<typename InputDataType, typename OutputDataType>
81 template<typename Archive>
83  Archive& ar,
84  const uint32_t /* version */)
85 {
86  ar(CEREAL_NVP(alpha));
87 }
88 
89 } // namespace ann
90 } // namespace mlpack
91 
92 #endif
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
size_t WeightSize() const
Get size of weights.
Definition: parametric_relu.hpp:123
void Backward(const DataType &input, const DataType &gy, DataType &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
Definition: parametric_relu_impl.hpp:51
OutputDataType const & Gradient() const
Get the gradient.
Definition: parametric_relu.hpp:113
PReLU(const double userAlpha=0.03)
Create the PReLU object using the specified parameters.
Definition: parametric_relu_impl.hpp:25
void Forward(const InputType &input, OutputType &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: parametric_relu_impl.hpp:41
void Reset()
Definition: parametric_relu_impl.hpp:33
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: parametric_relu_impl.hpp:82