12 #ifndef MLPACK_METHODS_ANN_LOSS_FUNCTION_MARGIN_IMPL_LOSS_HPP 13 #define MLPACK_METHODS_ANN_LOSS_FUNCTION_MARGIN_IMPL_LOSS_HPP 21 template<
typename InputDataType,
typename OutputDataType>
23 const double margin) : margin(margin)
28 template<
typename InputDataType,
typename OutputDataType>
29 template<
typename PredictionType,
typename TargetType>
30 typename PredictionType::elem_type
32 const PredictionType& prediction,
33 const TargetType& target)
35 const int predictionRows = prediction.n_rows;
36 const PredictionType& prediction1 = prediction.rows(0,
37 predictionRows / 2 - 1);
38 const PredictionType& prediction2 = prediction.rows(predictionRows / 2,
40 return arma::accu(arma::max(arma::zeros(size(target)),
41 -target % (prediction1 - prediction2) + margin)) / target.n_cols;
44 template<
typename InputDataType,
typename OutputDataType>
46 typename PredictionType,
51 const PredictionType& prediction,
52 const TargetType& target,
55 const int predictionRows = prediction.n_rows;
56 const PredictionType& prediction1 = prediction.rows(0,
57 predictionRows / 2 - 1);
58 const PredictionType& prediction2 = prediction.rows(predictionRows / 2,
60 loss = -target % (prediction1 - prediction2) + margin;
61 loss.elem(arma::find(loss >= 0)).ones();
62 loss.elem(arma::find(loss < 0)).zeros();
63 loss = (prediction2 - prediction1) % loss / target.n_cols;
66 template<
typename InputDataType,
typename OutputDataType>
67 template<
typename Archive>
72 ar(CEREAL_NVP(margin));
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: margin_ranking_loss_impl.hpp:68
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
MarginRankingLoss(const double margin=1.0)
Create the MarginRankingLoss object with Hyperparameter margin.
Definition: margin_ranking_loss_impl.hpp:22
void Backward(const PredictionType &prediction, const TargetType &target, LossType &loss)
Ordinary feed backward pass of a neural network.
Definition: margin_ranking_loss_impl.hpp:50
PredictionType::elem_type Forward(const PredictionType &prediction, const TargetType &target)
Computes the Margin Ranking Loss function.
Definition: margin_ranking_loss_impl.hpp:31