6464 ObjectCode ,
6565)
6666
67+ from cuda .bindings .utils import get_cuda_native_handle
68+ from cuda .core .experimental import (
69+ Stream as ExperimentalStream ,
70+ )
71+
72+
6773# There is no definition of the default stream in the Nvidia bindings (nor
6874# is there at the C/C++ level), so we define it here so we don't need to
6975# use a magic number 0 in places where we want the default stream.
@@ -2064,6 +2070,11 @@ def __int__(self):
20642070 # The default stream's handle.value is 0, which gives `None`
20652071 return self .handle .value or drvapi .CU_STREAM_DEFAULT
20662072
2073+ def __cuda_stream__ (self ):
2074+ if not self .handle .value :
2075+ return (0 , drvapi .CU_STREAM_DEFAULT )
2076+ return (0 , self .handle .value )
2077+
20672078 def __repr__ (self ):
20682079 default_streams = {
20692080 drvapi .CU_STREAM_DEFAULT : "<Default CUDA stream on %s>" ,
@@ -2210,7 +2221,7 @@ def record(self, stream=0):
22102221 queued in the stream at the time of the call to ``record()`` has been
22112222 completed.
22122223 """
2213- hstream = stream . handle . value if stream else binding . CUstream ( 0 )
2224+ hstream = _stream_handle ( stream )
22142225 handle = self .handle .value
22152226 driver .cuEventRecord (handle , hstream )
22162227
@@ -2225,7 +2236,7 @@ def wait(self, stream=0):
22252236 """
22262237 All future works submitted to stream will wait util the event completes.
22272238 """
2228- hstream = stream . handle . value if stream else binding . CUstream ( 0 )
2239+ hstream = _stream_handle ( stream )
22292240 handle = self .handle .value
22302241 flags = 0
22312242 driver .cuStreamWaitEvent (hstream , handle , flags )
@@ -3080,17 +3091,14 @@ def host_to_device(dst, src, size, stream=0):
30803091 it should not be changed until the operation which can be asynchronous
30813092 completes.
30823093 """
3083- varargs = []
3094+ fn = driver .cuMemcpyHtoD
3095+ args = (device_pointer (dst ), host_pointer (src , readonly = True ), size )
30843096
30853097 if stream :
3086- assert isinstance (stream , Stream )
30873098 fn = driver .cuMemcpyHtoDAsync
3088- handle = stream .handle .value
3089- varargs .append (handle )
3090- else :
3091- fn = driver .cuMemcpyHtoD
3099+ args += (_stream_handle (stream ),)
30923100
3093- fn (device_pointer ( dst ), host_pointer ( src , readonly = True ), size , * varargs )
3101+ fn (* args )
30943102
30953103
30963104def device_to_host (dst , src , size , stream = 0 ):
@@ -3099,61 +3107,52 @@ def device_to_host(dst, src, size, stream=0):
30993107 it should not be changed until the operation which can be asynchronous
31003108 completes.
31013109 """
3102- varargs = []
3110+ fn = driver .cuMemcpyDtoH
3111+ args = (host_pointer (dst ), device_pointer (src ), size )
31033112
31043113 if stream :
3105- assert isinstance (stream , Stream )
31063114 fn = driver .cuMemcpyDtoHAsync
3107- handle = stream .handle .value
3108- varargs .append (handle )
3109- else :
3110- fn = driver .cuMemcpyDtoH
3115+ args += (_stream_handle (stream ),)
31113116
3112- fn (host_pointer ( dst ), device_pointer ( src ), size , * varargs )
3117+ fn (* args )
31133118
31143119
31153120def device_to_device (dst , src , size , stream = 0 ):
31163121 """
3117- NOTE: The underlying data pointer from the host data buffer is used and
3122+ NOTE: The underlying data pointer from the device buffer is used and
31183123 it should not be changed until the operation which can be asynchronous
31193124 completes.
31203125 """
3121- varargs = []
3126+ fn = driver .cuMemcpyDtoD
3127+ args = (device_pointer (dst ), device_pointer (src ), size )
31223128
31233129 if stream :
3124- assert isinstance (stream , Stream )
31253130 fn = driver .cuMemcpyDtoDAsync
3126- handle = stream .handle .value
3127- varargs .append (handle )
3128- else :
3129- fn = driver .cuMemcpyDtoD
3131+ args += (_stream_handle (stream ),)
31303132
3131- fn (device_pointer ( dst ), device_pointer ( src ), size , * varargs )
3133+ fn (* args )
31323134
31333135
31343136def device_memset (dst , val , size , stream = 0 ):
3135- """Memset on the device.
3136- If stream is not zero, asynchronous mode is used.
3137+ """
3138+ Memset on the device.
3139+ If stream is 0, the call is synchronous.
3140+ If stream is a Stream object, asynchronous mode is used.
31373141
31383142 dst: device memory
31393143 val: byte value to be written
3140- size: number of byte to be written
3141- stream: a CUDA stream
3144+ size: number of bytes to be written
3145+ stream: 0 (synchronous) or a CUDA stream
31423146 """
3143- ptr = device_pointer (dst )
3144-
3145- varargs = []
3147+ fn = driver .cuMemsetD8
3148+ args = (device_pointer (dst ), val , size )
31463149
31473150 if stream :
3148- assert isinstance (stream , Stream )
31493151 fn = driver .cuMemsetD8Async
3150- handle = stream .handle .value
3151- varargs .append (handle )
3152- else :
3153- fn = driver .cuMemsetD8
3152+ args += (_stream_handle (stream ),)
31543153
31553154 try :
3156- fn (ptr , val , size , * varargs )
3155+ fn (* args )
31573156 except CudaAPIError as e :
31583157 invalid = binding .CUresult .CUDA_ERROR_INVALID_VALUE
31593158 if (
@@ -3226,3 +3225,28 @@ def inspect_obj_content(objpath: str):
32263225 code_types .add (match .group (1 ))
32273226
32283227 return code_types
3228+
3229+
3230+ def _stream_handle (stream ):
3231+ """
3232+ Obtain the appropriate handle for various types of
3233+ acceptable stream objects. Acceptable types are
3234+ int (0 for default stream), Stream, ExperimentalStream
3235+ """
3236+
3237+ if stream == 0 :
3238+ return stream
3239+ allowed = (Stream , ExperimentalStream )
3240+ if not isinstance (stream , allowed ):
3241+ raise TypeError (
3242+ "Expected a Stream object or 0, got %s" % type (stream ).__name__
3243+ )
3244+ elif hasattr (stream , "__cuda_stream__" ):
3245+ ver , ptr = stream .__cuda_stream__ ()
3246+ assert ver == 0
3247+ if isinstance (ptr , binding .CUstream ):
3248+ return get_cuda_native_handle (ptr )
3249+ else :
3250+ return ptr
3251+ else :
3252+ raise TypeError ("Invalid Stream" )
0 commit comments