diff --git a/supervision/metrics/mean_average_precision.py b/supervision/metrics/mean_average_precision.py index 9e7a30d0e..49af70078 100644 --- a/supervision/metrics/mean_average_precision.py +++ b/supervision/metrics/mean_average_precision.py @@ -366,29 +366,30 @@ def _average_precisions_per_class( def _detections_content(self, detections: Detections) -> np.ndarray: """Return boxes, masks or oriented bounding boxes from detections.""" - if self._metric_target == MetricTarget.BOXES: + mt = self._metric_target # cache for speed + if mt is MetricTarget.BOXES: return detections.xyxy - if self._metric_target == MetricTarget.MASKS: + elif mt is MetricTarget.MASKS: + mask_content = detections.mask return ( - detections.mask - if detections.mask is not None - else self._make_empty_content() + mask_content if mask_content is not None else self._make_empty_content() ) - if self._metric_target == MetricTarget.ORIENTED_BOUNDING_BOXES: + elif mt is MetricTarget.ORIENTED_BOUNDING_BOXES: obb = detections.data.get(ORIENTED_BOX_COORDINATES) - if obb is not None and len(obb) > 0: - return np.array(obb, dtype=np.float32) + if obb: # check if obb is not None and non-empty + return np.asarray(obb, dtype=np.float32) return self._make_empty_content() - raise ValueError(f"Invalid metric target: {self._metric_target}") + raise ValueError(f"Invalid metric target: {mt}") def _make_empty_content(self) -> np.ndarray: - if self._metric_target == MetricTarget.BOXES: + mt = self._metric_target # use local cached attribute + if mt is MetricTarget.BOXES: return np.empty((0, 4), dtype=np.float32) - if self._metric_target == MetricTarget.MASKS: + elif mt is MetricTarget.MASKS: return np.empty((0, 0, 0), dtype=bool) - if self._metric_target == MetricTarget.ORIENTED_BOUNDING_BOXES: + elif mt is MetricTarget.ORIENTED_BOUNDING_BOXES: return np.empty((0, 4, 2), dtype=np.float32) - raise ValueError(f"Invalid metric target: {self._metric_target}") + raise ValueError(f"Invalid metric target: {mt}") def _filter_detections_by_size( self, detections: Detections, size_category: ObjectSizeCategory