Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions rapid_layout/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(
iou_thres: float = 0.5,
use_cuda: bool = False,
use_dml: bool = False,
use_cann: bool = False,
Copy link
Preview

Copilot AI May 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The new use_cann parameter is not documented in the function docstring or CLI help; consider updating documentation to describe this option.

Copilot uses AI. Check for mistakes.

):
if not self.check_of(conf_thres):
raise ValueError(f"conf_thres {conf_thres} is outside of range [0, 1]")
Expand All @@ -63,6 +64,7 @@ def __init__(
"model_path": self.get_model_path(model_type, model_path),
"use_cuda": use_cuda,
"use_dml": use_dml,
"use_cann": use_cann,
}
self.session = OrtInferSession(config)
labels = self.session.get_character_list()
Expand Down
35 changes: 35 additions & 0 deletions rapid_layout/utils/infer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class EP(Enum):
CPU_EP = "CPUExecutionProvider"
CUDA_EP = "CUDAExecutionProvider"
DIRECTML_EP = "DmlExecutionProvider"
CANN_EP = "CANNExecutionProvider"


class OrtInferSession:
Expand All @@ -35,6 +36,7 @@ def __init__(self, config: Dict[str, Any]):

self.cfg_use_cuda = config.get("use_cuda", None)
self.cfg_use_dml = config.get("use_dml", None)
self.cfg_use_cann = config.get("use_cann", None)

self.had_providers: List[str] = get_available_providers()
EP_list = self._get_ep_list()
Expand Down Expand Up @@ -71,6 +73,16 @@ def _get_ep_list(self) -> List[Tuple[str, Dict[str, Any]]]:
}
EP_list = [(EP.CPU_EP.value, cpu_provider_opts)]

self.use_cann = self._check_cann()
Copy link
Preview

Copilot AI May 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Provider initialization follows a similar pattern across _check_* methods and EP_list insertion; consider refactoring common logic into a helper to reduce duplication.

Copilot uses AI. Check for mistakes.

if self.use_cann:
cann_provider_opts = {
"device_id": 0,
"arena_extend_strategy": "kNextPowerOfTwo",
"enable_cann_graph": True,
"precision_mode": "must_keep_origin_dtype",
}
EP_list.insert(0, (EP.CANN_EP.value, cann_provider_opts))

cuda_provider_opts = {
"device_id": 0,
"arena_extend_strategy": "kNextPowerOfTwo",
Expand All @@ -92,6 +104,23 @@ def _get_ep_list(self) -> List[Tuple[str, Dict[str, Any]]]:
EP_list.insert(0, (EP.DIRECTML_EP.value, directml_options))
return EP_list

def _check_cann(self) -> bool:
if not self.cfg_use_cann:
return False

if EP.CANN_EP.value not in self.had_providers:
self.logger.warning(
"%s is not in available providers (%s). Use %s inference by default.",
EP.CANN_EP.value,
self.had_providers,
self.had_providers[0],
)
self.logger.info("To use CANNExecutionProvider, you must:")
self.logger.info("1. Install Ascend CANN Toolkit")
self.logger.info("2. Install onnxruntime包 with CANN support")
return False
return True

def _check_cuda(self) -> bool:
if not self.cfg_use_cuda:
return False
Expand Down Expand Up @@ -176,6 +205,12 @@ def _verify_providers(self):
session_providers = self.session.get_providers()
first_provider = session_providers[0]

if self.use_cann and first_provider != EP.CANN_EP.value:
self.logger.warning(
"%s is not available, fallback to %s",
EP.CANN_EP.value,
first_provider,
)
if self.use_cuda and first_provider != EP.CUDA_EP.value:
self.logger.warning(
"%s is not avaiable for current env, the inference part is automatically shifted to be executed under %s.",
Expand Down