mlpack
softshrink_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_ANN_LAYER_SOFTSHRINK_IMPL_HPP
13 #define MLPACK_METHODS_ANN_LAYER_SOFTSHRINK_IMPL_HPP
14 
15 // In case it hasn't yet been included
16 #include "softshrink.hpp"
17 
18 namespace mlpack {
19 namespace ann {
20 
21 // This constructor is called for Soft Shrink activation function.
22 // lambda is a hyperparameter.
23 template<typename InputDataType, typename OutputDataType>
25  lambda(lambda)
26 {
27  // Nothing to do here.
28 }
29 
30 template<typename InputDataType, typename OutputDataType>
31 template<typename InputType, typename OutputType>
33  const InputType& input, OutputType& output)
34 {
35  output = (input > lambda) % (input - lambda) + (
36  input < -lambda) % (input + lambda);
37 }
38 
39 template<typename InputDataType, typename OutputDataType>
40 template<typename DataType>
42  const DataType& input, DataType& gy, DataType& g)
43 {
44  DataType derivative;
45  derivative = (arma::ones(arma::size(input)) - (input == 0));
46  g = gy % derivative;
47 }
48 
49 template<typename InputDataType, typename OutputDataType>
50 template<typename Archive>
52  Archive& ar,
53  const uint32_t /* version */)
54 {
55  ar(CEREAL_NVP(lambda));
56 }
57 
58 } // namespace ann
59 } // namespace mlpack
60 
61 #endif
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
void Forward(const InputType &input, OutputType &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: softshrink_impl.hpp:32
SoftShrink(const double lambda=0.5)
Create Soft Shrink object using specified hyperparameter lambda.
Definition: softshrink_impl.hpp:24
void Backward(const DataType &input, DataType &gy, DataType &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
Definition: softshrink_impl.hpp:41
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: softshrink_impl.hpp:51