Skip to content

Commit 2903172

Browse files
committed
Initial strart of implementation
1 parent cea934b commit 2903172

File tree

4 files changed

+258
-27
lines changed

4 files changed

+258
-27
lines changed

numba_cuda/numba/cuda/api_util.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from numba import types
2+
from numba.core import cgutils
13
import numpy as np
24

35

@@ -28,3 +30,26 @@ def _fill_stride_by_order(shape, dtype, order):
2830
else:
2931
raise ValueError('must be either C/F order')
3032
return tuple(strides)
33+
34+
35+
def normalize_indices(context, builder, indty, inds, aryty, valty):
36+
"""
37+
Convert integer indices into tuple of intp
38+
"""
39+
if indty in types.integer_domain:
40+
indty = types.UniTuple(dtype=indty, count=1)
41+
indices = [inds]
42+
else:
43+
indices = cgutils.unpack_tuple(builder, inds, count=len(indty))
44+
indices = [context.cast(builder, i, t, types.intp)
45+
for t, i in zip(indty, indices)]
46+
47+
dtype = aryty.dtype
48+
if dtype != valty:
49+
raise TypeError("expect %s but got %s" % (dtype, valty))
50+
51+
if aryty.ndim != len(indty):
52+
raise TypeError("indexing %d-D array with %d-D index" %
53+
(aryty.ndim, len(indty)))
54+
55+
return indty, indices
Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
from llvmlite import ir
2+
from numba import types
3+
from numba.core import cgutils
4+
from numba.core.extending import intrinsic, overload
5+
from numba.core.errors import NumbaTypeError
6+
from numba.cuda.api_util import normalize_indices
7+
8+
9+
def ldca(array, i):
10+
"""Generate a `ld.global.ca` instruction for element `i` of an array."""
11+
12+
13+
def ldcg(array, i):
14+
"""Generate a `ld.global.cg` instruction for element `i` of an array."""
15+
16+
17+
def ldcs(array, i):
18+
"""Generate a `ld.global.cs` instruction for element `i` of an array."""
19+
20+
21+
def ldlu(array, i):
22+
"""Generate a `ld.global.lu` instruction for element `i` of an array."""
23+
24+
25+
def ldcv(array, i):
26+
"""Generate a `ld.global.cv` instruction for element `i` of an array."""
27+
28+
29+
def stcg(array, i, value):
30+
"""Generate a `st.global.cg` instruction for element `i` of an array."""
31+
32+
33+
def stcs(array, i, value):
34+
"""Generate a `st.global.cs` instruction for element `i` of an array."""
35+
36+
37+
def stwb(array, i, value):
38+
"""Generate a `st.global.wb` instruction for element `i` of an array."""
39+
40+
41+
def stwt(array, i, value):
42+
"""Generate a `st.global.wt` instruction for element `i` of an array."""
43+
44+
45+
def ld_cache_operator(operator):
46+
@intrinsic
47+
def impl(typingctx, array, index):
48+
if not isinstance(array, types.Array):
49+
msg = f"ldcs operates on arrays. Got type {array}"
50+
raise NumbaTypeError(msg)
51+
52+
# Need to validate bitwidth
53+
54+
# Need to validate indices
55+
56+
signature = array.dtype(array, index)
57+
58+
def codegen(context, builder, sig, args):
59+
array_type, index_type = sig.args
60+
loaded_type = context.get_value_type(array_type.dtype)
61+
ptr_type = loaded_type.as_pointer()
62+
ldcs_type = ir.FunctionType(loaded_type, [ptr_type])
63+
64+
array, indices = args
65+
66+
index_type, indices = normalize_indices(context, builder,
67+
index_type, indices,
68+
array_type,
69+
array_type.dtype)
70+
array_struct = context.make_array(array_type)(context, builder,
71+
value=array)
72+
ptr = cgutils.get_item_pointer(context, builder, array_type,
73+
array_struct, indices,
74+
wraparound=True)
75+
76+
bitwidth = array_type.dtype.bitwidth
77+
inst = f"ld.global.{operator}.b{bitwidth}"
78+
# See
79+
# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#restricted-use-of-sub-word-sizes
80+
# for background on the choice of "r" for 8-bit operands - there is
81+
# no constraint for 8-bit operands, but the operand for loads and
82+
# stores is permitted to be greater than 8 bits.
83+
constraint_map = {
84+
1: "b",
85+
8: "r",
86+
16: "h",
87+
32: "r",
88+
64: "l",
89+
128: "q"
90+
}
91+
constraints = f"={constraint_map[bitwidth]},l"
92+
ldcs = ir.InlineAsm(ldcs_type, f"{inst} $0, [$1];", constraints)
93+
return builder.call(ldcs, [ptr])
94+
95+
return signature, codegen
96+
97+
return impl
98+
99+
100+
ldca_intrinsic = ld_cache_operator("ca")
101+
ldcg_intrinsic = ld_cache_operator("cg")
102+
ldcs_intrinsic = ld_cache_operator("cs")
103+
ldlu_intrinsic = ld_cache_operator("lu")
104+
ldcv_intrinsic = ld_cache_operator("cv")
105+
106+
107+
def st_cache_operator(operator):
108+
@intrinsic
109+
def impl(typingctx, array, index, value):
110+
if not isinstance(array, types.Array):
111+
msg = f"ldcs operates on arrays. Got type {array}"
112+
raise NumbaTypeError(msg)
113+
114+
# Need to validate bitwidth
115+
116+
# Need to validate indices
117+
118+
signature = types.void(array, index, value)
119+
120+
def codegen(context, builder, sig, args):
121+
array_type, index_type, value_type = sig.args
122+
stored_type = context.get_value_type(array_type.dtype)
123+
ptr_type = stored_type.as_pointer()
124+
stcs_type = ir.FunctionType(ir.VoidType(), [ptr_type, stored_type])
125+
126+
array, indices, value = args
127+
128+
index_type, indices = normalize_indices(context, builder,
129+
index_type, indices,
130+
array_type,
131+
array_type.dtype)
132+
array_struct = context.make_array(array_type)(context, builder,
133+
value=array)
134+
ptr = cgutils.get_item_pointer(context, builder, array_type,
135+
array_struct, indices,
136+
wraparound=True)
137+
138+
casted_value = context.cast(builder, value, value_type,
139+
array_type.dtype)
140+
141+
bitwidth = array_type.dtype.bitwidth
142+
inst = f"st.global.{operator}.b{bitwidth}"
143+
constraint_map = {
144+
1: "b",
145+
8: "r",
146+
16: "h",
147+
32: "r",
148+
64: "l",
149+
128: "q"
150+
}
151+
constraints = f"l,{constraint_map[bitwidth]},~{{memory}}"
152+
stcs = ir.InlineAsm(stcs_type, f"{inst} [$0], $1;", constraints)
153+
builder.call(stcs, [ptr, casted_value])
154+
155+
return signature, codegen
156+
157+
return impl
158+
159+
160+
stcg_intrinsic = st_cache_operator("cg")
161+
stcs_intrinsic = st_cache_operator("cs")
162+
stwb_intrinsic = st_cache_operator("wb")
163+
stwt_intrinsic = st_cache_operator("wt")
164+
165+
166+
@overload(ldca, target='cuda')
167+
def ol_ldca(array, i):
168+
def impl(array, i):
169+
return ldca_intrinsic(array, i)
170+
return impl
171+
172+
173+
@overload(ldcg, target='cuda')
174+
def ol_ldcg(array, i):
175+
def impl(array, i):
176+
return ldcg_intrinsic(array, i)
177+
return impl
178+
179+
180+
@overload(ldcs, target='cuda')
181+
def ol_ldcs(array, i):
182+
def impl(array, i):
183+
return ldcs_intrinsic(array, i)
184+
return impl
185+
186+
187+
@overload(ldlu, target='cuda')
188+
def ol_ldlu(array, i):
189+
def impl(array, i):
190+
return ldlu_intrinsic(array, i)
191+
return impl
192+
193+
194+
@overload(ldcv, target='cuda')
195+
def ol_ldcv(array, i):
196+
def impl(array, i):
197+
return ldcv_intrinsic(array, i)
198+
return impl
199+
200+
201+
@overload(stcg, target='cuda')
202+
def ol_stcg(array, i, value):
203+
def impl(array, i, value):
204+
return stcg_intrinsic(array, i, value)
205+
return impl
206+
207+
208+
@overload(stcs, target='cuda')
209+
def ol_stcs(array, i, value):
210+
def impl(array, i, value):
211+
return stcs_intrinsic(array, i, value)
212+
return impl
213+
214+
215+
@overload(stwb, target='cuda')
216+
def ol_stwb(array, i, value):
217+
def impl(array, i, value):
218+
return stwb_intrinsic(array, i, value)
219+
return impl
220+
221+
222+
@overload(stwt, target='cuda')
223+
def ol_stwt(array, i, value):
224+
def impl(array, i, value):
225+
return stwt_intrinsic(array, i, value)
226+
return impl

