Skip to content

Commit 8427ae1

Browse files
authored
Merge branch 'main' into atmn/vendor-in-errors
2 parents b5819be + 39066c7 commit 8427ae1

File tree

4 files changed

+138
-39
lines changed

4 files changed

+138
-39
lines changed

numba_cuda/numba/cuda/cudadrv/driver.py

Lines changed: 61 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,12 @@
6464
ObjectCode,
6565
)
6666

67+
from cuda.bindings.utils import get_cuda_native_handle
68+
from cuda.core.experimental import (
69+
Stream as ExperimentalStream,
70+
)
71+
72+
6773
# There is no definition of the default stream in the Nvidia bindings (nor
6874
# is there at the C/C++ level), so we define it here so we don't need to
6975
# use a magic number 0 in places where we want the default stream.
@@ -2064,6 +2070,11 @@ def __int__(self):
20642070
# The default stream's handle.value is 0, which gives `None`
20652071
return self.handle.value or drvapi.CU_STREAM_DEFAULT
20662072

2073+
def __cuda_stream__(self):
2074+
if not self.handle.value:
2075+
return (0, drvapi.CU_STREAM_DEFAULT)
2076+
return (0, self.handle.value)
2077+
20672078
def __repr__(self):
20682079
default_streams = {
20692080
drvapi.CU_STREAM_DEFAULT: "<Default CUDA stream on %s>",
@@ -2210,7 +2221,7 @@ def record(self, stream=0):
22102221
queued in the stream at the time of the call to ``record()`` has been
22112222
completed.
22122223
"""
2213-
hstream = stream.handle.value if stream else binding.CUstream(0)
2224+
hstream = _stream_handle(stream)
22142225
handle = self.handle.value
22152226
driver.cuEventRecord(handle, hstream)
22162227

@@ -2225,7 +2236,7 @@ def wait(self, stream=0):
22252236
"""
22262237
All future works submitted to stream will wait util the event completes.
22272238
"""
2228-
hstream = stream.handle.value if stream else binding.CUstream(0)
2239+
hstream = _stream_handle(stream)
22292240
handle = self.handle.value
22302241
flags = 0
22312242
driver.cuStreamWaitEvent(hstream, handle, flags)
@@ -3080,17 +3091,14 @@ def host_to_device(dst, src, size, stream=0):
30803091
it should not be changed until the operation which can be asynchronous
30813092
completes.
30823093
"""
3083-
varargs = []
3094+
fn = driver.cuMemcpyHtoD
3095+
args = (device_pointer(dst), host_pointer(src, readonly=True), size)
30843096

30853097
if stream:
3086-
assert isinstance(stream, Stream)
30873098
fn = driver.cuMemcpyHtoDAsync
3088-
handle = stream.handle.value
3089-
varargs.append(handle)
3090-
else:
3091-
fn = driver.cuMemcpyHtoD
3099+
args += (_stream_handle(stream),)
30923100

3093-
fn(device_pointer(dst), host_pointer(src, readonly=True), size, *varargs)
3101+
fn(*args)
30943102

30953103

30963104
def device_to_host(dst, src, size, stream=0):
@@ -3099,61 +3107,52 @@ def device_to_host(dst, src, size, stream=0):
30993107
it should not be changed until the operation which can be asynchronous
31003108
completes.
31013109
"""
3102-
varargs = []
3110+
fn = driver.cuMemcpyDtoH
3111+
args = (host_pointer(dst), device_pointer(src), size)
31033112

31043113
if stream:
3105-
assert isinstance(stream, Stream)
31063114
fn = driver.cuMemcpyDtoHAsync
3107-
handle = stream.handle.value
3108-
varargs.append(handle)
3109-
else:
3110-
fn = driver.cuMemcpyDtoH
3115+
args += (_stream_handle(stream),)
31113116

3112-
fn(host_pointer(dst), device_pointer(src), size, *varargs)
3117+
fn(*args)
31133118

31143119

31153120
def device_to_device(dst, src, size, stream=0):
31163121
"""
3117-
NOTE: The underlying data pointer from the host data buffer is used and
3122+
NOTE: The underlying data pointer from the device buffer is used and
31183123
it should not be changed until the operation which can be asynchronous
31193124
completes.
31203125
"""
3121-
varargs = []
3126+
fn = driver.cuMemcpyDtoD
3127+
args = (device_pointer(dst), device_pointer(src), size)
31223128

31233129
if stream:
3124-
assert isinstance(stream, Stream)
31253130
fn = driver.cuMemcpyDtoDAsync
3126-
handle = stream.handle.value
3127-
varargs.append(handle)
3128-
else:
3129-
fn = driver.cuMemcpyDtoD
3131+
args += (_stream_handle(stream),)
31303132

3131-
fn(device_pointer(dst), device_pointer(src), size, *varargs)
3133+
fn(*args)
31323134

31333135

31343136
def device_memset(dst, val, size, stream=0):
3135-
"""Memset on the device.
3136-
If stream is not zero, asynchronous mode is used.
3137+
"""
3138+
Memset on the device.
3139+
If stream is 0, the call is synchronous.
3140+
If stream is a Stream object, asynchronous mode is used.
31373141
31383142
dst: device memory
31393143
val: byte value to be written
3140-
size: number of byte to be written
3141-
stream: a CUDA stream
3144+
size: number of bytes to be written
3145+
stream: 0 (synchronous) or a CUDA stream
31423146
"""
3143-
ptr = device_pointer(dst)
3144-
3145-
varargs = []
3147+
fn = driver.cuMemsetD8
3148+
args = (device_pointer(dst), val, size)
31463149

31473150
if stream:
3148-
assert isinstance(stream, Stream)
31493151
fn = driver.cuMemsetD8Async
3150-
handle = stream.handle.value
3151-
varargs.append(handle)
3152-
else:
3153-
fn = driver.cuMemsetD8
3152+
args += (_stream_handle(stream),)
31543153

31553154
try:
3156-
fn(ptr, val, size, *varargs)
3155+
fn(*args)
31573156
except CudaAPIError as e:
31583157
invalid = binding.CUresult.CUDA_ERROR_INVALID_VALUE
31593158
if (
@@ -3226,3 +3225,28 @@ def inspect_obj_content(objpath: str):
32263225
code_types.add(match.group(1))
32273226

32283227
return code_types
3228+
3229+
3230+
def _stream_handle(stream):
3231+
"""
3232+
Obtain the appropriate handle for various types of
3233+
acceptable stream objects. Acceptable types are
3234+
int (0 for default stream), Stream, ExperimentalStream
3235+
"""
3236+
3237+
if stream == 0:
3238+
return stream
3239+
allowed = (Stream, ExperimentalStream)
3240+
if not isinstance(stream, allowed):
3241+
raise TypeError(
3242+
"Expected a Stream object or 0, got %s" % type(stream).__name__
3243+
)
3244+
elif hasattr(stream, "__cuda_stream__"):
3245+
ver, ptr = stream.__cuda_stream__()
3246+
assert ver == 0
3247+
if isinstance(ptr, binding.CUstream):
3248+
return get_cuda_native_handle(ptr)
3249+
else:
3250+
return ptr
3251+
else:
3252+
raise TypeError("Invalid Stream")

numba_cuda/numba/cuda/dispatcher.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,7 @@ def launch(self, args, griddim, blockdim, stream=0, sharedmem=0):
474474
for t, v in zip(self.argument_types, args):
475475
self._prepare_args(t, v, stream, retr, kernelargs)
476476

477-
stream_handle = stream and stream.handle.value or 0
477+
stream_handle = driver._stream_handle(stream)
478478

479479
# Invoke kernel
480480
driver.launch_kernel(

numba_cuda/numba/cuda/tests/cudadrv/test_cuda_driver.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,14 @@
99
driver,
1010
launch_kernel,
1111
)
12+
13+
from numba import cuda
1214
from numba.cuda.cudadrv import devices, driver as _driver
1315
from numba.cuda.testing import unittest, CUDATestCase
1416
from numba.cuda.testing import skip_on_cudasim
17+
import contextlib
1518

19+
from cuda.core.experimental import Device
1620

1721
ptx1 = """
1822
.version 1.4
@@ -152,6 +156,65 @@ def test_cuda_driver_stream_operations(self):
152156
for i, v in enumerate(array):
153157
self.assertEqual(i, v)
154158

159+
def test_cuda_core_stream_operations(self):
160+
module = self.context.create_module_ptx(self.ptx)
161+
function = module.get_function("_Z10helloworldPi")
162+
array = (c_int * 100)()
163+
dev = Device()
164+
dev.set_current()
165+
stream = dev.create_stream()
166+
167+
@contextlib.contextmanager
168+
def auto_synchronize(stream):
169+
try:
170+
yield stream
171+
finally:
172+
stream.sync()
173+
174+
with auto_synchronize(stream):
175+
memory = self.context.memalloc(sizeof(array))
176+
host_to_device(memory, array, sizeof(array), stream=stream)
177+
178+
ptr = memory.device_ctypes_pointer
179+
180+
launch_kernel(
181+
function.handle, # Kernel
182+
1,
183+
1,
184+
1, # gx, gy, gz
185+
100,
186+
1,
187+
1, # bx, by, bz
188+
0, # dynamic shared mem
189+
stream.handle, # stream
190+
[ptr],
191+
)
192+
193+
device_to_host(array, memory, sizeof(array), stream=stream)
194+
for i, v in enumerate(array):
195+
self.assertEqual(i, v)
196+
197+
def test_cuda_core_stream_launch_user_facing(self):
198+
@cuda.jit
199+
def kernel(a):
200+
idx = cuda.grid(1)
201+
if idx < len(a):
202+
a[idx] = idx
203+
204+
dev = Device()
205+
dev.set_current()
206+
stream = dev.create_stream()
207+
208+
ary = cuda.to_device([0] * 100, stream=stream)
209+
stream.sync()
210+
211+
kernel[1, 100, stream](ary)
212+
stream.sync()
213+
214+
result = ary.copy_to_host(stream=stream)
215+
for i, v in enumerate(result):
216+
self.assertEqual(i, v)
217+
155218
def test_cuda_driver_default_stream(self):
156219
# Test properties of the default stream
157220
ds = self.context.get_default_stream()

numba_cuda/numba/cuda/tests/cudadrv/test_events.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import numpy as np
55
from numba import cuda
66
from numba.cuda.testing import unittest, CUDATestCase
7+
from cuda.core.experimental import Device
8+
from numba.cuda.testing import skip_on_cudasim
79

810

911
class TestCudaEvent(CUDATestCase):
@@ -22,8 +24,18 @@ def test_event_elapsed(self):
2224
evtstart.elapsed_time(evtend)
2325

2426
def test_event_elapsed_stream(self):
25-
N = 32
2627
stream = cuda.stream()
28+
self.event_elapsed_inner(stream)
29+
30+
@skip_on_cudasim("Testing cuda.core events requires driver")
31+
def test_event_elapsed_cuda_core_stream(self):
32+
dev = Device()
33+
dev.set_current()
34+
stream = dev.create_stream()
35+
self.event_elapsed_inner(stream)
36+
37+
def event_elapsed_inner(self, stream):
38+
N = 32
2739
dary = cuda.device_array(N, dtype=np.double)
2840
evtstart = cuda.event()
2941
evtend = cuda.event()

0 commit comments

Comments
 (0)