Skip to content

Commit 7df62ce

Browse files
events
1 parent 324a48a commit 7df62ce

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

numba_cuda/numba/cuda/cudadrv/driver.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2221,7 +2221,7 @@ def record(self, stream=0):
22212221
queued in the stream at the time of the call to ``record()`` has been
22222222
completed.
22232223
"""
2224-
hstream = stream.handle.value if stream else binding.CUstream(0)
2224+
hstream = _stream_handle(stream)
22252225
handle = self.handle.value
22262226
driver.cuEventRecord(handle, hstream)
22272227

@@ -2236,7 +2236,7 @@ def wait(self, stream=0):
22362236
"""
22372237
All future works submitted to stream will wait util the event completes.
22382238
"""
2239-
hstream = stream.handle.value if stream else binding.CUstream(0)
2239+
hstream = _stream_handle(stream)
22402240
handle = self.handle.value
22412241
flags = 0
22422242
driver.cuStreamWaitEvent(hstream, handle, flags)

numba_cuda/numba/cuda/tests/cudadrv/test_events.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import numpy as np
55
from numba import cuda
66
from numba.cuda.testing import unittest, CUDATestCase
7+
from cuda.core.experimental import Device
78

89

910
class TestCudaEvent(CUDATestCase):
@@ -22,8 +23,17 @@ def test_event_elapsed(self):
2223
evtstart.elapsed_time(evtend)
2324

2425
def test_event_elapsed_stream(self):
25-
N = 32
2626
stream = cuda.stream()
27+
self.event_elapsed_inner(stream)
28+
29+
def test_event_elapsed_cuda_core_stream(self):
30+
dev = Device()
31+
dev.set_current()
32+
stream = dev.create_stream()
33+
self.event_elapsed_inner(stream)
34+
35+
def event_elapsed_inner(self, stream):
36+
N = 32
2737
dary = cuda.device_array(N, dtype=np.double)
2838
evtstart = cuda.event()
2939
evtend = cuda.event()

0 commit comments

Comments
 (0)