mlpack
hinge_loss_impl.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_LOSS_FUNCTION_HINGE_LOSS_IMPL_HPP
14 #define MLPACK_METHODS_ANN_LOSS_FUNCTION_HINGE_LOSS_IMPL_HPP
15 
16 // In case it hasn't yet been included.
17 #include "hinge_loss.hpp"
18 
19 namespace mlpack {
20 namespace ann {
21 
22 template<typename InputDataType, typename OutputDataType>
24  reduction(reduction)
25 {
26  // Nothing to do here.
27 }
28 
29 template<typename InputDataType, typename OutputDataType>
30 template<typename PredictionType, typename TargetType>
31 typename PredictionType::elem_type
33  const PredictionType& prediction,
34  const TargetType& target)
35 {
36  TargetType temp = target - (target == 0);
37  TargetType temp_zeros(size(target), arma::fill::zeros);
38 
39  PredictionType loss = arma::max(temp_zeros, 1 - prediction % temp);
40 
41  typename PredictionType::elem_type lossSum = arma::accu(loss);
42 
43  if (reduction)
44  return lossSum;
45 
46  return lossSum / loss.n_elem;
47 }
48 
49 template<typename InputDataType, typename OutputDataType>
50 template<typename PredictionType, typename TargetType, typename LossType>
52  const PredictionType& prediction,
53  const TargetType& target,
54  LossType& loss)
55 {
56  TargetType temp = target - (target == 0);
57  loss = (prediction < (1 / temp)) % -temp;
58 
59  if (!reduction)
60  loss /= target.n_elem;
61 }
62 
63 template<typename InputDataType, typename OutputDataType>
64 template<typename Archive>
66  Archive& ar,
67  const uint32_t /* version */)
68 {
69  ar(CEREAL_NVP(reduction));
70 }
71 
72 } // namespace ann
73 } // namespace mlpack
74 
75 #endif
HingeLoss(const bool reduction=true)
Create HingeLoss object.
Definition: hinge_loss_impl.hpp:23
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
void Backward(const PredictionType &prediction, const TargetType &target, LossType &loss)
Ordinary feed backward pass of a neural network.
Definition: hinge_loss_impl.hpp:51
PredictionType::elem_type Forward(const PredictionType &prediction, const TargetType &target)
Computes the Hinge loss function.
Definition: hinge_loss_impl.hpp:32
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: hinge_loss_impl.hpp:65