mlpack
hardshrink_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_ANN_LAYER_HARDSHRINK_IMPL_HPP
13 #define MLPACK_METHODS_ANN_LAYER_HARDSHRINK_IMPL_HPP
14 
15 // In case it hasn't yet been included
16 #include "hardshrink.hpp"
17 
18 namespace mlpack {
19 namespace ann {
20 
21 // This constructor is called for Hard Shrink activation function.
22 // 'lambda' is a hyperparameter.
23 template<typename InputDataType, typename OutputDataType>
25  lambda(lambda)
26 {
27  // Nothing to do here.
28 }
29 
30 template<typename InputDataType, typename OutputDataType>
31 template<typename InputType, typename OutputType>
33  const InputType& input, OutputType& output)
34 {
35  output = ((input > lambda) + (input < -lambda)) % input;
36 }
37 
38 template<typename InputDataType, typename OutputDataType>
39 template<typename DataType>
41  const DataType& input, DataType& gy, DataType& g)
42 {
43  DataType derivative;
44  derivative = (arma::ones(arma::size(input)) - (input == 0));
45  g = gy % derivative;
46 }
47 
48 template<typename InputDataType, typename OutputDataType>
49 template<typename Archive>
51  Archive& ar,
52  const uint32_t /* version */)
53 {
54  ar(CEREAL_NVP(lambda));
55 }
56 
57 } // namespace ann
58 } // namespace mlpack
59 
60 #endif
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: hardshrink_impl.hpp:50
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
HardShrink(const double lambda=0.5)
Create HardShrink object using specified hyperparameter lambda.
Definition: hardshrink_impl.hpp:24
void Forward(const InputType &input, OutputType &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: hardshrink_impl.hpp:32
void Backward(const DataType &input, DataType &gy, DataType &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
Definition: hardshrink_impl.hpp:40