Skip to content

Commit cb1978f

Browse files
authored
refactor: replace device functionality with cuda.core APIs (#581)
1 parent d193a64 commit cb1978f

File tree

1 file changed

+17
-41
lines changed

1 file changed

+17
-41
lines changed

numba_cuda/numba/cuda/cudadrv/driver.py

Lines changed: 17 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
import importlib
4444
import numpy as np
4545
from collections import namedtuple, deque
46-
from uuid import UUID
4746

4847

4948
from numba.cuda.cext import mviewbuf
@@ -67,6 +66,7 @@
6766
from cuda.bindings.utils import get_cuda_native_handle
6867
from cuda.core.experimental import (
6968
Stream as ExperimentalStream,
69+
Device as ExperimentalDevice,
7070
)
7171

7272

@@ -527,7 +527,7 @@ def _build_reverse_device_attrs():
527527
DEVICE_ATTRIBUTES = _build_reverse_device_attrs()
528528

529529

530-
class Device(object):
530+
class Device:
531531
"""
532532
The device object owns the CUDA contexts. This is owned by the driver
533533
object. User should not construct devices directly.
@@ -548,32 +548,12 @@ def from_identity(self, identity):
548548
"Target device may not be visible in this process."
549549
)
550550

551-
def __init__(self, devnum):
552-
result = driver.cuDeviceGet(devnum)
553-
self.id = result
554-
got_devnum = int(result)
555-
556-
msg = f"Driver returned device {got_devnum} instead of {devnum}"
557-
if devnum != got_devnum:
558-
raise RuntimeError(msg)
559-
560-
# Read compute capability
561-
self.compute_capability = (
562-
self.COMPUTE_CAPABILITY_MAJOR,
563-
self.COMPUTE_CAPABILITY_MINOR,
564-
)
565-
566-
# Read name
567-
bufsz = 128
568-
buf = driver.cuDeviceGetName(bufsz, self.id)
569-
name = buf.split(b"\x00", 1)[0]
570-
571-
self.name = name
572-
573-
# Read UUID
574-
uuid = driver.cuDeviceGetUuid(self.id)
575-
self.uuid = f"GPU-{UUID(bytes=uuid.bytes)}"
576-
551+
def __init__(self, devnum: int) -> None:
552+
self._dev = ExperimentalDevice(devnum)
553+
self.id = self._dev.device_id
554+
self.compute_capability = self._dev.compute_capability
555+
self.name = self._dev.name
556+
self.uuid = f"GPU-{self._dev.uuid}"
577557
self.primary_context = None
578558

579559
def get_device_identity(self):
@@ -613,13 +593,16 @@ def get_primary_context(self):
613593
if (ctx := self.primary_context) is not None:
614594
return ctx
615595

616-
met_requirement_for_device(self)
617-
# create primary context
618-
hctx = driver.cuDevicePrimaryCtxRetain(self.id)
619-
hctx = drvapi.cu_context(int(hctx))
596+
if self.compute_capability < MIN_REQUIRED_CC:
597+
raise CudaSupportError(
598+
f"{self} has compute capability < {MIN_REQUIRED_CC}"
599+
)
620600

621-
ctx = Context(weakref.proxy(self), hctx)
622-
self.primary_context = ctx
601+
self._dev.set_current()
602+
self.primary_context = ctx = Context(
603+
weakref.proxy(self),
604+
ctypes.c_void_p(int(self._dev.context._handle)),
605+
)
623606
return ctx
624607

625608
def release_primary_context(self):
@@ -648,13 +631,6 @@ def supports_bfloat16(self):
648631
return self.compute_capability >= (8, 0)
649632

650633

651-
def met_requirement_for_device(device):
652-
if device.compute_capability < MIN_REQUIRED_CC:
653-
raise CudaSupportError(
654-
"%s has compute capability < %s" % (device, MIN_REQUIRED_CC)
655-
)
656-
657-
658634
class BaseCUDAMemoryManager(object, metaclass=ABCMeta):
659635
"""Abstract base class for External Memory Management (EMM) Plugins."""
660636

0 commit comments

Comments
 (0)