From 781c056d51feafdc313a0c932473a74afdcf5e78 Mon Sep 17 00:00:00 2001 From: narcissus <765188431@qq.com> Date: Thu, 29 May 2025 09:24:52 +0800 Subject: [PATCH] add ascend npu support --- rapid_layout/main.py | 2 ++ rapid_layout/utils/infer_engine.py | 35 ++++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/rapid_layout/main.py b/rapid_layout/main.py index 959867f..9e6665c 100644 --- a/rapid_layout/main.py +++ b/rapid_layout/main.py @@ -51,6 +51,7 @@ def __init__( iou_thres: float = 0.5, use_cuda: bool = False, use_dml: bool = False, + use_cann: bool = False, ): if not self.check_of(conf_thres): raise ValueError(f"conf_thres {conf_thres} is outside of range [0, 1]") @@ -63,6 +64,7 @@ def __init__( "model_path": self.get_model_path(model_type, model_path), "use_cuda": use_cuda, "use_dml": use_dml, + "use_cann": use_cann, } self.session = OrtInferSession(config) labels = self.session.get_character_list() diff --git a/rapid_layout/utils/infer_engine.py b/rapid_layout/utils/infer_engine.py index 6354bbb..8da61b1 100644 --- a/rapid_layout/utils/infer_engine.py +++ b/rapid_layout/utils/infer_engine.py @@ -24,6 +24,7 @@ class EP(Enum): CPU_EP = "CPUExecutionProvider" CUDA_EP = "CUDAExecutionProvider" DIRECTML_EP = "DmlExecutionProvider" + CANN_EP = "CANNExecutionProvider" class OrtInferSession: @@ -35,6 +36,7 @@ def __init__(self, config: Dict[str, Any]): self.cfg_use_cuda = config.get("use_cuda", None) self.cfg_use_dml = config.get("use_dml", None) + self.cfg_use_cann = config.get("use_cann", None) self.had_providers: List[str] = get_available_providers() EP_list = self._get_ep_list() @@ -71,6 +73,16 @@ def _get_ep_list(self) -> List[Tuple[str, Dict[str, Any]]]: } EP_list = [(EP.CPU_EP.value, cpu_provider_opts)] + self.use_cann = self._check_cann() + if self.use_cann: + cann_provider_opts = { + "device_id": 0, + "arena_extend_strategy": "kNextPowerOfTwo", + "enable_cann_graph": True, + "precision_mode": "must_keep_origin_dtype", + } + EP_list.insert(0, (EP.CANN_EP.value, cann_provider_opts)) + cuda_provider_opts = { "device_id": 0, "arena_extend_strategy": "kNextPowerOfTwo", @@ -92,6 +104,23 @@ def _get_ep_list(self) -> List[Tuple[str, Dict[str, Any]]]: EP_list.insert(0, (EP.DIRECTML_EP.value, directml_options)) return EP_list + def _check_cann(self) -> bool: + if not self.cfg_use_cann: + return False + + if EP.CANN_EP.value not in self.had_providers: + self.logger.warning( + "%s is not in available providers (%s). Use %s inference by default.", + EP.CANN_EP.value, + self.had_providers, + self.had_providers[0], + ) + self.logger.info("To use CANNExecutionProvider, you must:") + self.logger.info("1. Install Ascend CANN Toolkit") + self.logger.info("2. Install onnxruntimeċŒ… with CANN support") + return False + return True + def _check_cuda(self) -> bool: if not self.cfg_use_cuda: return False @@ -176,6 +205,12 @@ def _verify_providers(self): session_providers = self.session.get_providers() first_provider = session_providers[0] + if self.use_cann and first_provider != EP.CANN_EP.value: + self.logger.warning( + "%s is not available, fallback to %s", + EP.CANN_EP.value, + first_provider, + ) if self.use_cuda and first_provider != EP.CUDA_EP.value: self.logger.warning( "%s is not avaiable for current env, the inference part is automatically shifted to be executed under %s.",