Skip to content

Commit 3e554a1

Browse files
committed
Import numba.core.entrypoints if available
1 parent 780d7ad commit 3e554a1

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

numba_cuda/numba/cuda/core/entrypoints.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,17 @@
1212

1313

1414
def init_all():
15-
"""Execute all `numba_extensions` entry points with the name `init`
15+
"""Execute all `numba_cuda_extensions` entry points with the name `init`
1616
1717
If extensions have already been initialized, this function does nothing.
1818
"""
19+
try:
20+
from numba.core import entrypoints
21+
22+
entrypoints.init_all()
23+
except ImportError:
24+
pass
25+
1926
global _already_initialized
2027
if _already_initialized:
2128
return
@@ -42,9 +49,11 @@ def load_ep(entry_point):
4249
# interface, versions prior to that do not. See "compatibility note" in:
4350
# https://docs.python.org/3.10/library/importlib.metadata.html#entry-points
4451
if hasattr(eps, "select"):
45-
for entry_point in eps.select(group="numba_extensions", name="init"):
52+
for entry_point in eps.select(
53+
group="numba_cuda_extensions", name="init"
54+
):
4655
load_ep(entry_point)
4756
else:
48-
for entry_point in eps.get("numba_extensions", ()):
57+
for entry_point in eps.get("numba_cuda_extensions", ()):
4958
if entry_point.name == "init":
5059
load_ep(entry_point)

0 commit comments

Comments
 (0)