12 #ifndef MLPACK_METHODS_ANN_LOSS_FUNCTION_COSINE_EMBEDDING_IMPL_HPP 13 #define MLPACK_METHODS_ANN_LOSS_FUNCTION_COSINE_EMBEDDING_IMPL_HPP 21 template<
typename InputDataType,
typename OutputDataType>
23 const double margin,
const bool similarity,
const bool takeMean):
24 margin(margin), similarity(similarity), takeMean(takeMean)
29 template<
typename InputDataType,
typename OutputDataType>
30 template<
typename PredictionType,
typename TargetType>
31 typename PredictionType::elem_type
33 const PredictionType& prediction,
34 const TargetType& target)
36 typedef typename PredictionType::elem_type ElemType;
38 const size_t cols = prediction.n_cols;
39 const size_t batchSize = prediction.n_elem / cols;
40 if (arma::size(prediction) != arma::size(target))
41 Log::Fatal <<
"Input Tensors must have same dimensions." << std::endl;
43 arma::colvec inputTemp1 = arma::vectorise(prediction);
44 arma::colvec inputTemp2 = arma::vectorise(target);
47 for (
size_t i = 0; i < inputTemp1.n_elem; i += cols)
50 inputTemp1(arma::span(i, i + cols - 1)), inputTemp2(arma::span(i,
56 const ElemType currentLoss = cosDist - margin;
57 loss += currentLoss > 0 ? currentLoss : 0;
62 loss = (ElemType) loss / batchSize;
67 template<
typename InputDataType,
typename OutputDataType>
68 template<
typename PredictionType,
typename TargetType,
typename LossType>
70 const PredictionType& prediction,
71 const TargetType& target,
74 typedef typename PredictionType::elem_type ElemType;
76 const size_t cols = prediction.n_cols;
77 if (arma::size(prediction) != arma::size(target))
78 Log::Fatal <<
"Input Tensors must have same dimensions." << std::endl;
80 arma::colvec inputTemp1 = arma::vectorise(prediction);
81 arma::colvec inputTemp2 = arma::vectorise(target);
82 loss.set_size(arma::size(inputTemp1));
84 arma::colvec outputTemp(loss.memptr(), inputTemp1.n_elem,
86 for (
size_t i = 0; i < inputTemp1.n_elem; i += cols)
89 arma::span(i, i + cols -1)), inputTemp2(arma::span(i, i + cols -1)));
91 if (cosDist < margin && !similarity)
92 outputTemp(arma::span(i, i + cols - 1)).zeros();
95 const int multiplier = similarity ? 1 : -1;
96 outputTemp(arma::span(i, i + cols -1)) = -1 * multiplier *
97 (arma::normalise(inputTemp2(arma::span(i, i + cols - 1))) -
98 cosDist * arma::normalise(inputTemp1(arma::span(i, i + cols -
99 1)))) / std::sqrt(arma::accu(arma::pow(inputTemp1(arma::span(i, i +
105 template<
typename InputDataType,
typename OutputDataType>
106 template<
typename Archive>
108 Archive& ar,
const uint32_t )
110 ar(CEREAL_NVP(margin));
111 ar(CEREAL_NVP(similarity));
112 ar(CEREAL_NVP(takeMean));
CosineEmbeddingLoss(const double margin=0.0, const bool similarity=true, const bool takeMean=false)
Create the CosineEmbeddingLoss object.
Definition: cosine_embedding_loss_impl.hpp:22
static MLPACK_EXPORT util::PrefixedOutStream Fatal
Prints fatal messages prefixed with [FATAL], then terminates the program.
Definition: log.hpp:90
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
void Backward(const PredictionType &prediction, const TargetType &target, LossType &loss)
Ordinary feed backward pass of a neural network.
Definition: cosine_embedding_loss_impl.hpp:69
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: cosine_embedding_loss_impl.hpp:107
static double Evaluate(const VecTypeA &a, const VecTypeB &b)
Computes the cosine distance between two points.
Definition: cosine_distance_impl.hpp:21
PredictionType::elem_type Forward(const PredictionType &prediction, const TargetType &target)
Ordinary feed forward pass of a neural network.
Definition: cosine_embedding_loss_impl.hpp:32