mlpack
cosine_embedding_loss.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_ANN_LOSS_FUNCTION_COSINE_EMBEDDING_HPP
13 #define MLPACK_METHODS_ANN_LOSS_FUNCTION_COSINE_EMBEDDING_HPP
14 
15 #include <mlpack/prereqs.hpp>
16 
17 namespace mlpack {
18 namespace ann {
19 
35 template <
36  typename InputDataType = arma::mat,
37  typename OutputDataType = arma::mat
38 >
40 {
41  public:
53  CosineEmbeddingLoss(const double margin = 0.0,
54  const bool similarity = true,
55  const bool takeMean = false);
56 
64  template <typename PredictionType, typename TargetType>
65  typename PredictionType::elem_type Forward(const PredictionType& prediction,
66  const TargetType& target);
67 
76  template<typename PredictionType, typename TargetType, typename LossType>
77  void Backward(const PredictionType& prediction,
78  const TargetType& target,
79  LossType& loss);
80 
82  InputDataType& InputParameter() const { return inputParameter; }
84  InputDataType& InputParameter() { return inputParameter; }
85 
87  OutputDataType& OutputParameter() const { return outputParameter; }
89  OutputDataType& OutputParameter() { return outputParameter; }
90 
92  OutputDataType& Delta() const { return delta; }
94  OutputDataType& Delta() { return delta; }
95 
97  bool TakeMean() const { return takeMean; }
99  bool& TakeMean() { return takeMean; }
100 
102  double Margin() const { return margin; }
104  double& Margin() { return margin; }
105 
107  bool Similarity() const { return similarity; }
109  bool& Similarity() { return similarity; }
110 
114  template<typename Archive>
115  void serialize(Archive& ar, const uint32_t /* version */);
116 
117  private:
119  OutputDataType delta;
120 
122  InputDataType inputParameter;
123 
125  OutputDataType outputParameter;
126 
128  double margin;
129 
131  bool similarity;
132 
134  bool takeMean;
135 }; // class CosineEmbeddingLoss
136 
137 } // namespace ann
138 } // namespace mlpack
139 
140 // Include implementation.
142 
143 #endif
CosineEmbeddingLoss(const double margin=0.0, const bool similarity=true, const bool takeMean=false)
Create the CosineEmbeddingLoss object.
Definition: cosine_embedding_loss_impl.hpp:22
double Margin() const
Get the value of margin.
Definition: cosine_embedding_loss.hpp:102
OutputDataType & Delta() const
Get the delta.
Definition: cosine_embedding_loss.hpp:92
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
void Backward(const PredictionType &prediction, const TargetType &target, LossType &loss)
Ordinary feed backward pass of a neural network.
Definition: cosine_embedding_loss_impl.hpp:69
OutputDataType & OutputParameter() const
Get the output parameter.
Definition: cosine_embedding_loss.hpp:87
The core includes that mlpack expects; standard C++ includes and Armadillo.
InputDataType & InputParameter()
Modify the input parameter.
Definition: cosine_embedding_loss.hpp:84
bool & Similarity()
Modify the value of takeMean.
Definition: cosine_embedding_loss.hpp:109
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: cosine_embedding_loss_impl.hpp:107
bool Similarity() const
Get the value of similarity hyperparameter.
Definition: cosine_embedding_loss.hpp:107
bool & TakeMean()
Modify the value of takeMean.
Definition: cosine_embedding_loss.hpp:99
InputDataType & InputParameter() const
Get the input parameter.
Definition: cosine_embedding_loss.hpp:82
PredictionType::elem_type Forward(const PredictionType &prediction, const TargetType &target)
Ordinary feed forward pass of a neural network.
Definition: cosine_embedding_loss_impl.hpp:32
double & Margin()
Modify the value of takeMean.
Definition: cosine_embedding_loss.hpp:104
Cosine Embedding Loss function is used for measuring whether two inputs are similar or dissimilar...
Definition: cosine_embedding_loss.hpp:39
OutputDataType & Delta()
Modify the delta.
Definition: cosine_embedding_loss.hpp:94
OutputDataType & OutputParameter()
Modify the output parameter.
Definition: cosine_embedding_loss.hpp:89
bool TakeMean() const
Get the value of takeMean.
Definition: cosine_embedding_loss.hpp:97