mlpack
non_maximal_supression_impl.hpp
1 
12 #ifndef MLPACK_CORE_METRICS_NMS_IMPL_HPP
13 #define MLPACK_CORE_METRICS_NMS_IMPL_HPP
14 
15 // In case it hasn't been included.
17 
18 namespace mlpack {
19 namespace metric {
20 
21 template<bool UseCoordinates>
22 template<
23  typename BoundingBoxesType,
24  typename ConfidenceScoreType,
25  typename OutputType
26 >
28  const BoundingBoxesType& boundingBoxes,
29  const ConfidenceScoreType& confidenceScores,
30  OutputType& selectedIndices,
31  const double threshold)
32 {
33  Log::Assert(boundingBoxes.n_rows == 4, "Bounding boxes must \
34  contain only 4 rows determining coordinates of bounding \
35  box either in {x1, y1, x2, y2} or {x1, y1, h, w} format.\
36  Refer to the documentation for more information.");
37 
38  Log::Assert(confidenceScores.n_cols != boundingBoxes.n_cols, "Each \
39  bounding box must correspond to atleast and only 1 bounding box. \
40  Found " + std::to_string(confidenceScores.n_cols) + " confidence \
41  scores for " + std::to_string(boundingBoxes.n_cols) + " bounding boxes.");
42 
43  // Clear selected bounding boxes.
44  selectedIndices.clear();
45 
46  // Obtain Sorted indices for bounding boxes according to
47  // their confidence scores.
48  arma::ucolvec sortedIndices = arma::sort_index(confidenceScores);
49 
50  // Pre-Compute area of each bounding box.
51  arma::mat area;
52  if (UseCoordinates)
53  {
54  area = (boundingBoxes.row(2) - boundingBoxes.row(0)) %
55  (boundingBoxes.row(3) - boundingBoxes.row(1));
56  }
57  else
58  {
59  area = boundingBoxes.row(2) % boundingBoxes.row(3);
60  }
61 
62  while (sortedIndices.n_elem > 0)
63  {
64  size_t selectedIndex = sortedIndices(sortedIndices.n_elem - 1);
65 
66  // Choose the box with the largest probability.
67  selectedIndices.insert_rows(0, arma::uvec(1).fill(selectedIndex));
68 
69  // Check if there are other bounding boxes to compare with.
70  if (sortedIndices.n_elem == 1)
71  {
72  break;
73  }
74 
75  // Remove the last index.
76  sortedIndices = sortedIndices(arma::span(0, sortedIndices.n_rows - 2),
77  arma::span());
78 
79  // Get x and y coordinates for remaining bounding boxes.
80  BoundingBoxesType x2 = boundingBoxes.submat(arma::uvec(1).fill(2),
81  sortedIndices);
82 
83  BoundingBoxesType x1 = boundingBoxes.submat(arma::uvec(1).fill(0),
84  sortedIndices);;
85 
86  BoundingBoxesType y2 = boundingBoxes.submat(arma::uvec(1).fill(3),
87  sortedIndices);
88 
89  BoundingBoxesType y1 = boundingBoxes.submat(arma::uvec(1).fill(1),
90  sortedIndices);
91 
92  size_t selectedX2 = boundingBoxes(2, selectedIndex);
93  size_t selectedY2 = boundingBoxes(3, selectedIndex);
94  size_t selectedX1 = boundingBoxes(0, selectedIndex);
95  size_t selectedY1 = boundingBoxes(1, selectedIndex);
96 
97  if (!UseCoordinates)
98  {
99  // Change height - width representation to coordinate represention.
100  x2 = x2 + x1;
101  y2 = y2 + y1;
102  selectedX2 = selectedX2 + selectedX1;
103  selectedY2 = selectedY2 + selectedY1;
104  }
105 
106  // Calculate points of intersection between the bounding box with
107  // highest confidence score and remaining bounding boxes.
108  x2 = arma::clamp(x2, DBL_MIN, selectedX2);
109  y2 = arma::clamp(y2, DBL_MIN, selectedY2);
110  x1 = arma::clamp(x1, selectedX1, DBL_MAX);
111  y1 = arma::clamp(y1, selectedY1, DBL_MAX);
112 
113  BoundingBoxesType intersectionArea = arma::clamp(x2 - x1, 0.0, DBL_MAX) %
114  arma::clamp(y2 - y1, 0.0, DBL_MAX);
115 
116  // Calculate IoU of remaining boxes with the last bounding box with
117  // the highest confidence score.
118  BoundingBoxesType calculateIoU = intersectionArea /
119  (area(sortedIndices).t() - intersectionArea + area(selectedIndex));
120 
121  sortedIndices = sortedIndices(arma::find(calculateIoU <= threshold));
122  }
123 
124  selectedIndices = arma::flipud(selectedIndices);
125 }
126 
127 template<bool UseCoordinates>
128 template<typename Archive>
130  Archive& /* ar */,
131  const uint32_t /* version */)
132 {
133  // Nothing to do here.
134 }
135 
136 } // namespace metric
137 } // namespace mlpack
138 #endif
void serialize(Archive &ar, const uint32_t)
Serialize the metric.
Definition: non_maximal_supression_impl.hpp:129
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
static void Evaluate(const BoundingBoxesType &boundingBoxes, const ConfidenceScoreType &confidenceScores, OutputType &selectedIndices, const double threshold=0.5)
Performs non-maximal suppression.
Definition: non_maximal_supression_impl.hpp:27
static void Assert(bool condition, const std::string &message="Assert Failed.")
Checks if the specified condition is true.
Definition: log.cpp:38