Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 112 additions & 0 deletions ml_dtypes/_src/custom_complex.h
Original file line number Diff line number Diff line change
Expand Up @@ -904,6 +904,114 @@ bool RegisterComplexUFuncs(PyObject* numpy) {
return ok;
}


// Identical to the NumPy code, we could actually do without the loop.
// (Supports full ufunc path, although this is currently unexposed in NumPy.)
template <bool real_part>
static NPY_CASTING
complex_to_real_resolve_descriptors(
PyObject *NPY_UNUSED(self),
PyArray_DTypeMeta *const dtypes[2],
PyArray_Descr *const given_descrs[2],
PyArray_Descr *loop_descrs[2],
npy_intp *view_offset)
{
Py_INCREF(given_descrs[0]);
loop_descrs[0] = given_descrs[0];
Py_INCREF(dtypes[1]->singleton);
loop_descrs[1] = dtypes[1]->singleton;

if (PyDataType_ISBYTESWAPPED(loop_descrs[0])) {
Py_SETREF(
loop_descrs[1], PyArray_DescrNewByteorder(loop_descrs[1], NPY_SWAP));
if (loop_descrs[1] == NULL) {
Py_DECREF(loop_descrs[0]);
return _NPY_ERROR_OCCURRED_IN_CAST;
}
}
if constexpr (real_part) {
*view_offset = 0;
}
else {
*view_offset = PyDataType_ELSIZE(loop_descrs[1]);
}
return NPY_NO_CASTING;
}


/* We shouldn't normally use it, but define a simple loop anyway. */
template <typename T, bool real_part>
static int extract_complex_part_loop(
PyArrayMethod_Context *context, char *const data[],
npy_intp const dimensions[], npy_intp const strides[],
NpyAuxData *NPY_UNUSED(auxdata))
{
using real_type = typename T::value_type;
npy_intp N = dimensions[0];
char *in = data[0];
char *out = data[1];
npy_intp istride = strides[0];
npy_intp ostride = strides[1];

if constexpr (!real_part) {
in += sizeof(real_type);
}

while (N--) {
real_type value = *reinterpret_cast<real_type *>(in);
*reinterpret_cast<real_type *>(out) = value;
in += istride;
out += ostride;
}
return 0;
}

template <typename T, bool real_part>
int RegisterRealImag(PyArray_DTypeMeta* complex_dtype) {
using real_type = typename T::value_type;
Safe_PyObjectPtr real_descr = make_safe(
(PyObject*)PyArray_DescrFromType(TypeDescriptor<real_type>::Dtype()));
if (!real_descr) {
return -1;
}

static PyType_Slot meth_slots[] = {
{NPY_METH_resolve_descriptors, (void*)&complex_to_real_resolve_descriptors<real_part>},
{NPY_METH_strided_loop, (void*)&extract_complex_part_loop<T, real_part>},
{0, nullptr},
};
PyArray_DTypeMeta* dtypes[2] = {complex_dtype, NPY_DTYPE(real_descr.get())};
PyArrayMethod_Spec meth_spec = {
.name = "generic_real_imag_loop",
.nin = 1,
.nout = 1,
.casting = NPY_NO_CASTING,
.flags = NPY_METH_NO_FLOATINGPOINT_ERRORS,
.dtypes = dtypes,
.slots = meth_slots,
};
constexpr const char* ufunc_name = real_part ? "real" : "imag";
PyUFunc_LoopSlot loop_slots[] = {
{ufunc_name, &meth_spec},
{nullptr, nullptr},
};
return PyUFunc_AddLoopsFromSpecs(loop_slots);
}


template <typename T>
int RegisterRealAndImag(PyArray_DTypeMeta* complex_dtype) {
// TODO: FIXME, once NumPy main is bumped, this needs to be x16 or NPY_2_5_API_VERSION
if (PyArray_RUNTIME_VERSION < 0x15) {
return 0;
}
if (RegisterRealImag<T, true>(complex_dtype) < 0) {
return -1;
}
return RegisterRealImag<T, false>(complex_dtype);
}


template <typename T>
bool RegisterComplexDtype(PyObject* numpy) {
// bases must be a tuple for Python 3.9 and earlier. Change to just pass
Expand Down Expand Up @@ -963,6 +1071,10 @@ bool RegisterComplexDtype(PyObject* numpy) {
CustomComplexType<T>::npy_descr =
PyArray_DescrFromType(TypeDescriptor<T>::npy_type);

if (RegisterRealAndImag<T>(NPY_DTYPE(CustomComplexType<T>::npy_descr)) < 0) {
return false;
}

Safe_PyObjectPtr typeDict_obj =
make_safe(PyObject_GetAttrString(numpy, "sctypeDict"));
if (!typeDict_obj) return false;
Expand Down
22 changes: 22 additions & 0 deletions ml_dtypes/_src/numpy.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,28 @@ limitations under the License.
#include "numpy/arrayobject.h"
#include "numpy/arrayscalars.h"
#include "numpy/ufuncobject.h"
#include "numpy/dtype_api.h"


#ifndef PyUFunc_AddLoopsFromSpecs
// Backport `PyUFunc_AddLoopsFromSpecs` for conditional use if we are
// on NumPy 2.5+ at runtime. (function available with 2.4, but not imag/real)
#if NPY_API_VERSION < 0x15
typedef struct {
const char *name;
PyArrayMethod_Spec *spec;
} PyUFunc_LoopSlot; // defined starting NumPy 2.4+
#endif

static inline int PyUFunc_AddLoopsFromSpecs(PyUFunc_LoopSlot *loop_specs) {
if (PyArray_RUNTIME_VERSION < 0x15) {
return 0; // no-op as function is not available.
}
return (*(int (*)(PyUFunc_LoopSlot *))PyUFunc_API[47])(loop_specs);
}

#endif


namespace ml_dtypes {

Expand Down
15 changes: 15 additions & 0 deletions ml_dtypes/tests/custom_complex_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,21 @@ def test_real_imag_arrays(sctype):
np.testing.assert_array_equal(imag_part, [2.0, 4.0])


@pytest.mark.parametrize("sctype", COMPLEX_SCTYPES)
@pytest.mark.xfail(
np.lib.NumpyVersion(np.__version__) < "2.5.0.dev0",
reason="2.5 introduced real and imag helpers."
)
def test_real_imag_arrays(sctype):
# Test ml_dtypes.real() and ml_dtypes.imag() helpers.
arr = np.array([1 + 2j, 3 + 4j], dtype=sctype)
real_part = ml_dtypes.real(arr)
imag_part = ml_dtypes.imag(arr)
expected_dtype = ml_dtypes.finfo(sctype).dtype # the real one
assert real_part.dtype == imag_part.dtype == expected_dtype
np.testing.assert_array_equal(real_part, [1.0, 3.0])
np.testing.assert_array_equal(imag_part, [2.0, 4.0])

@pytest.mark.parametrize("sctype", COMPLEX_SCTYPES)
@pytest.mark.parametrize("value", COMPLEX_VALUES)
def test_str_repr(sctype, value):
Expand Down