Skip to content

Commit b8e5b64

Browse files
committed
set the numba cuda global compiler lock to also control the lock of numba
1 parent dc93dac commit b8e5b64

File tree

1 file changed

+57
-1
lines changed

1 file changed

+57
-1
lines changed

numba_cuda/numba/cuda/core/compiler_lock.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,14 @@
44
import threading
55
import functools
66
import numba.cuda.core.event as ev
7+
from numba.cuda import HAS_NUMBA
8+
9+
if HAS_NUMBA:
10+
from numba.core.compiler_lock import (
11+
global_compiler_lock as _numba_compiler_lock,
12+
)
13+
else:
14+
_numba_compiler_lock = None
715

816

917
# Lock for the preventing multiple compiler execution
@@ -50,7 +58,55 @@ def _is_owned(self):
5058
return True
5159

5260

53-
global_compiler_lock = _CompilerLock()
61+
_numba_cuda_compiler_lock = _CompilerLock()
62+
63+
64+
# Wrapper that coordinates both numba and numba-cuda compiler locks
65+
class _DualCompilerLock(object):
66+
"""Wrapper that coordinates both the numba-cuda and upstream numba compiler locks."""
67+
68+
def __init__(self, cuda_lock, numba_lock):
69+
self._cuda_lock = cuda_lock
70+
self._numba_lock = numba_lock
71+
72+
def acquire(self):
73+
if self._numba_lock:
74+
self._numba_lock.acquire()
75+
self._cuda_lock.acquire()
76+
77+
def release(self):
78+
self._cuda_lock.release()
79+
if self._numba_lock:
80+
self._numba_lock.release()
81+
82+
def __enter__(self):
83+
self.acquire()
84+
85+
def __exit__(self, exc_val, exc_type, traceback):
86+
self.release()
87+
88+
def is_locked(self):
89+
cuda_locked = self._cuda_lock.is_locked()
90+
if self._numba_lock:
91+
return cuda_locked and self._numba_lock.is_locked()
92+
return cuda_locked
93+
94+
def __call__(self, func):
95+
@functools.wraps(func)
96+
def _acquire_compile_lock(*args, **kwargs):
97+
with self:
98+
return func(*args, **kwargs)
99+
100+
return _acquire_compile_lock
101+
102+
103+
# Create the global compiler lock, wrapping both locks if numba is available
104+
if HAS_NUMBA:
105+
global_compiler_lock = _DualCompilerLock(
106+
_numba_cuda_compiler_lock, _numba_compiler_lock
107+
)
108+
else:
109+
global_compiler_lock = _numba_cuda_compiler_lock
54110

55111

56112
def require_global_compiler_lock():

0 commit comments

Comments
 (0)