Skip to content

Commit 4b90a71

Browse files
committed
add DWARF address class for cuda array in shared memory
1 parent 5389798 commit 4b90a71

File tree

3 files changed

+130
-0
lines changed

3 files changed

+130
-0
lines changed

numba_cuda/numba/cuda/cudadecl.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,20 @@ def typer(shape, dtype, alignment=None):
6262
class Cuda_shared_array(Cuda_array_decl):
6363
key = cuda.shared.array
6464

65+
def generic(self):
66+
typer = super().generic()
67+
68+
# Wrap it to add _addrspace attribute for debug info only
69+
from numba.cuda.cudadrv import nvvm
70+
71+
def shared_array_typer(shape, dtype, alignment=None):
72+
result = typer(shape, dtype, alignment)
73+
if result is not None:
74+
result._addrspace = nvvm.ADDRSPACE_SHARED
75+
return result
76+
77+
return shared_array_typer
78+
6579

6680
@register
6781
class Cuda_local_array(Cuda_array_decl):

numba_cuda/numba/cuda/debuginfo.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,27 @@
44
import abc
55
import os
66
from contextlib import contextmanager
7+
from enum import IntEnum
78

89
from llvmlite import ir
910
from numba.cuda import types
1011
from numba.cuda.core import config
1112
from numba.cuda import cgutils
13+
from numba.cuda.cudadrv import nvvm
1214
from numba.cuda.datamodel.models import ComplexModel, UnionModel, UniTupleModel
1315
from numba.cuda.types.ext_types import GridGroup
1416

1517

18+
class DwarfAddressClass(IntEnum):
19+
GENERIC = 0x00
20+
GLOBAL = 0x01
21+
REGISTER = 0x02
22+
CONSTANT = 0x05
23+
LOCAL = 0x06
24+
PARAMETER = 0x07
25+
SHARED = 0x08
26+
27+
1628
@contextmanager
1729
def suspend_emission(builder):
1830
"""Suspends the emission of debug_metadata for the duration of the context
@@ -121,6 +133,18 @@ def initialize(self):
121133
# constructing subprograms
122134
self.dicompileunit = self._di_compile_unit()
123135

136+
def get_dwarf_address_class(self, addrspace):
137+
# Map NVVM address space to DWARF address class.
138+
139+
addrspace_to_addrclass_dict = {
140+
nvvm.ADDRSPACE_GENERIC: None,
141+
nvvm.ADDRSPACE_GLOBAL: DwarfAddressClass.GLOBAL,
142+
nvvm.ADDRSPACE_SHARED: DwarfAddressClass.SHARED,
143+
nvvm.ADDRSPACE_CONSTANT: DwarfAddressClass.CONSTANT,
144+
nvvm.ADDRSPACE_LOCAL: DwarfAddressClass.LOCAL,
145+
}
146+
return addrspace_to_addrclass_dict.get(addrspace)
147+
124148
def _var_type(self, lltype, size, datamodel=None):
125149
if self._DEBUG:
126150
print(
@@ -659,6 +683,67 @@ def _var_type(self, lltype, size, datamodel=None):
659683
},
660684
is_distinct=True,
661685
)
686+
687+
# Check if there's actually address space info to handle
688+
if (
689+
isinstance(lltype, ir.LiteralStructType)
690+
and datamodel is not None
691+
and datamodel.inner_models()
692+
and hasattr(datamodel, "fe_type")
693+
and getattr(datamodel.fe_type, "_addrspace", None) not in (None, 0)
694+
):
695+
# Process struct with datamodel that has _addrspace available
696+
meta = []
697+
offset = 0
698+
for element, field, model in zip(
699+
lltype.elements, datamodel._fields, datamodel.inner_models()
700+
):
701+
size_field = self.cgctx.get_abi_sizeof(element)
702+
if isinstance(element, ir.PointerType) and field == "data":
703+
# Create pointer type with correct address space
704+
pointee_size = self.cgctx.get_abi_sizeof(element.pointee)
705+
pointee_model = getattr(model, "_pointee_model", None)
706+
pointee_type = self._var_type(
707+
element.pointee, pointee_size, datamodel=pointee_model
708+
)
709+
meta_ptr = {
710+
"tag": ir.DIToken("DW_TAG_pointer_type"),
711+
"baseType": pointee_type,
712+
"size": _BYTE_SIZE * size_field,
713+
}
714+
addrspace = getattr(datamodel.fe_type, "_addrspace", None)
715+
dwarf_addrclass = self.get_dwarf_address_class(addrspace)
716+
if dwarf_addrclass is not None:
717+
meta_ptr["dwarfAddressSpace"] = dwarf_addrclass
718+
basetype = m.add_debug_info("DIDerivedType", meta_ptr)
719+
else:
720+
basetype = self._var_type(
721+
element, size_field, datamodel=model
722+
)
723+
derived_type = m.add_debug_info(
724+
"DIDerivedType",
725+
{
726+
"tag": ir.DIToken("DW_TAG_member"),
727+
"name": field,
728+
"baseType": basetype,
729+
"size": _BYTE_SIZE * size_field,
730+
"offset": offset,
731+
},
732+
)
733+
meta.append(derived_type)
734+
offset += _BYTE_SIZE * size_field
735+
736+
return m.add_debug_info(
737+
"DICompositeType",
738+
{
739+
"tag": ir.DIToken("DW_TAG_structure_type"),
740+
"name": f"{datamodel.fe_type}",
741+
"elements": m.add_metadata(meta),
742+
"size": offset,
743+
},
744+
is_distinct=True,
745+
)
746+
662747
# For other cases, use upstream Numba implementation
663748
return super()._var_type(lltype, size, datamodel=datamodel)
664749

numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -840,6 +840,37 @@ def foo():
840840
""",
841841
)
842842

843+
def test_shared_memory_address_class(self):
844+
"""Test that shared memory arrays have correct DWARF address class.
845+
846+
Shared memory pointers should have addressClass: 8 (DW_AT_address_class
847+
for CUDA shared memory) in their debug metadata.
848+
"""
849+
sig = (types.int32,)
850+
851+
@cuda.jit(sig, debug=True, opt=False)
852+
def kernel_with_shared(data):
853+
shared_arr = cuda.shared.array(32, dtype=np.int32)
854+
idx = cuda.grid(1)
855+
if idx < 32:
856+
shared_arr[idx] = data + idx
857+
cuda.syncthreads()
858+
if idx == 0:
859+
result = np.int32(0)
860+
for i in range(32):
861+
result += shared_arr[i]
862+
863+
llvm_ir = kernel_with_shared.inspect_llvm(sig)
864+
865+
# Find the DIDerivedType for the pointer to shared memory (int32 addrspace(3)*)
866+
# The pointer should have dwarfAddressSpace: 8 for shared memory
867+
pat = r"!DIDerivedType\([^)]*dwarfAddressSpace:\s*8[^)]*tag:\s*DW_TAG_pointer_type[^)]*\)"
868+
match = re.compile(pat).search(llvm_ir)
869+
self.assertIsNotNone(
870+
match,
871+
msg=f"Shared memory pointer should have dwarfAddressSpace: 8 in LLVM IR.\n{llvm_ir}",
872+
)
873+
843874

844875
if __name__ == "__main__":
845876
unittest.main()

0 commit comments

Comments
 (0)