diff --git a/supervision/metrics/utils/object_size.py b/supervision/metrics/utils/object_size.py index 162879224..6c1e94089 100644 --- a/supervision/metrics/utils/object_size.py +++ b/supervision/metrics/utils/object_size.py @@ -98,7 +98,7 @@ def get_mask_size_category(mask: npt.NDArray[np.bool_]) -> npt.NDArray[np.int_]: def get_obb_size_category(xyxyxyxy: npt.NDArray[np.float32]) -> npt.NDArray[np.int_]: """ - Get the size category of a oriented bounding boxes array. + Get the size category of an oriented bounding boxes array. Args: xyxyxyxy (np.ndarray): The bounding boxes array shaped (N, 4, 2). @@ -107,25 +107,34 @@ def get_obb_size_category(xyxyxyxy: npt.NDArray[np.float32]) -> npt.NDArray[np.i (np.ndarray) The size category of each bounding box, matching the enum values of ObjectSizeCategory. Shaped (N,). """ - if len(xyxyxyxy.shape) != 3 or xyxyxyxy.shape[1] != 4 or xyxyxyxy.shape[2] != 2: + if xyxyxyxy.shape != (len(xyxyxyxy), 4, 2): raise ValueError("Oriented bounding boxes must be shaped (N, 4, 2)") # Shoelace formula x = xyxyxyxy[:, :, 0] y = xyxyxyxy[:, :, 1] - x1, x2, x3, x4 = x.T - y1, y2, y3, y4 = y.T areas = 0.5 * np.abs( - (x1 * y2 + x2 * y3 + x3 * y4 + x4 * y1) - - (x2 * y1 + x3 * y2 + x4 * y3 + x1 * y4) + x[:, 0] * y[:, 1] + + x[:, 1] * y[:, 2] + + x[:, 2] * y[:, 3] + + x[:, 3] * y[:, 0] + - ( + x[:, 1] * y[:, 0] + + x[:, 2] * y[:, 1] + + x[:, 3] * y[:, 2] + + x[:, 0] * y[:, 3] + ) ) - result = np.full(areas.shape, ObjectSizeCategory.ANY.value) SM, LG = SIZE_THRESHOLDS - result[areas < SM] = ObjectSizeCategory.SMALL.value - result[(areas >= SM) & (areas < LG)] = ObjectSizeCategory.MEDIUM.value - result[areas >= LG] = ObjectSizeCategory.LARGE.value - return result + categories = np.where( + areas < SM, + ObjectSizeCategory.SMALL.value, + np.where( + areas < LG, ObjectSizeCategory.MEDIUM.value, ObjectSizeCategory.LARGE.value + ), + ) + return categories def get_detection_size_category(