12 #ifndef MLPACK_CORE_METRICS_NMS_IMPL_HPP 13 #define MLPACK_CORE_METRICS_NMS_IMPL_HPP 21 template<
bool UseCoordinates>
23 typename BoundingBoxesType,
24 typename ConfidenceScoreType,
28 const BoundingBoxesType& boundingBoxes,
29 const ConfidenceScoreType& confidenceScores,
30 OutputType& selectedIndices,
31 const double threshold)
33 Log::Assert(boundingBoxes.n_rows == 4,
"Bounding boxes must \ 34 contain only 4 rows determining coordinates of bounding \ 35 box either in {x1, y1, x2, y2} or {x1, y1, h, w} format.\ 36 Refer to the documentation for more information.");
38 Log::Assert(confidenceScores.n_cols != boundingBoxes.n_cols,
"Each \ 39 bounding box must correspond to atleast and only 1 bounding box. \ 40 Found " + std::to_string(confidenceScores.n_cols) +
" confidence \ 41 scores for " + std::to_string(boundingBoxes.n_cols) +
" bounding boxes.");
44 selectedIndices.clear();
48 arma::ucolvec sortedIndices = arma::sort_index(confidenceScores);
54 area = (boundingBoxes.row(2) - boundingBoxes.row(0)) %
55 (boundingBoxes.row(3) - boundingBoxes.row(1));
59 area = boundingBoxes.row(2) % boundingBoxes.row(3);
62 while (sortedIndices.n_elem > 0)
64 size_t selectedIndex = sortedIndices(sortedIndices.n_elem - 1);
67 selectedIndices.insert_rows(0, arma::uvec(1).fill(selectedIndex));
70 if (sortedIndices.n_elem == 1)
76 sortedIndices = sortedIndices(arma::span(0, sortedIndices.n_rows - 2),
80 BoundingBoxesType x2 = boundingBoxes.submat(arma::uvec(1).fill(2),
83 BoundingBoxesType x1 = boundingBoxes.submat(arma::uvec(1).fill(0),
86 BoundingBoxesType y2 = boundingBoxes.submat(arma::uvec(1).fill(3),
89 BoundingBoxesType y1 = boundingBoxes.submat(arma::uvec(1).fill(1),
92 size_t selectedX2 = boundingBoxes(2, selectedIndex);
93 size_t selectedY2 = boundingBoxes(3, selectedIndex);
94 size_t selectedX1 = boundingBoxes(0, selectedIndex);
95 size_t selectedY1 = boundingBoxes(1, selectedIndex);
102 selectedX2 = selectedX2 + selectedX1;
103 selectedY2 = selectedY2 + selectedY1;
108 x2 = arma::clamp(x2, DBL_MIN, selectedX2);
109 y2 = arma::clamp(y2, DBL_MIN, selectedY2);
110 x1 = arma::clamp(x1, selectedX1, DBL_MAX);
111 y1 = arma::clamp(y1, selectedY1, DBL_MAX);
113 BoundingBoxesType intersectionArea = arma::clamp(x2 - x1, 0.0, DBL_MAX) %
114 arma::clamp(y2 - y1, 0.0, DBL_MAX);
118 BoundingBoxesType calculateIoU = intersectionArea /
119 (area(sortedIndices).t() - intersectionArea + area(selectedIndex));
121 sortedIndices = sortedIndices(arma::find(calculateIoU <= threshold));
124 selectedIndices = arma::flipud(selectedIndices);
127 template<
bool UseCoordinates>
128 template<
typename Archive>
void serialize(Archive &ar, const uint32_t)
Serialize the metric.
Definition: non_maximal_supression_impl.hpp:129
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
static void Evaluate(const BoundingBoxesType &boundingBoxes, const ConfidenceScoreType &confidenceScores, OutputType &selectedIndices, const double threshold=0.5)
Performs non-maximal suppression.
Definition: non_maximal_supression_impl.hpp:27
static void Assert(bool condition, const std::string &message="Assert Failed.")
Checks if the specified condition is true.
Definition: log.cpp:38