12 #ifndef MLPACK_METHODS_ANN_LOSS_FUNCTIONS_DICE_LOSS_IMPL_HPP 13 #define MLPACK_METHODS_ANN_LOSS_FUNCTIONS_DICE_LOSS_IMPL_HPP 21 template<
typename InputDataType,
typename OutputDataType>
23 const double smooth) : smooth(smooth)
28 template<
typename InputDataType,
typename OutputDataType>
29 template<
typename PredictionType,
typename TargetType>
32 const TargetType& target)
34 return 1 - ((2 * arma::accu(target % prediction) + smooth) /
35 (arma::accu(target % target) + arma::accu(
36 prediction % prediction) + smooth));
39 template<
typename InputDataType,
typename OutputDataType>
40 template<
typename PredictionType,
typename TargetType,
typename LossType>
42 const PredictionType& prediction,
43 const TargetType& target,
46 loss = -2 * (target * (arma::accu(prediction % prediction) +
47 arma::accu(target % target) + smooth) - prediction *
48 (2 * arma::accu(target % prediction) + smooth)) / std::pow(
49 arma::accu(target % target) + arma::accu(prediction % prediction)
53 template<
typename InputDataType,
typename OutputDataType>
54 template<
typename Archive>
59 ar(CEREAL_NVP(smooth));
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
DiceLoss(const double smooth=1)
Create the DiceLoss object.
Definition: dice_loss_impl.hpp:22
PredictionType::elem_type Forward(const PredictionType &prediction, const TargetType &target)
Computes the dice loss function.
Definition: dice_loss_impl.hpp:31
void Backward(const PredictionType &prediction, const TargetType &target, LossType &loss)
Ordinary feed backward pass of a neural network.
Definition: dice_loss_impl.hpp:41
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: dice_loss_impl.hpp:55