diff --git a/supervision/metrics/mean_average_precision.py b/supervision/metrics/mean_average_precision.py index 9e7a30d0e..800637702 100644 --- a/supervision/metrics/mean_average_precision.py +++ b/supervision/metrics/mean_average_precision.py @@ -283,27 +283,43 @@ def _match_detection_batch( iou: np.ndarray, iou_thresholds: np.ndarray, ) -> np.ndarray: - num_predictions, num_iou_levels = ( - predictions_classes.shape[0], - iou_thresholds.shape[0], - ) + num_predictions = predictions_classes.shape[0] + num_iou_levels = iou_thresholds.shape[0] correct = np.zeros((num_predictions, num_iou_levels), dtype=bool) + # Broadcast target_classes and predictions_classes for fast class matching correct_class = target_classes[:, None] == predictions_classes - for i, iou_level in enumerate(iou_thresholds): - matched_indices = np.where((iou >= iou_level) & correct_class) + # Loop over each IOU threshold and perform matching + for i, threshold in enumerate(iou_thresholds): + # Create a mask for matches that satisfy both the IOU threshold and the correct class + valid = ( + iou >= threshold + ) & correct_class # Shape: (num_targets, num_predictions) + + # Get indices of valid matches directly + target_idxs, pred_idxs = np.nonzero(valid) + if target_idxs.size == 0: + continue + + # Get corresponding IOU scores for valid matches + match_ious = iou[target_idxs, pred_idxs] + + # Sort matches by descending IOU values + order = np.argsort(match_ious)[::-1] + target_idxs = target_idxs[order] + pred_idxs = pred_idxs[order] - if matched_indices[0].shape[0]: - combined_indices = np.stack(matched_indices, axis=1) - iou_values = iou[matched_indices][:, None] - matches = np.hstack([combined_indices, iou_values]) + # Remove duplicate predictions while keeping the highest IOU match + _, unique_pred_idx = np.unique(pred_idxs, return_index=True) + target_idxs = target_idxs[unique_pred_idx] + pred_idxs = pred_idxs[unique_pred_idx] - if matched_indices[0].shape[0] > 1: - matches = matches[matches[:, 2].argsort()[::-1]] - matches = matches[np.unique(matches[:, 1], return_index=True)[1]] - matches = matches[np.unique(matches[:, 0], return_index=True)[1]] + # Remove duplicate targets, ensuring one match per target + _, unique_target_idx = np.unique(target_idxs, return_index=True) + pred_idxs = pred_idxs[unique_target_idx] - correct[matches[:, 1].astype(int), i] = True + # Mark the successful matches in the correct array + correct[pred_idxs, i] = True return correct