mlpack
margin_ranking_loss_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_ANN_LOSS_FUNCTION_MARGIN_IMPL_LOSS_HPP
13 #define MLPACK_METHODS_ANN_LOSS_FUNCTION_MARGIN_IMPL_LOSS_HPP
14 
15 // In case it hasn't been included.
16 #include "margin_ranking_loss.hpp"
17 
18 namespace mlpack {
19 namespace ann {
20 
21 template<typename InputDataType, typename OutputDataType>
23  const double margin) : margin(margin)
24 {
25  // Nothing to do here.
26 }
27 
28 template<typename InputDataType, typename OutputDataType>
29 template<typename PredictionType, typename TargetType>
30 typename PredictionType::elem_type
32  const PredictionType& prediction,
33  const TargetType& target)
34 {
35  const int predictionRows = prediction.n_rows;
36  const PredictionType& prediction1 = prediction.rows(0,
37  predictionRows / 2 - 1);
38  const PredictionType& prediction2 = prediction.rows(predictionRows / 2,
39  predictionRows - 1);
40  return arma::accu(arma::max(arma::zeros(size(target)),
41  -target % (prediction1 - prediction2) + margin)) / target.n_cols;
42 }
43 
44 template<typename InputDataType, typename OutputDataType>
45 template <
46  typename PredictionType,
47  typename TargetType,
48  typename LossType
49 >
51  const PredictionType& prediction,
52  const TargetType& target,
53  LossType& loss)
54 {
55  const int predictionRows = prediction.n_rows;
56  const PredictionType& prediction1 = prediction.rows(0,
57  predictionRows / 2 - 1);
58  const PredictionType& prediction2 = prediction.rows(predictionRows / 2,
59  predictionRows - 1);
60  loss = -target % (prediction1 - prediction2) + margin;
61  loss.elem(arma::find(loss >= 0)).ones();
62  loss.elem(arma::find(loss < 0)).zeros();
63  loss = (prediction2 - prediction1) % loss / target.n_cols;
64 }
65 
66 template<typename InputDataType, typename OutputDataType>
67 template<typename Archive>
69  Archive& ar,
70  const uint32_t /* version */)
71 {
72  ar(CEREAL_NVP(margin));
73 }
74 
75 } // namespace ann
76 } // namespace mlpack
77 
78 #endif
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: margin_ranking_loss_impl.hpp:68
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
MarginRankingLoss(const double margin=1.0)
Create the MarginRankingLoss object with Hyperparameter margin.
Definition: margin_ranking_loss_impl.hpp:22
void Backward(const PredictionType &prediction, const TargetType &target, LossType &loss)
Ordinary feed backward pass of a neural network.
Definition: margin_ranking_loss_impl.hpp:50
PredictionType::elem_type Forward(const PredictionType &prediction, const TargetType &target)
Computes the Margin Ranking Loss function.
Definition: margin_ranking_loss_impl.hpp:31