13 #ifndef MLPACK_METHODS_ANN_LOSS_FUNCTION_TRIPLET_MARGIN_IMPL_LOSS_HPP 14 #define MLPACK_METHODS_ANN_LOSS_FUNCTION_TRIPLET_MARGIN_IMPL_LOSS_HPP 22 template<
typename InputDataType,
typename OutputDataType>
24 const double margin) : margin(margin)
29 template<
typename InputDataType,
typename OutputDataType>
30 template<
typename PredictionType,
typename TargetType>
31 typename PredictionType::elem_type
33 const PredictionType& prediction,
34 const TargetType& target)
36 PredictionType anchor =
37 prediction.submat(0, 0, prediction.n_rows / 2 - 1, prediction.n_cols - 1);
38 PredictionType positive =
39 prediction.submat(prediction.n_rows / 2, 0, prediction.n_rows - 1,
40 prediction.n_cols - 1);
41 return std::max(0.0, arma::accu(arma::pow(anchor - positive, 2)) -
42 arma::accu(arma::pow(anchor - target, 2)) + margin) / anchor.n_cols;
45 template<
typename InputDataType,
typename OutputDataType>
47 typename PredictionType,
52 const PredictionType& prediction,
53 const TargetType& target,
56 PredictionType positive =
57 prediction.submat(prediction.n_rows / 2, 0, prediction.n_rows - 1,
58 prediction.n_cols - 1);
59 loss = 2 * (target - positive) / target.n_cols;
62 template<
typename InputDataType,
typename OutputDataType>
63 template<
typename Archive>
68 ar(CEREAL_NVP(margin));
TripletMarginLoss(const double margin=1.0)
Create the TripletMarginLoss object.
Definition: triplet_margin_loss_impl.hpp:23
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
PredictionType::elem_type Forward(const PredictionType &prediction, const TargetType &target)
Computes the Triplet Margin Loss function.
Definition: triplet_margin_loss_impl.hpp:32
void Backward(const PredictionType &prediction, const TargetType &target, LossType &loss)
Ordinary feed backward pass of a neural network.
Definition: triplet_margin_loss_impl.hpp:51
void serialize(Archive &ar, const unsigned int)
Serialize the layer.
Definition: triplet_margin_loss_impl.hpp:64