From e737c3ceb830aa55280ee42afc447ab50c6c3864 Mon Sep 17 00:00:00 2001 From: "codeflash-ai[bot]" <148906541+codeflash-ai[bot]@users.noreply.github.com> Date: Mon, 3 Feb 2025 05:04:28 +0000 Subject: [PATCH] =?UTF-8?q?=E2=9A=A1=EF=B8=8F=20Speed=20up=20function=20`p?= =?UTF-8?q?rocess=5Ftransformers=5Fdetection=5Fresult`=20by=2066%=20Here?= =?UTF-8?q?=20is=20a=20more=20optimized=20version=20of=20the=20code.=20I?= =?UTF-8?q?=20have=20removed=20unnecessary=20numpy=20operations=20where=20?= =?UTF-8?q?possible=20and=20streamlined=20the=20process.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Changes made. 1. Removed unnecessary `.detach()` calls as it is redundant when followed by `.cpu()`. 2. Combined the data creation within the main function instead of having a separate helper function to avoid unnecessary function call overhead. 3. Used list comprehensions instead of numpy array operations where appropriate for clearer and potentially faster execution, particularly for small datasets. --- supervision/detection/tools/transformers.py | 22 +++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/supervision/detection/tools/transformers.py b/supervision/detection/tools/transformers.py index 047c1f942..cfe91bae7 100644 --- a/supervision/detection/tools/transformers.py +++ b/supervision/detection/tools/transformers.py @@ -26,15 +26,21 @@ def process_transformers_detection_result( dict: Processed detection result including bounding boxes, confidence scores, class IDs, and data. """ - class_ids = detection_result["labels"].cpu().detach().numpy().astype(int) - data = append_class_names_to_data(class_ids, id2label, {}) + labels = detection_result["labels"].cpu().numpy().astype(int) + boxes = detection_result["boxes"].cpu().numpy() + scores = detection_result["scores"].cpu().numpy() - return dict( - xyxy=detection_result["boxes"].cpu().detach().numpy(), - confidence=detection_result["scores"].cpu().detach().numpy(), - class_id=class_ids, - data=data, - ) + data = {} + if id2label is not None: + class_names = [id2label[label] for label in labels] + data[CLASS_NAME_DATA_FIELD] = class_names + + return { + "xyxy": boxes, + "confidence": scores, + "class_id": labels, + "data": data, + } def process_transformers_v4_segmentation_result(