mlpack
log_softmax_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_ANN_LAYER_LOG_SOFTMAX_IMPL_HPP
13 #define MLPACK_METHODS_ANN_LAYER_LOG_SOFTMAX_IMPL_HPP
14 
15 // In case it hasn't yet been included.
16 #include "log_softmax.hpp"
17 
18 namespace mlpack {
19 namespace ann {
20 
21 template<typename InputDataType, typename OutputDataType>
23 {
24  // Nothing to do here.
25 }
26 
27 template<typename InputDataType, typename OutputDataType>
28 template<typename InputType, typename OutputType>
30  const InputType& input, OutputType& output)
31 {
32  arma::mat maxInput = arma::repmat(arma::max(input), input.n_rows, 1);
33  output = (maxInput - input);
34 
35  // Approximation of the base-e exponential function. The acuracy however is
36  // about 0.00001 lower as using exp. Credits go to Leon Bottou.
37  output.transform([](double x)
38  {
40  static constexpr double A0 = 1.0;
41  static constexpr double A1 = 0.125;
42  static constexpr double A2 = 0.0078125;
43  static constexpr double A3 = 0.00032552083;
44  static constexpr double A4 = 1.0172526e-5;
45 
46  if (x < 13.0)
47  {
48  double y = A0 + x * (A1 + x * (A2 + x * (A3 + x * A4)));
49  y *= y;
50  y *= y;
51  y *= y;
52  y = 1 / y;
53 
54  return y;
55  }
56 
57  return 0.0;
58  });
59 
60  maxInput.each_row() += arma::log(arma::sum(output));
61  output = input - maxInput;
62 }
63 
64 template<typename InputDataType, typename OutputDataType>
65 template<typename eT>
67  const arma::Mat<eT>& input,
68  const arma::Mat<eT>& gy,
69  arma::Mat<eT>& g)
70 {
71  g = arma::exp(input) + gy;
72 }
73 
74 template<typename InputDataType, typename OutputDataType>
75 template<typename Archive>
77  Archive& /* ar */,
78  const uint32_t /* version */)
79 {
80  // Nothing to do here.
81 }
82 
83 } // namespace ann
84 } // namespace mlpack
85 
86 #endif
void Forward(const InputType &input, OutputType &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: log_softmax_impl.hpp:29
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
void serialize(Archive &, const uint32_t)
Serialize the layer.
Definition: log_softmax_impl.hpp:76
void Backward(const arma::Mat< eT > &input, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
Definition: log_softmax_impl.hpp:66
LogSoftMax()
Create the LogSoftmax object.
Definition: log_softmax_impl.hpp:22