numba_cuda/numba/cuda/cudaimpl.py

Lines changed: 5 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from numba.np.npyimpl import register_ufuncs
1414
from .cudadrv import nvvm
1515
from numba import cuda
16+
from numba.cuda.api_util import normalize_indices
1617
from numba.cuda import nvvmutils, stubs, errors
1718
from numba.cuda.types import dim3, CUDADispatcher
1819

@@ -692,38 +693,15 @@ def impl(context, builder, sig, args):
692693
lower(math.degrees, types.f8)(gen_deg_rad(_rad2deg))
693694

694695

695-
def _normalize_indices(context, builder, indty, inds, aryty, valty):
696-
"""
697-
Convert integer indices into tuple of intp
698-
"""
699-
if indty in types.integer_domain:
700-
indty = types.UniTuple(dtype=indty, count=1)
701-
indices = [inds]
702-
else:
703-
indices = cgutils.unpack_tuple(builder, inds, count=len(indty))
704-
indices = [context.cast(builder, i, t, types.intp)
705-
for t, i in zip(indty, indices)]
706-
707-
dtype = aryty.dtype
708-
if dtype != valty:
709-
raise TypeError("expect %s but got %s" % (dtype, valty))
710-
711-
if aryty.ndim != len(indty):
712-
raise TypeError("indexing %d-D array with %d-D index" %
713-
(aryty.ndim, len(indty)))
714-
715-
return indty, indices
716-
717-
718696
def _atomic_dispatcher(dispatch_fn):
719697
def imp(context, builder, sig, args):
720698
# The common argument handling code
721699
aryty, indty, valty = sig.args
722700
ary, inds, val = args
723701
dtype = aryty.dtype
724702

