Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 50 additions & 4 deletions modelopt/onnx/quantization/ort_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@
"""Provides basic ORT inference utils, should be replaced by modelopt.torch.ort_client."""

import glob
import io
import os
import platform
import sys
from collections.abc import Sequence
from contextlib import redirect_stderr, redirect_stdout

import onnxruntime as ort
from onnxruntime.quantization.operators.qdq_base_operator import QDQOperatorBase
Expand Down Expand Up @@ -70,11 +73,54 @@ def _check_for_libcudnn():
f" for your ORT version at https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#requirements."
)
else:
logger.error(f"cuDNN library not found in {env_variable}")
# Fallback: ORT >=1.20 ships a preload_dlls() helper that loads CUDA/cuDNN
# DLLs bundled inside pip packages (e.g. nvidia-cudnn-cu12) so they don't
# need to be on the system PATH / LD_LIBRARY_PATH.
# However, preload_dlls() is broken on Python 3.10 (missing os.add_dll_directory
# behaviour), so we skip it for that version.
if hasattr(ort, "preload_dlls") and sys.version_info[:2] != (3, 10):
logger.warning(
f"cuDNN not found in {env_variable}. "
"Attempting onnxruntime.preload_dlls() to load from site-packages..."
)
# preload_dlls() does not raise on failure — it silently prints
# "Failed to load ..." messages. Capture its output and check
# whether the key cuDNN DLL actually loaded.
cudnn_dll = "cudnn" if platform.system() == "Windows" else "libcudnn_adv"
captured = io.StringIO()
try:
with redirect_stdout(captured), redirect_stderr(captured):
ort.preload_dlls()
except Exception as e:
logger.warning(f"onnxruntime.preload_dlls() raised an exception: {e}")

preload_output = captured.getvalue()
if preload_output:
logger.debug(f"preload_dlls() output:\n{preload_output}")

if f"Failed to load {cudnn_dll}" in preload_output:
logger.error(
f"onnxruntime.preload_dlls() was called but {cudnn_dll} failed to load. "
"cuDNN DLLs were NOT successfully loaded from site-packages."
)
else:
logger.info(
"onnxruntime.preload_dlls() succeeded — CUDA/cuDNN DLLs loaded"
" from site-packages. Verify version compatibility at"
" https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#requirements."
)
return True

raise FileNotFoundError(
f"{lib_pattern} is not accessible in {env_variable}! Please make sure that the path to that library"
f" is in the env var to use the CUDA or TensorRT EP and ensure that the correct version is available."
f" Versioning compatibility can be checked at https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#requirements."
f"{lib_pattern} is not accessible via {env_variable} or site-packages.\n"
f"To fix this, either:\n"
f" 1. Add the directory containing {lib_pattern} to your"
f" {env_variable} env var, or\n"
f" 2. Install the cuDNN pip package (Python>=3.11 only):"
f" pip install nvidia-cudnn-cu12 (or nvidia-cudnn-cu13)\n"
f"This is required for the CUDA / TensorRT execution provider.\n"
f"Check version compatibility at"
f" https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#requirements."
)
return found

Expand Down
Loading