4343import importlib
4444import numpy as np
4545from collections import namedtuple , deque
46- from uuid import UUID
4746
4847
4948from numba .cuda .cext import mviewbuf
6766from cuda .bindings .utils import get_cuda_native_handle
6867from cuda .core .experimental import (
6968 Stream as ExperimentalStream ,
69+ Device as ExperimentalDevice ,
7070)
7171
7272
@@ -527,7 +527,7 @@ def _build_reverse_device_attrs():
527527DEVICE_ATTRIBUTES = _build_reverse_device_attrs ()
528528
529529
530- class Device ( object ) :
530+ class Device :
531531 """
532532 The device object owns the CUDA contexts. This is owned by the driver
533533 object. User should not construct devices directly.
@@ -548,32 +548,12 @@ def from_identity(self, identity):
548548 "Target device may not be visible in this process."
549549 )
550550
551- def __init__ (self , devnum ):
552- result = driver .cuDeviceGet (devnum )
553- self .id = result
554- got_devnum = int (result )
555-
556- msg = f"Driver returned device { got_devnum } instead of { devnum } "
557- if devnum != got_devnum :
558- raise RuntimeError (msg )
559-
560- # Read compute capability
561- self .compute_capability = (
562- self .COMPUTE_CAPABILITY_MAJOR ,
563- self .COMPUTE_CAPABILITY_MINOR ,
564- )
565-
566- # Read name
567- bufsz = 128
568- buf = driver .cuDeviceGetName (bufsz , self .id )
569- name = buf .split (b"\x00 " , 1 )[0 ]
570-
571- self .name = name
572-
573- # Read UUID
574- uuid = driver .cuDeviceGetUuid (self .id )
575- self .uuid = f"GPU-{ UUID (bytes = uuid .bytes )} "
576-
551+ def __init__ (self , devnum : int ) -> None :
552+ self ._dev = ExperimentalDevice (devnum )
553+ self .id = self ._dev .device_id
554+ self .compute_capability = self ._dev .compute_capability
555+ self .name = self ._dev .name
556+ self .uuid = f"GPU-{ self ._dev .uuid } "
577557 self .primary_context = None
578558
579559 def get_device_identity (self ):
@@ -613,13 +593,16 @@ def get_primary_context(self):
613593 if (ctx := self .primary_context ) is not None :
614594 return ctx
615595
616- met_requirement_for_device ( self )
617- # create primary context
618- hctx = driver . cuDevicePrimaryCtxRetain ( self . id )
619- hctx = drvapi . cu_context ( int ( hctx ) )
596+ if self . compute_capability < MIN_REQUIRED_CC :
597+ raise CudaSupportError (
598+ f" { self } has compute capability < { MIN_REQUIRED_CC } "
599+ )
620600
621- ctx = Context (weakref .proxy (self ), hctx )
622- self .primary_context = ctx
601+ self ._dev .set_current ()
602+ self .primary_context = ctx = Context (
603+ weakref .proxy (self ),
604+ ctypes .c_void_p (int (self ._dev .context ._handle )),
605+ )
623606 return ctx
624607
625608 def release_primary_context (self ):
@@ -648,13 +631,6 @@ def supports_bfloat16(self):
648631 return self .compute_capability >= (8 , 0 )
649632
650633
651- def met_requirement_for_device (device ):
652- if device .compute_capability < MIN_REQUIRED_CC :
653- raise CudaSupportError (
654- "%s has compute capability < %s" % (device , MIN_REQUIRED_CC )
655- )
656-
657-
658634class BaseCUDAMemoryManager (object , metaclass = ABCMeta ):
659635 """Abstract base class for External Memory Management (EMM) Plugins."""
660636
0 commit comments