Skip to content

Commit a3a75f0

Browse files
committed
Make confidential_compute_enabled() more robust and complete
Signed-off-by: Dan Hansen <[email protected]>
1 parent 5abeac6 commit a3a75f0

File tree

1 file changed

+41
-6
lines changed

1 file changed

+41
-6
lines changed

tensorrt_llm/_utils.py

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1167,15 +1167,50 @@ def set_prometheus_multiproc_dir() -> object:
11671167
logger.info(
11681168
f"PROMETHEUS_MULTIPROC_DIR: {os.environ['PROMETHEUS_MULTIPROC_DIR']}")
11691169

1170+
11701171
def confidential_compute_enabled() -> bool:
1171-
import pynvml
1172+
"""
1173+
Query NVML for the confidential compute state
1174+
"""
1175+
1176+
cc_enabled = False
1177+
1178+
try:
1179+
# Init
1180+
import pynvml
1181+
pynvml.nvmlInit()
1182+
1183+
# Hopper and newer supports a more nuanced query of confidential
1184+
# compute settings
1185+
cc_settings = pynvml.c_nvmlSystemConfComputeSettings_v1_t()
1186+
if (pynvml.nvmlSystemGetConfComputeSettings(cc_settings) ==
1187+
pynvml.NVML_SUCCESS):
1188+
cc_enabled = (cc_settings.ccFeature
1189+
== pynvml.NVML_CC_SYSTEM_FEATURE_ENABLED
1190+
or cc_settings.multiGpuMode
1191+
== pynvml.NVML_CC_SYSTEM_MULTIGPU_PROTECTED_PCIE
1192+
or cc_settings.multiGpuMode
1193+
== pynvml.NVML_CC_SYSTEM_MULTIGPU_NVLE)
1194+
except pynvml.NVMLError_NotSupported:
1195+
# Simple query for older GPUs
1196+
try:
1197+
cc_state = pynvml.nvmlSystemGetConfComputeState()
1198+
cc_enabled = (
1199+
cc_state.ccFeature == pynvml.NVML_CC_SYSTEM_FEATURE_ENABLED)
1200+
except Exception as e:
1201+
logger.error(f"Error querying confidential compute state: {str(e)}")
1202+
except Exception as e:
1203+
logger.error(f"Error querying confidential compute state: {str(e)}")
1204+
finally:
1205+
# Shutdown
1206+
try:
1207+
pynvml.nvmlShutdown()
1208+
except:
1209+
# Ignore shutdown errors
1210+
pass
11721211

1173-
# Determine if Confidential Compute is enabled
1174-
pynvml.nvmlInit()
1175-
conf_compute_enabled = bool(pynvml.nvmlSystemGetConfComputeState().ccFeature)
1176-
pynvml.nvmlShutdown()
1212+
return cc_enabled
11771213

1178-
return conf_compute_enabled
11791214

11801215
P = ParamSpec("P")
11811216

0 commit comments

Comments
 (0)