mlpack
dice_loss_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_ANN_LOSS_FUNCTIONS_DICE_LOSS_IMPL_HPP
13 #define MLPACK_METHODS_ANN_LOSS_FUNCTIONS_DICE_LOSS_IMPL_HPP
14 
15 // In case it hasn't yet been included.
16 #include "dice_loss.hpp"
17 
18 namespace mlpack {
19 namespace ann {
20 
21 template<typename InputDataType, typename OutputDataType>
23  const double smooth) : smooth(smooth)
24 {
25  // Nothing to do here.
26 }
27 
28 template<typename InputDataType, typename OutputDataType>
29 template<typename PredictionType, typename TargetType>
30 typename PredictionType::elem_type DiceLoss<InputDataType, OutputDataType>
31  ::Forward(const PredictionType& prediction,
32  const TargetType& target)
33 {
34  return 1 - ((2 * arma::accu(target % prediction) + smooth) /
35  (arma::accu(target % target) + arma::accu(
36  prediction % prediction) + smooth));
37 }
38 
39 template<typename InputDataType, typename OutputDataType>
40 template<typename PredictionType, typename TargetType, typename LossType>
42  const PredictionType& prediction,
43  const TargetType& target,
44  LossType& loss)
45 {
46  loss = -2 * (target * (arma::accu(prediction % prediction) +
47  arma::accu(target % target) + smooth) - prediction *
48  (2 * arma::accu(target % prediction) + smooth)) / std::pow(
49  arma::accu(target % target) + arma::accu(prediction % prediction)
50  + smooth, 2.0);
51 }
52 
53 template<typename InputDataType, typename OutputDataType>
54 template<typename Archive>
56  Archive& ar,
57  const uint32_t /* version */)
58 {
59  ar(CEREAL_NVP(smooth));
60 }
61 
62 } // namespace ann
63 } // namespace mlpack
64 
65 #endif
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
DiceLoss(const double smooth=1)
Create the DiceLoss object.
Definition: dice_loss_impl.hpp:22
PredictionType::elem_type Forward(const PredictionType &prediction, const TargetType &target)
Computes the dice loss function.
Definition: dice_loss_impl.hpp:31
void Backward(const PredictionType &prediction, const TargetType &target, LossType &loss)
Ordinary feed backward pass of a neural network.
Definition: dice_loss_impl.hpp:41
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: dice_loss_impl.hpp:55