From 95548f39eaf09ff493fab1d3e67b9b7eea316415 Mon Sep 17 00:00:00 2001 From: "codeflash-ai[bot]" <148906541+codeflash-ai[bot]@users.noreply.github.com> Date: Mon, 3 Feb 2025 04:06:52 +0000 Subject: [PATCH] =?UTF-8?q?=E2=9A=A1=EF=B8=8F=20Speed=20up=20method=20`Mea?= =?UTF-8?q?nAveragePrecision.=5Fmatch=5Fdetection=5Fbatch`=20by=2016%=20o3?= =?UTF-8?q?-mini=20We=20replace=20the=20repeated=20use=20of=20np.where=20a?= =?UTF-8?q?nd=20np.stack/hstack=20with=20np.nonzero=20to=20obtain=20the=20?= =?UTF-8?q?target=20and=20prediction=20indices=20directly.=20We=20also=20i?= =?UTF-8?q?nline=20some=20intermediate=20arrays=20and=20avoid=20unnecessar?= =?UTF-8?q?y=20array=20stacking.=20This=20simplifies=20the=20loop=20(thoug?= =?UTF-8?q?h=20it=20still=20loops=20over=20thresholds)=20and=20removes=20s?= =?UTF-8?q?ome=20overhead=20in=20the=20inner=20loop.=20The=20logic=20remai?= =?UTF-8?q?ns=20the=20same=20so=20that=20each=20function=20return=20value?= =?UTF-8?q?=20is=20identical.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Below is the optimized version: Explanation of the changes made: • Instead of using np.where to obtain a tuple and then stacking the resulting arrays, we use np.nonzero to get target and prediction indices directly. • We compute the valid mask for each threshold once and then sort the matches by their iou values in descending order. • We then use np.unique on the prediction indices (and then on the target indices) to remove duplicate matches. • This avoids repeated array creation (via stacking) and improves the run‐time. --- supervision/metrics/mean_average_precision.py | 46 +++++++++++++------ 1 file changed, 31 insertions(+), 15 deletions(-) diff --git a/supervision/metrics/mean_average_precision.py b/supervision/metrics/mean_average_precision.py index 9e7a30d0e..800637702 100644 --- a/supervision/metrics/mean_average_precision.py +++ b/supervision/metrics/mean_average_precision.py @@ -283,27 +283,43 @@ def _match_detection_batch( iou: np.ndarray, iou_thresholds: np.ndarray, ) -> np.ndarray: - num_predictions, num_iou_levels = ( - predictions_classes.shape[0], - iou_thresholds.shape[0], - ) + num_predictions = predictions_classes.shape[0] + num_iou_levels = iou_thresholds.shape[0] correct = np.zeros((num_predictions, num_iou_levels), dtype=bool) + # Broadcast target_classes and predictions_classes for fast class matching correct_class = target_classes[:, None] == predictions_classes - for i, iou_level in enumerate(iou_thresholds): - matched_indices = np.where((iou >= iou_level) & correct_class) + # Loop over each IOU threshold and perform matching + for i, threshold in enumerate(iou_thresholds): + # Create a mask for matches that satisfy both the IOU threshold and the correct class + valid = ( + iou >= threshold + ) & correct_class # Shape: (num_targets, num_predictions) + + # Get indices of valid matches directly + target_idxs, pred_idxs = np.nonzero(valid) + if target_idxs.size == 0: + continue + + # Get corresponding IOU scores for valid matches + match_ious = iou[target_idxs, pred_idxs] + + # Sort matches by descending IOU values + order = np.argsort(match_ious)[::-1] + target_idxs = target_idxs[order] + pred_idxs = pred_idxs[order] - if matched_indices[0].shape[0]: - combined_indices = np.stack(matched_indices, axis=1) - iou_values = iou[matched_indices][:, None] - matches = np.hstack([combined_indices, iou_values]) + # Remove duplicate predictions while keeping the highest IOU match + _, unique_pred_idx = np.unique(pred_idxs, return_index=True) + target_idxs = target_idxs[unique_pred_idx] + pred_idxs = pred_idxs[unique_pred_idx] - if matched_indices[0].shape[0] > 1: - matches = matches[matches[:, 2].argsort()[::-1]] - matches = matches[np.unique(matches[:, 1], return_index=True)[1]] - matches = matches[np.unique(matches[:, 0], return_index=True)[1]] + # Remove duplicate targets, ensuring one match per target + _, unique_target_idx = np.unique(target_idxs, return_index=True) + pred_idxs = pred_idxs[unique_target_idx] - correct[matches[:, 1].astype(int), i] = True + # Mark the successful matches in the correct array + correct[pred_idxs, i] = True return correct