mlpack
triplet_margin_loss_impl.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_LOSS_FUNCTION_TRIPLET_MARGIN_IMPL_LOSS_HPP
14 #define MLPACK_METHODS_ANN_LOSS_FUNCTION_TRIPLET_MARGIN_IMPL_LOSS_HPP
15 
16 // In case it hasn't been included.
17 #include "triplet_margin_loss.hpp"
18 
19 namespace mlpack {
20 namespace ann {
21 
22 template<typename InputDataType, typename OutputDataType>
24  const double margin) : margin(margin)
25 {
26  // Nothing to do here.
27 }
28 
29 template<typename InputDataType, typename OutputDataType>
30 template<typename PredictionType, typename TargetType>
31 typename PredictionType::elem_type
33  const PredictionType& prediction,
34  const TargetType& target)
35 {
36  PredictionType anchor =
37  prediction.submat(0, 0, prediction.n_rows / 2 - 1, prediction.n_cols - 1);
38  PredictionType positive =
39  prediction.submat(prediction.n_rows / 2, 0, prediction.n_rows - 1,
40  prediction.n_cols - 1);
41  return std::max(0.0, arma::accu(arma::pow(anchor - positive, 2)) -
42  arma::accu(arma::pow(anchor - target, 2)) + margin) / anchor.n_cols;
43 }
44 
45 template<typename InputDataType, typename OutputDataType>
46 template <
47  typename PredictionType,
48  typename TargetType,
49  typename LossType
50 >
52  const PredictionType& prediction,
53  const TargetType& target,
54  LossType& loss)
55 {
56  PredictionType positive =
57  prediction.submat(prediction.n_rows / 2, 0, prediction.n_rows - 1,
58  prediction.n_cols - 1);
59  loss = 2 * (target - positive) / target.n_cols;
60 }
61 
62 template<typename InputDataType, typename OutputDataType>
63 template<typename Archive>
65  Archive& ar,
66  const unsigned int /* version */)
67 {
68  ar(CEREAL_NVP(margin));
69 }
70 
71 } // namespace ann
72 } // namespace mlpack
73 
74 #endif
TripletMarginLoss(const double margin=1.0)
Create the TripletMarginLoss object.
Definition: triplet_margin_loss_impl.hpp:23
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
PredictionType::elem_type Forward(const PredictionType &prediction, const TargetType &target)
Computes the Triplet Margin Loss function.
Definition: triplet_margin_loss_impl.hpp:32
void Backward(const PredictionType &prediction, const TargetType &target, LossType &loss)
Ordinary feed backward pass of a neural network.
Definition: triplet_margin_loss_impl.hpp:51
void serialize(Archive &ar, const unsigned int)
Serialize the layer.
Definition: triplet_margin_loss_impl.hpp:64