Skip to content

Commit 54f79e0

Browse files
committed
refactor: clean up Device constructor
1 parent 2567b28 commit 54f79e0

File tree

1 file changed

+12
-23
lines changed

1 file changed

+12
-23
lines changed

numba_cuda/numba/cuda/cudadrv/driver.py

Lines changed: 12 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
import importlib
4444
import numpy as np
4545
from collections import namedtuple, deque
46+
from uuid import UUID
4647

4748

4849
from numba.cuda.cext import mviewbuf
@@ -536,11 +537,10 @@ def from_identity(self, identity):
536537
if d.get_device_identity() == identity:
537538
return d
538539
else:
539-
errmsg = (
540-
"No device of {} is found. "
540+
raise RuntimeError(
541+
f"No device of {identity} is found. "
541542
"Target device may not be visible in this process."
542-
).format(identity)
543-
raise RuntimeError(errmsg)
543+
)
544544

545545
def __init__(self, devnum):
546546
result = driver.cuDeviceGet(devnum)
@@ -551,8 +551,6 @@ def __init__(self, devnum):
551551
if devnum != got_devnum:
552552
raise RuntimeError(msg)
553553

554-
self.attributes = {}
555-
556554
# Read compute capability
557555
self.compute_capability = (
558556
self.COMPUTE_CAPABILITY_MAJOR,
@@ -562,20 +560,13 @@ def __init__(self, devnum):
562560
# Read name
563561
bufsz = 128
564562
buf = driver.cuDeviceGetName(bufsz, self.id)
565-
name = buf.split(b"\x00")[0]
563+
name = buf.split(b"\x00", 1)[0]
566564

567565
self.name = name
568566

569567
# Read UUID
570568
uuid = driver.cuDeviceGetUuid(self.id)
571-
uuid_vals = tuple(uuid.bytes)
572-
573-
b = "%02x"
574-
b2 = b * 2
575-
b4 = b * 4
576-
b6 = b * 6
577-
fmt = f"GPU-{b4}-{b2}-{b2}-{b2}-{b6}"
578-
self.uuid = fmt % uuid_vals
569+
self.uuid = f"GPU-{UUID(bytes=uuid.bytes)}"
579570

580571
self.primary_context = None
581572

@@ -587,7 +578,7 @@ def get_device_identity(self):
587578
}
588579

589580
def __repr__(self):
590-
return "<CUDA device %d '%s'>" % (self.id, self.name)
581+
return f"<CUDA device {self.id:d} '{self.name}'>"
591582

592583
def __getattr__(self, attr):
593584
"""Read attributes lazily"""
@@ -603,9 +594,7 @@ def __hash__(self):
603594
return hash(self.id)
604595

605596
def __eq__(self, other):
606-
if isinstance(other, Device):
607-
return self.id == other.id
608-
return False
597+
return isinstance(other, Device) and self.id == other.id
609598

610599
def __ne__(self, other):
611600
return not (self == other)
@@ -615,8 +604,8 @@ def get_primary_context(self):
615604
Returns the primary context for the device.
616605
Note: it is not pushed to the CPU thread.
617606
"""
618-
if self.primary_context is not None:
619-
return self.primary_context
607+
if (ctx := self.primary_context) is not None:
608+
return ctx
620609

621610
met_requirement_for_device(self)
622611
# create primary context
@@ -637,8 +626,8 @@ def release_primary_context(self):
637626

638627
def reset(self):
639628
try:
640-
if self.primary_context is not None:
641-
self.primary_context.reset()
629+
if (ctx := self.primary_context) is not None:
630+
ctx.reset()
642631
self.release_primary_context()
643632
finally:
644633
# reset at the driver level

0 commit comments

Comments
 (0)