mlpack
hard_tanh_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_ANN_LAYER_HARD_TANH_IMPL_HPP
13 #define MLPACK_METHODS_ANN_LAYER_HARD_TANH_IMPL_HPP
14 
15 // In case it hasn't yet been included.
16 #include "hard_tanh.hpp"
17 
18 namespace mlpack {
19 namespace ann {
20 
21 template<typename InputDataType, typename OutputDataType>
23  const double maxValue,
24  const double minValue) :
25  maxValue(maxValue),
26  minValue(minValue)
27 {
28  // Nothing to do here.
29 }
30 
31 template<typename InputDataType, typename OutputDataType>
32 template<typename InputType, typename OutputType>
34  const InputType& input, OutputType& output)
35 {
36  output = input;
37  for (size_t i = 0; i < input.n_elem; ++i)
38  {
39  output(i) = (output(i) > maxValue ? maxValue :
40  (output(i) < minValue ? minValue : output(i)));
41  }
42 }
43 
44 template<typename InputDataType, typename OutputDataType>
45 template<typename DataType>
47  const DataType& input, const DataType& gy, DataType& g)
48 {
49  g = gy;
50  for (size_t i = 0; i < input.n_elem; ++i)
51  {
52  if (input(i) < minValue || input(i) > maxValue)
53  {
54  g(i) = 0;
55  }
56  }
57 }
58 
59 template<typename InputDataType, typename OutputDataType>
60 template<typename Archive>
62  Archive& ar,
63  const uint32_t /* version */)
64 {
65  ar(CEREAL_NVP(maxValue));
66  ar(CEREAL_NVP(minValue));
67 }
68 
69 } // namespace ann
70 } // namespace mlpack
71 
72 #endif
void Forward(const InputType &input, OutputType &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: hard_tanh_impl.hpp:33
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: hard_tanh_impl.hpp:61
void Backward(const DataType &input, const DataType &gy, DataType &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
Definition: hard_tanh_impl.hpp:46
HardTanH(const double maxValue=1, const double minValue=-1)
Create the HardTanH object using the specified parameters.
Definition: hard_tanh_impl.hpp:22