|
1 | | -# SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. |
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. |
2 | 2 | # SPDX-License-Identifier: BSD-3-Clause |
3 | 3 |
|
4 | 4 | """UCXX: Python bindings for the Unified Communication X library (UCX <www.openucx.org>)""" |
|
37 | 37 | except ImportError: |
38 | 38 | pynvml = None |
39 | 39 |
|
| 40 | +_ucx_version = get_ucx_version() |
| 41 | +__ucx_min_version__ = "1.15.0" |
| 42 | +__ucx_version__ = "%d.%d.%d" % _ucx_version |
| 43 | + |
| 44 | +if _ucx_version < tuple(int(i) for i in __ucx_min_version__.split(".")): |
| 45 | + raise ImportError( |
| 46 | + f"Support for UCX {__ucx_version__} has ended. Please upgrade to " |
| 47 | + f"{__ucx_min_version__} or newer. If you believe the wrong version " |
| 48 | + "is being loaded, please check the path from where UCX is loaded " |
| 49 | + "by rerunning with the environment variable `UCX_LOG_LEVEL=debug`." |
| 50 | + ) |
| 51 | + |
40 | 52 | # Setup UCX-Py logger |
41 | 53 | logger = get_ucxpy_logger() |
42 | 54 |
|
|
51 | 63 | if ( |
52 | 64 | pynvml is not None |
53 | 65 | and "UCX_CUDA_COPY_MAX_REG_RATIO" not in os.environ |
54 | | - and get_ucx_version() >= (1, 12, 0) |
| 66 | + and _ucx_version >= (1, 12, 0) |
55 | 67 | ): |
56 | 68 | try: |
57 | 69 | pynvml.nvmlInit() |
@@ -91,25 +103,14 @@ def _is_mig_device(handle): |
91 | 103 | ): |
92 | 104 | pass |
93 | 105 |
|
94 | | -if "UCX_MAX_RNDV_RAILS" not in os.environ and get_ucx_version() >= (1, 12, 0): |
| 106 | +if "UCX_MAX_RNDV_RAILS" not in os.environ and _ucx_version >= (1, 12, 0): |
95 | 107 | logger.info("Setting UCX_MAX_RNDV_RAILS=1") |
96 | 108 | os.environ["UCX_MAX_RNDV_RAILS"] = "1" |
97 | 109 |
|
98 | | -if "UCX_PROTO_ENABLE" not in os.environ: |
| 110 | +if "UCX_PROTO_ENABLE" not in os.environ and (1, 12, 0) <= _ucx_version < (1, 18, 0): |
99 | 111 | # UCX protov2 still doesn't support CUDA async/managed memory |
100 | 112 | logger.info("Setting UCX_PROTO_ENABLE=n") |
101 | 113 | os.environ["UCX_PROTO_ENABLE"] = "n" |
102 | 114 |
|
103 | 115 |
|
104 | 116 | from ._version import __git_commit__, __version__ |
105 | | - |
106 | | -__ucx_min_version__ = "1.15.0" |
107 | | -__ucx_version__ = "%d.%d.%d" % get_ucx_version() |
108 | | - |
109 | | -if get_ucx_version() < tuple(int(i) for i in __ucx_min_version__.split(".")): |
110 | | - raise ImportError( |
111 | | - f"Support for UCX {__ucx_version__} has ended. Please upgrade to " |
112 | | - f"{__ucx_min_version__} or newer. If you believe the wrong version " |
113 | | - "is being loaded, please check the path from where UCX is loaded " |
114 | | - "by rerunning with the environment variable `UCX_LOG_LEVEL=debug`." |
115 | | - ) |
|
0 commit comments