725-
indty, indices = _normalize_indices(context, builder, indty, inds,
726-
aryty, valty)
703+
indty, indices = normalize_indices(context, builder, indty, inds,
704+
aryty, valty)
727705

728706
lary = context.make_array(aryty)(context, builder, ary)
729707
ptr = cgutils.get_item_pointer(context, builder, aryty, lary, indices,
@@ -917,8 +895,8 @@ def ptx_atomic_cas(context, builder, sig, args):
917895
aryty, indty, oldty, valty = sig.args
918896
ary, inds, old, val = args
919897

920-
indty, indices = _normalize_indices(context, builder, indty, inds, aryty,
921-
valty)
898+
indty, indices = normalize_indices(context, builder, indty, inds, aryty,
899+
valty)
922900

923901
lary = context.make_array(aryty)(context, builder, ary)
924902
ptr = cgutils.get_item_pointer(context, builder, aryty, lary, indices,

numba_cuda/numba/cuda/device_init.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# Re export
22
import sys
33
from numba.cuda import cg
4+
from numba.cuda.cache_hints import (ldca, ldcg, ldcs, ldlu, ldcv, stcg, stcs,
5+
stwb, stwt)
46
from .stubs import (threadIdx, blockIdx, blockDim, gridDim, laneid, warpsize,
57
syncwarp, shared, local, const, atomic,
68
shfl_sync_intrinsic, vote_sync_intrinsic, match_any_sync,

0 commit comments

Comments
 (0)