Skip to content

Commit 1c04a80

Browse files
Vendor in typeconv for future CUDA-specific changes (#499)
This PR vendors in `numba.core.typeconv` for CUDA-specific customizations. The `_typeconv` C extension can be vendored in after PR #373 gets merged in; it already contains the necessary cpp and hpp files for `_typeconv`. --------- Co-authored-by: Graham Markall <[email protected]>
1 parent 5667cf8 commit 1c04a80

File tree

14 files changed

+1007
-12
lines changed

14 files changed

+1007
-12
lines changed

numba_cuda/numba/cuda/_internal/cuda_fp16.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def __init__(self):
124124
self.bitwidth = 2 * 8
125125

126126
def can_convert_from(self, typingctx, other):
127-
from numba.core.typeconv import Conversion
127+
from numba.cuda.typeconv import Conversion
128128

129129
if other in []:
130130
return Conversion.safe
@@ -174,7 +174,7 @@ def __init__(self):
174174
self.bitwidth = 4 * 8
175175

176176
def can_convert_from(self, typingctx, other):
177-
from numba.core.typeconv import Conversion
177+
from numba.cuda.typeconv import Conversion
178178

179179
if other in []:
180180
return Conversion.safe
@@ -7903,9 +7903,9 @@ def generic(self, args, kws):
79037903
# - Conversion.safe
79047904

79057905
if (
7906-
(convertible == numba.core.typeconv.Conversion.exact)
7907-
or (convertible == numba.core.typeconv.Conversion.promote)
7908-
or (convertible == numba.core.typeconv.Conversion.safe)
7906+
(convertible == numba.cuda.typeconv.Conversion.exact)
7907+
or (convertible == numba.cuda.typeconv.Conversion.promote)
7908+
or (convertible == numba.cuda.typeconv.Conversion.safe)
79097909
):
79107910
return signature(retty, types.float16, types.float16)
79117911

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
// SPDX-License-Identifier: BSD-2-Clause
3+
4+
#include "_pymodule.h"
5+
#include "capsulethunk.h"
6+
#include "typeconv.hpp"
7+
8+
extern "C" {
9+
10+
11+
static PyObject*
12+
new_type_manager(PyObject* self, PyObject* args);
13+
14+
static void
15+
del_type_manager(PyObject *);
16+
17+
static PyObject*
18+
select_overload(PyObject* self, PyObject* args);
19+
20+
static PyObject*
21+
check_compatible(PyObject* self, PyObject* args);
22+
23+
static PyObject*
24+
set_compatible(PyObject* self, PyObject* args);
25+
26+
static PyObject*
27+
get_pointer(PyObject* self, PyObject* args);
28+
29+
30+
static PyMethodDef ext_methods[] = {
31+
#define declmethod(func) { #func , ( PyCFunction )func , METH_VARARGS , NULL }
32+
declmethod(new_type_manager),
33+
declmethod(select_overload),
34+
declmethod(check_compatible),
35+
declmethod(set_compatible),
36+
declmethod(get_pointer),
37+
{ NULL },
38+
#undef declmethod
39+
};
40+
41+
42+
MOD_INIT(_typeconv) {
43+
PyObject *m;
44+
MOD_DEF(m, "_typeconv", "No docs", ext_methods)
45+
if (m == NULL)
46+
return MOD_ERROR_VAL;
47+
48+
return MOD_SUCCESS_VAL(m);
49+
}
50+
51+
} // end extern C
52+
53+
///////////////////////////////////////////////////////////////////////////////
54+
55+
const char PY_CAPSULE_TM_NAME[] = "*tm";
56+
#define BAD_TM_ARGUMENT PyErr_SetString(PyExc_TypeError, \
57+
"1st argument not TypeManager")
58+
59+
static
60+
TypeManager* unwrap_TypeManager(PyObject *tm) {
61+
void* p = PyCapsule_GetPointer(tm, PY_CAPSULE_TM_NAME);
62+
return reinterpret_cast<TypeManager*>(p);
63+
}
64+
65+
PyObject*
66+
new_type_manager(PyObject* self, PyObject* args)
67+
{
68+
TypeManager* tm = new TypeManager();
69+
return PyCapsule_New(tm, PY_CAPSULE_TM_NAME, &del_type_manager);
70+
}
71+
72+
void
73+
del_type_manager(PyObject *tm)
74+
{
75+
delete unwrap_TypeManager(tm);
76+
}
77+
78+
PyObject*
79+
select_overload(PyObject* self, PyObject* args)
80+
{
81+
PyObject *tmcap, *sigtup, *ovsigstup;
82+
int allow_unsafe;
83+
int exact_match_required;
84+
85+
if (!PyArg_ParseTuple(args, "OOOii", &tmcap, &sigtup, &ovsigstup,
86+
&allow_unsafe, &exact_match_required)) {
87+
return NULL;
88+
}
89+
90+
TypeManager *tm = unwrap_TypeManager(tmcap);
91+
if (!tm) {
92+
BAD_TM_ARGUMENT;
93+
}
94+
95+
Py_ssize_t sigsz = PySequence_Size(sigtup);
96+
Py_ssize_t ovsz = PySequence_Size(ovsigstup);
97+
98+
Type *sig = new Type[sigsz];
99+
Type *ovsigs = new Type[ovsz * sigsz];
100+
101+
for (int i = 0; i < sigsz; ++i) {
102+
sig[i] = Type(PyNumber_AsSsize_t(PySequence_Fast_GET_ITEM(sigtup,
103+
i), NULL));
104+
}
105+
106+
for (int i = 0; i < ovsz; ++i) {
107+
PyObject *cursig = PySequence_Fast_GET_ITEM(ovsigstup, i);
108+
for (int j = 0; j < sigsz; ++j) {
109+
long tid = PyNumber_AsSsize_t(PySequence_Fast_GET_ITEM(cursig,
110+
j), NULL);
111+
ovsigs[i * sigsz + j] = Type(tid);
112+
}
113+
}
114+
115+
int selected = -42;
116+
int matches = tm->selectOverload(sig, ovsigs, selected, sigsz, ovsz,
117+
(bool) allow_unsafe,
118+
(bool) exact_match_required);
119+
120+
delete [] sig;
121+
delete [] ovsigs;
122+
123+
if (matches > 1) {
124+
PyErr_SetString(PyExc_TypeError, "Ambiguous overloading");
125+
return NULL;
126+
} else if (matches == 0) {
127+
PyErr_SetString(PyExc_TypeError, "No compatible overload");
128+
return NULL;
129+
}
130+
131+
return PyLong_FromLong(selected);
132+
}
133+
134+
PyObject*
135+
check_compatible(PyObject* self, PyObject* args)
136+
{
137+
PyObject *tmcap;
138+
int from, to;
139+
if (!PyArg_ParseTuple(args, "Oii", &tmcap, &from, &to)) {
140+
return NULL;
141+
}
142+
143+
TypeManager *tm = unwrap_TypeManager(tmcap);
144+
if(!tm) {
145+
BAD_TM_ARGUMENT;
146+
return NULL;
147+
}
148+
149+
switch(tm->isCompatible(Type(from), Type(to))){
150+
case TCC_EXACT:
151+
return PyString_FromString("exact");
152+
case TCC_PROMOTE:
153+
return PyString_FromString("promote");
154+
case TCC_CONVERT_SAFE:
155+
return PyString_FromString("safe");
156+
case TCC_CONVERT_UNSAFE:
157+
return PyString_FromString("unsafe");
158+
default:
159+
Py_RETURN_NONE;
160+
}
161+
}
162+
163+
PyObject*
164+
set_compatible(PyObject* self, PyObject* args)
165+
{
166+
PyObject *tmcap;
167+
int from, to, by;
168+
if (!PyArg_ParseTuple(args, "Oiii", &tmcap, &from, &to, &by)) {
169+
return NULL;
170+
}
171+
172+
TypeManager *tm = unwrap_TypeManager(tmcap);
173+
if (!tm) {
174+
BAD_TM_ARGUMENT;
175+
return NULL;
176+
}
177+
TypeCompatibleCode tcc;
178+
switch (by) {
179+
case 'p': // promote
180+
tcc = TCC_PROMOTE;
181+
break;
182+
case 's': // safe convert
183+
tcc = TCC_CONVERT_SAFE;
184+
break;
185+
case 'u': // unsafe convert
186+
tcc = TCC_CONVERT_UNSAFE;
187+
break;
188+
default:
189+
PyErr_SetString(PyExc_ValueError, "Unknown TCC");
190+
return NULL;
191+
}
192+
193+
tm->addCompatibility(Type(from), Type(to), tcc);
194+
Py_RETURN_NONE;
195+
}
196+
197+
198+
PyObject*
199+
get_pointer(PyObject* self, PyObject* args)
200+
{
201+
PyObject *tmcap;
202+
if (!PyArg_ParseTuple(args, "O", &tmcap)) {
203+
return NULL;
204+
}
205+
return PyLong_FromVoidPtr(unwrap_TypeManager(tmcap));
206+
}
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
// SPDX-License-Identifier: BSD-2-Clause
3+
4+
/**
5+
6+
This is a modified version of capsulethunk.h for use in llvmpy
7+
8+
**/
9+
10+
#ifndef __CAPSULETHUNK_H
11+
#define __CAPSULETHUNK_H
12+
13+
#if ( (PY_VERSION_HEX < 0x02070000) \
14+
|| ((PY_VERSION_HEX >= 0x03000000) \
15+
&& (PY_VERSION_HEX < 0x03010000)) )
16+
17+
//#define Assert(X) do_assert(!!(X), #X, __FILE__, __LINE__)
18+
#define Assert(X)
19+
20+
static
21+
void do_assert(int cond, const char * msg, const char *file, unsigned line){
22+
if (!cond) {
23+
fprintf(stderr, "Assertion failed %s:%d\n%s\n", file, line, msg);
24+
exit(1);
25+
}
26+
}
27+
28+
typedef void (*PyCapsule_Destructor)(PyObject *);
29+
30+
struct FakePyCapsule_Desc {
31+
const char *name;
32+
void *context;
33+
PyCapsule_Destructor dtor;
34+
PyObject *parent;
35+
36+
FakePyCapsule_Desc() : name(0), context(0), dtor(0) {}
37+
};
38+
39+
static
40+
FakePyCapsule_Desc* get_pycobj_desc(PyObject *p){
41+
void *desc = ((PyCObject*)p)->desc;
42+
Assert(desc && "No desc in PyCObject");
43+
return static_cast<FakePyCapsule_Desc*>(desc);
44+
}
45+
46+
static
47+
void pycobject_pycapsule_dtor(void *p, void *desc){
48+
Assert(desc);
49+
Assert(p);
50+
FakePyCapsule_Desc *fpc_desc = static_cast<FakePyCapsule_Desc*>(desc);
51+
Assert(fpc_desc->parent);
52+
Assert(PyCObject_Check(fpc_desc->parent));
53+
fpc_desc->dtor(static_cast<PyObject*>(fpc_desc->parent));
54+
delete fpc_desc;
55+
}
56+
57+
static
58+
PyObject* PyCapsule_New(void* ptr, const char *name, PyCapsule_Destructor dtor)
59+
{
60+
FakePyCapsule_Desc *desc = new FakePyCapsule_Desc;
61+
desc->name = name;
62+
desc->context = NULL;
63+
desc->dtor = dtor;
64+
PyObject *p = PyCObject_FromVoidPtrAndDesc(ptr, desc,
65+
pycobject_pycapsule_dtor);
66+
desc->parent = p;
67+
return p;
68+
}
69+
70+
static
71+
int PyCapsule_CheckExact(PyObject *p)
72+
{
73+
return PyCObject_Check(p);
74+
}
75+
76+
static
77+
void* PyCapsule_GetPointer(PyObject *p, const char *name)
78+
{
79+
Assert(PyCapsule_CheckExact(p));
80+
if (strcmp(get_pycobj_desc(p)->name, name) != 0) {
81+
PyErr_SetString(PyExc_ValueError, "Invalid PyCapsule object");
82+
}
83+
return PyCObject_AsVoidPtr(p);
84+
}
85+
86+
static
87+
void* PyCapsule_GetContext(PyObject *p)
88+
{
89+
Assert(p);
90+
Assert(PyCapsule_CheckExact(p));
91+
return get_pycobj_desc(p)->context;
92+
}
93+
94+
static
95+
int PyCapsule_SetContext(PyObject *p, void *context)
96+
{
97+
Assert(PyCapsule_CheckExact(p));
98+
get_pycobj_desc(p)->context = context;
99+
return 0;
100+
}
101+
102+
static
103+
const char * PyCapsule_GetName(PyObject *p)
104+
{
105+
// Assert(PyCapsule_CheckExact(p));
106+
return get_pycobj_desc(p)->name;
107+
}
108+
109+
#endif /* #if PY_VERSION_HEX < 0x02070000 */
110+
111+
#endif /* __CAPSULETHUNK_H */

numba_cuda/numba/cuda/core/typeinfer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
NumbaValueError,
4949
)
5050
from numba.cuda.core.funcdesc import qualifying_prefix
51-
from numba.core.typeconv import Conversion
51+
from numba.cuda.typeconv import Conversion
5252

5353
_logger = logging.getLogger(__name__)
5454

numba_cuda/numba/cuda/target.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def resolve_value_type(self, val):
8888
def can_convert(self, fromty, toty):
8989
"""
9090
Check whether conversion is possible from *fromty* to *toty*.
91-
If successful, return a numba.typeconv.Conversion instance;
91+
If successful, return a numba.cuda.typeconv.Conversion instance;
9292
otherwise None is returned.
9393
"""
9494

0 commit comments

Comments
 (0)