Skip to content

Commit 6253827

Browse files
Merge branch 'main' of github.com:NVIDIA/numba-cuda into ajm/vendor-imputils
2 parents e6b6bd2 + c83f379 commit 6253827

File tree

21 files changed

+1180
-43
lines changed

21 files changed

+1180
-43
lines changed

numba_cuda/numba/cuda/_internal/cuda_fp16.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def __init__(self):
124124
self.bitwidth = 2 * 8
125125

126126
def can_convert_from(self, typingctx, other):
127-
from numba.core.typeconv import Conversion
127+
from numba.cuda.typeconv import Conversion
128128

129129
if other in []:
130130
return Conversion.safe
@@ -174,7 +174,7 @@ def __init__(self):
174174
self.bitwidth = 4 * 8
175175

176176
def can_convert_from(self, typingctx, other):
177-
from numba.core.typeconv import Conversion
177+
from numba.cuda.typeconv import Conversion
178178

179179
if other in []:
180180
return Conversion.safe
@@ -7903,9 +7903,9 @@ def generic(self, args, kws):
79037903
# - Conversion.safe
79047904

79057905
if (
7906-
(convertible == numba.core.typeconv.Conversion.exact)
7907-
or (convertible == numba.core.typeconv.Conversion.promote)
7908-
or (convertible == numba.core.typeconv.Conversion.safe)
7906+
(convertible == numba.cuda.typeconv.Conversion.exact)
7907+
or (convertible == numba.cuda.typeconv.Conversion.promote)
7908+
or (convertible == numba.cuda.typeconv.Conversion.safe)
79097909
):
79107910
return signature(retty, types.float16, types.float16)
79117911

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
// SPDX-License-Identifier: BSD-2-Clause
3+
4+
/*
5+
* Helper functions used by Numba CUDA at runtime.
6+
* This C file is meant to be included after defining the
7+
* NUMBA_EXPORT_FUNC() and NUMBA_EXPORT_DATA() macros.
8+
*/
9+
10+
#include "_pymodule.h"
11+
#include <stddef.h>
12+
13+
/*
14+
* Unicode helpers
15+
*/
16+
17+
/* Developer note:
18+
*
19+
* The hash value of unicode objects is obtained via:
20+
* ((PyASCIIObject *)(obj))->hash;
21+
* The use comes from this definition:
22+
* https://github.com/python/cpython/blob/6d43f6f081023b680d9db4542d19b9e382149f0a/Objects/unicodeobject.c#L119-L120
23+
* and it's used extensively throughout the `cpython/Object/unicodeobject.c`
24+
* source, not least in `unicode_hash` itself:
25+
* https://github.com/python/cpython/blob/6d43f6f081023b680d9db4542d19b9e382149f0a/Objects/unicodeobject.c#L11662-L11679
26+
*
27+
* The Unicode string struct layouts are described here:
28+
* https://github.com/python/cpython/blob/6d43f6f081023b680d9db4542d19b9e382149f0a/Include/cpython/unicodeobject.h#L82-L161
29+
* essentially, all the unicode string layouts start with a `PyASCIIObject` at
30+
* offset 0 (as of commit 6d43f6f081023b680d9db4542d19b9e382149f0a, somewhere
31+
* in the 3.8 development cycle).
32+
*
33+
* For safety against future CPython internal changes, the code checks that the
34+
* _base members of the unicode structs are what is expected in 3.7, and that
35+
* their offset is 0. It then walks the struct to the hash location to make sure
36+
* the offset is indeed the same as PyASCIIObject->hash.
37+
* Note: The large condition in the if should evaluate to a compile time
38+
* constant.
39+
*/
40+
41+
#define MEMBER_SIZE(structure, member) sizeof(((structure *)0)->member)
42+
43+
NUMBA_EXPORT_FUNC(void *)
44+
numba_extract_unicode(PyObject *obj, Py_ssize_t *length, int *kind,
45+
unsigned int *ascii, Py_ssize_t *hash) {
46+
if (!PyUnicode_READY(obj)) {
47+
*length = PyUnicode_GET_LENGTH(obj);
48+
*kind = PyUnicode_KIND(obj);
49+
/* could also use PyUnicode_IS_ASCII but it is not publicly advertised in https://docs.python.org/3/c-api/unicode.html */
50+
*ascii = (unsigned int)(PyUnicode_MAX_CHAR_VALUE(obj) == (0x7f));
51+
/* this is here as a crude check for safe casting of all unicode string
52+
* structs to a PyASCIIObject */
53+
if (MEMBER_SIZE(PyCompactUnicodeObject, _base) == sizeof(PyASCIIObject) &&
54+
MEMBER_SIZE(PyUnicodeObject, _base) == sizeof(PyCompactUnicodeObject) &&
55+
offsetof(PyCompactUnicodeObject, _base) == 0 &&
56+
offsetof(PyUnicodeObject, _base) == 0 &&
57+
offsetof(PyCompactUnicodeObject, _base.hash) == offsetof(PyASCIIObject, hash) &&
58+
offsetof(PyUnicodeObject, _base._base.hash) == offsetof(PyASCIIObject, hash)
59+
) {
60+
/* Grab the hash from the type object cache, do not compute it. */
61+
*hash = ((PyASCIIObject *)(obj))->hash;
62+
}
63+
else {
64+
/* cast is not safe, fail */
65+
return NULL;
66+
}
67+
return PyUnicode_DATA(obj);
68+
} else {
69+
return NULL;
70+
}
71+
}
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
// SPDX-License-Identifier: BSD-2-Clause
3+
4+
/*
5+
* Expose all functions as pointers in a dedicated C extension.
6+
*/
7+
8+
/* Import _pymodule.h first, for a recent _POSIX_C_SOURCE */
9+
#include "_pymodule.h"
10+
11+
/* Visibility control macros */
12+
#if defined(_WIN32) || defined(_WIN64)
13+
#define VISIBILITY_HIDDEN
14+
#define VISIBILITY_GLOBAL __declspec(dllexport)
15+
#else
16+
#define VISIBILITY_HIDDEN __attribute__((visibility("hidden")))
17+
#define VISIBILITY_GLOBAL __attribute__((visibility("default")))
18+
#endif
19+
20+
/* Define all runtime-required symbols in this C module, but do not
21+
export them outside the shared library if possible. */
22+
#define NUMBA_EXPORT_FUNC(_rettype) VISIBILITY_HIDDEN _rettype
23+
#define NUMBA_EXPORT_DATA(_vartype) VISIBILITY_HIDDEN _vartype
24+
25+
/* Numba CUDA C helpers */
26+
#include "_helperlib.c"
27+
28+
static PyObject *
29+
build_c_helpers_dict(void)
30+
{
31+
PyObject *dct = PyDict_New();
32+
if (dct == NULL)
33+
goto error;
34+
35+
#define _declpointer(name, value) do { \
36+
PyObject *o = PyLong_FromVoidPtr(value); \
37+
if (o == NULL) goto error; \
38+
if (PyDict_SetItemString(dct, name, o)) { \
39+
Py_DECREF(o); \
40+
goto error; \
41+
} \
42+
Py_DECREF(o); \
43+
} while (0)
44+
45+
#define declmethod(func) _declpointer(#func, &numba_##func)
46+
47+
/* Unicode string support */
48+
declmethod(extract_unicode);
49+
50+
#undef declmethod
51+
return dct;
52+
error:
53+
Py_XDECREF(dct);
54+
return NULL;
55+
}
56+
57+
static PyMethodDef ext_methods[] = {
58+
{ NULL },
59+
};
60+
61+
MOD_INIT(_helperlib) {
62+
PyObject *m;
63+
MOD_DEF(m, "_helperlib", "No docs", ext_methods)
64+
if (m == NULL)
65+
return MOD_ERROR_VAL;
66+
67+
PyModule_AddObject(m, "c_helpers", build_c_helpers_dict());
68+
PyModule_AddIntConstant(m, "long_min", LONG_MIN);
69+
PyModule_AddIntConstant(m, "long_max", LONG_MAX);
70+
PyModule_AddIntConstant(m, "py_buffer_size", sizeof(Py_buffer));
71+
PyModule_AddIntConstant(m, "py_gil_state_size", sizeof(PyGILState_STATE));
72+
PyModule_AddIntConstant(m, "py_unicode_1byte_kind", PyUnicode_1BYTE_KIND);
73+
PyModule_AddIntConstant(m, "py_unicode_2byte_kind", PyUnicode_2BYTE_KIND);
74+
PyModule_AddIntConstant(m, "py_unicode_4byte_kind", PyUnicode_4BYTE_KIND);
75+
#if (PY_MAJOR_VERSION == 3)
76+
#if ((PY_MINOR_VERSION == 10) || (PY_MINOR_VERSION == 11))
77+
PyModule_AddIntConstant(m, "py_unicode_wchar_kind", PyUnicode_WCHAR_KIND);
78+
#endif
79+
#endif
80+
81+
return MOD_SUCCESS_VAL(m);
82+
}
Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
// SPDX-License-Identifier: BSD-2-Clause
3+
4+
#include "_pymodule.h"
5+
#include "capsulethunk.h"
6+
#include "typeconv.hpp"
7+
8+
extern "C" {
9+
10+
11+
static PyObject*
12+
new_type_manager(PyObject* self, PyObject* args);
13+
14+
static void
15+
del_type_manager(PyObject *);
16+
17+
static PyObject*
18+
select_overload(PyObject* self, PyObject* args);
19+
20+
static PyObject*
21+
check_compatible(PyObject* self, PyObject* args);
22+
23+
static PyObject*
24+
set_compatible(PyObject* self, PyObject* args);
25+
26+
static PyObject*
27+
get_pointer(PyObject* self, PyObject* args);
28+
29+
30+
static PyMethodDef ext_methods[] = {
31+
#define declmethod(func) { #func , ( PyCFunction )func , METH_VARARGS , NULL }
32+
declmethod(new_type_manager),
33+
declmethod(select_overload),
34+
declmethod(check_compatible),
35+
declmethod(set_compatible),
36+
declmethod(get_pointer),
37+
{ NULL },
38+
#undef declmethod
39+
};
40+
41+
42+
MOD_INIT(_typeconv) {
43+
PyObject *m;
44+
MOD_DEF(m, "_typeconv", "No docs", ext_methods)
45+
if (m == NULL)
46+
return MOD_ERROR_VAL;
47+
48+
return MOD_SUCCESS_VAL(m);
49+
}
50+
51+
} // end extern C
52+
53+
///////////////////////////////////////////////////////////////////////////////
54+
55+
const char PY_CAPSULE_TM_NAME[] = "*tm";
56+
#define BAD_TM_ARGUMENT PyErr_SetString(PyExc_TypeError, \
57+
"1st argument not TypeManager")
58+
59+
static
60+
TypeManager* unwrap_TypeManager(PyObject *tm) {
61+
void* p = PyCapsule_GetPointer(tm, PY_CAPSULE_TM_NAME);
62+
return reinterpret_cast<TypeManager*>(p);
63+
}
64+
65+
PyObject*
66+
new_type_manager(PyObject* self, PyObject* args)
67+
{
68+
TypeManager* tm = new TypeManager();
69+
return PyCapsule_New(tm, PY_CAPSULE_TM_NAME, &del_type_manager);
70+
}
71+
72+
void
73+
del_type_manager(PyObject *tm)
74+
{
75+
delete unwrap_TypeManager(tm);
76+
}
77+
78+
PyObject*
79+
select_overload(PyObject* self, PyObject* args)
80+
{
81+
PyObject *tmcap, *sigtup, *ovsigstup;
82+
int allow_unsafe;
83+
int exact_match_required;
84+
85+
if (!PyArg_ParseTuple(args, "OOOii", &tmcap, &sigtup, &ovsigstup,
86+
&allow_unsafe, &exact_match_required)) {
87+
return NULL;
88+
}
89+
90+
TypeManager *tm = unwrap_TypeManager(tmcap);
91+
if (!tm) {
92+
BAD_TM_ARGUMENT;
93+
}
94+
95+
Py_ssize_t sigsz = PySequence_Size(sigtup);
96+
Py_ssize_t ovsz = PySequence_Size(ovsigstup);
97+
98+
Type *sig = new Type[sigsz];
99+
Type *ovsigs = new Type[ovsz * sigsz];
100+
101+
for (int i = 0; i < sigsz; ++i) {
102+
sig[i] = Type(PyNumber_AsSsize_t(PySequence_Fast_GET_ITEM(sigtup,
103+
i), NULL));
104+
}
105+
106+
for (int i = 0; i < ovsz; ++i) {
107+
PyObject *cursig = PySequence_Fast_GET_ITEM(ovsigstup, i);
108+
for (int j = 0; j < sigsz; ++j) {
109+
long tid = PyNumber_AsSsize_t(PySequence_Fast_GET_ITEM(cursig,
110+
j), NULL);
111+
ovsigs[i * sigsz + j] = Type(tid);
112+
}
113+
}
114+
115+
int selected = -42;
116+
int matches = tm->selectOverload(sig, ovsigs, selected, sigsz, ovsz,
117+
(bool) allow_unsafe,
118+
(bool) exact_match_required);
119+
120+
delete [] sig;
121+
delete [] ovsigs;
122+
123+
if (matches > 1) {
124+
PyErr_SetString(PyExc_TypeError, "Ambiguous overloading");
125+
return NULL;
126+
} else if (matches == 0) {
127+
PyErr_SetString(PyExc_TypeError, "No compatible overload");
128+
return NULL;
129+
}
130+
131+
return PyLong_FromLong(selected);
132+
}
133+
134+
PyObject*
135+
check_compatible(PyObject* self, PyObject* args)
136+
{
137+
PyObject *tmcap;
138+
int from, to;
139+
if (!PyArg_ParseTuple(args, "Oii", &tmcap, &from, &to)) {
140+
return NULL;
141+
}
142+
143+
TypeManager *tm = unwrap_TypeManager(tmcap);
144+
if(!tm) {
145+
BAD_TM_ARGUMENT;
146+
return NULL;
147+
}
148+
149+
switch(tm->isCompatible(Type(from), Type(to))){
150+
case TCC_EXACT:
151+
return PyString_FromString("exact");
152+
case TCC_PROMOTE:
153+
return PyString_FromString("promote");
154+
case TCC_CONVERT_SAFE:
155+
return PyString_FromString("safe");
156+
case TCC_CONVERT_UNSAFE:
157+
return PyString_FromString("unsafe");
158+
default:
159+
Py_RETURN_NONE;
160+
}
161+
}
162+
163+
PyObject*
164+
set_compatible(PyObject* self, PyObject* args)
165+
{
166+
PyObject *tmcap;
167+
int from, to, by;
168+
if (!PyArg_ParseTuple(args, "Oiii", &tmcap, &from, &to, &by)) {
169+
return NULL;
170+
}
171+
172+
TypeManager *tm = unwrap_TypeManager(tmcap);
173+
if (!tm) {
174+
BAD_TM_ARGUMENT;
175+
return NULL;
176+
}
177+
TypeCompatibleCode tcc;
178+
switch (by) {
179+
case 'p': // promote
180+
tcc = TCC_PROMOTE;
181+
break;
182+
case 's': // safe convert
183+
tcc = TCC_CONVERT_SAFE;
184+
break;
185+
case 'u': // unsafe convert
186+
tcc = TCC_CONVERT_UNSAFE;
187+
break;
188+
default:
189+
PyErr_SetString(PyExc_ValueError, "Unknown TCC");
190+
return NULL;
191+
}
192+
193+
tm->addCompatibility(Type(from), Type(to), tcc);
194+
Py_RETURN_NONE;
195+
}
196+
197+
198+
PyObject*
199+
get_pointer(PyObject* self, PyObject* args)
200+
{
201+
PyObject *tmcap;
202+
if (!PyArg_ParseTuple(args, "O", &tmcap)) {
203+
return NULL;
204+
}
205+
return PyLong_FromVoidPtr(unwrap_TypeManager(tmcap));
206+
}

0 commit comments

Comments
 (0)