mlpack
soft_margin_loss_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_ANN_LOSS_FUNCTION_SOFT_MARGIN_LOSS_IMPL_HPP
13 #define MLPACK_METHODS_ANN_LOSS_FUNCTION_SOFT_MARGIN_LOSS_IMPL_HPP
14 
15 // In case it hasn't been included.
16 #include "soft_margin_loss.hpp"
17 
18 namespace mlpack {
19 namespace ann {
20 
21 template<typename InputDataType, typename OutputDataType>
23 SoftMarginLoss(const bool reduction) : reduction(reduction)
24 {
25  // Nothing to do here.
26 }
27 
28 template<typename InputDataType, typename OutputDataType>
29 template<typename PredictionType, typename TargetType>
30 typename PredictionType::elem_type
32  const PredictionType& prediction, const TargetType& target)
33 {
34  PredictionType loss = arma::log(1 + arma::exp(-target % prediction));
35  typename PredictionType::elem_type lossSum = arma::accu(loss);
36 
37  if (reduction)
38  return lossSum;
39 
40  return lossSum / prediction.n_elem;
41 }
42 
43 template<typename InputDataType, typename OutputDataType>
44 template<typename PredictionType, typename TargetType, typename LossType>
46  const PredictionType& prediction,
47  const TargetType& target,
48  LossType& loss)
49 {
50  loss.set_size(size(prediction));
51  PredictionType temp = arma::exp(-target % prediction);
52  PredictionType numerator = -target % temp;
53  PredictionType denominator = 1 + temp;
54  loss = numerator / denominator;
55 
56  if (!reduction)
57  loss = loss / prediction.n_elem;
58 }
59 
60 template<typename InputDataType, typename OutputDataType>
61 template<typename Archive>
63  Archive& ar, const uint32_t /* version */)
64 {
65  ar(CEREAL_NVP(reduction));
66 }
67 
68 } // namespace ann
69 } // namespace mlpack
70 
71 #endif
SoftMarginLoss(const bool reduction=true)
Create the SoftMarginLoss object.
Definition: soft_margin_loss_impl.hpp:23
PredictionType::elem_type Forward(const PredictionType &prediction, const TargetType &target)
Computes the Soft Margin Loss function.
Definition: soft_margin_loss_impl.hpp:31
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
void Backward(const PredictionType &prediction, const TargetType &target, LossType &loss)
Ordinary feed backward pass of a neural network.
Definition: soft_margin_loss_impl.hpp:45
void serialize(Archive &ar, const uint32_t version)
Serialize the layer.
Definition: soft_margin_loss_impl.hpp:62