mlpack
flatten_t_swish_impl.hpp
Go to the documentation of this file.
1 
14 #ifndef MLPACK_METHODS_ANN_LAYER_FLATTEN_T_SWISH_IMPL_HPP
15 #define MLPACK_METHODS_ANN_LAYER_FLATTEN_T_SWISH_IMPL_HPP
16 
17 // In case it hasn't yet been included.
18 #include "flatten_t_swish.hpp"
21 
22 namespace mlpack {
23 namespace ann {
24 
25 template<typename InputDataType, typename OutputDataType>
27  const double T) : t(T)
28 {
29  // Nothing to do here.
30 }
31 
32 template<typename InputDataType, typename OutputDataType>
33 template<typename InputType, typename OutputType>
35  const InputType& input, OutputType& output)
36 {
37  // Placeholder for Relu values.
38  OutputDataType relu;
39  RectifierFunction::Fn(input, relu);
40  LogisticFunction::Fn(input, output);
41  // F(x) = relu * sigmoid + t.
42  output = relu % output + t;
43 }
44 
45 template<typename InputDataType, typename OutputDataType>
46 template<typename DataType>
48  const DataType& input, const DataType& gy, DataType& g)
49 {
50  DataType derivate, sigmoid;
51  LogisticFunction::Fn(input, sigmoid);
52  derivate.set_size(arma::size(input));
53  for (size_t i = 0; i < input.n_elem; ++i)
54  {
55  if (input(i) >= 0)
56  {
57  // F(x) = x * sigmoid(x).
58  // We don't put '+ t' here because this is a derivate.
59  derivate(i) = input(i) * sigmoid(i);
60  derivate(i) = sigmoid(i) * (1.0 - derivate(i)) + derivate(i);
61  }
62  else
63  {
64  derivate(i) = 0;
65  }
66  }
67  g = gy % derivate;
68 }
69 
70 template<typename InputDataType, typename OutputDataType>
71 template<typename Archive>
73  Archive& ar,
74  const uint32_t /* version */)
75 {
76  ar(CEREAL_NVP(t));
77 }
78 
79 } // namespace ann
80 } // namespace mlpack
81 
82 #endif
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: flatten_t_swish_impl.hpp:72
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
FlattenTSwish(const double T=-0.20)
Create the Flatten T Swish object using the specified parameters.
Definition: flatten_t_swish_impl.hpp:26
void Backward(const DataType &input, const DataType &gy, DataType &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
Definition: flatten_t_swish_impl.hpp:47
void Forward(const InputType &input, OutputType &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: flatten_t_swish_impl.hpp:34
static double Fn(const eT x)
Computes the logistic function.
Definition: logistic_function.hpp:39
static double Fn(const double x)
Computes the rectifier function.
Definition: rectifier_function.hpp:54