diff --git a/ml_dtypes/_src/custom_complex.h b/ml_dtypes/_src/custom_complex.h index c640aee7..7e437aef 100644 --- a/ml_dtypes/_src/custom_complex.h +++ b/ml_dtypes/_src/custom_complex.h @@ -904,6 +904,114 @@ bool RegisterComplexUFuncs(PyObject* numpy) { return ok; } + +// Identical to the NumPy code, we could actually do without the loop. +// (Supports full ufunc path, although this is currently unexposed in NumPy.) +template +static NPY_CASTING +complex_to_real_resolve_descriptors( + PyObject *NPY_UNUSED(self), + PyArray_DTypeMeta *const dtypes[2], + PyArray_Descr *const given_descrs[2], + PyArray_Descr *loop_descrs[2], + npy_intp *view_offset) +{ + Py_INCREF(given_descrs[0]); + loop_descrs[0] = given_descrs[0]; + Py_INCREF(dtypes[1]->singleton); + loop_descrs[1] = dtypes[1]->singleton; + + if (PyDataType_ISBYTESWAPPED(loop_descrs[0])) { + Py_SETREF( + loop_descrs[1], PyArray_DescrNewByteorder(loop_descrs[1], NPY_SWAP)); + if (loop_descrs[1] == NULL) { + Py_DECREF(loop_descrs[0]); + return _NPY_ERROR_OCCURRED_IN_CAST; + } + } + if constexpr (real_part) { + *view_offset = 0; + } + else { + *view_offset = PyDataType_ELSIZE(loop_descrs[1]); + } + return NPY_NO_CASTING; +} + + +/* We shouldn't normally use it, but define a simple loop anyway. */ +template +static int extract_complex_part_loop( + PyArrayMethod_Context *context, char *const data[], + npy_intp const dimensions[], npy_intp const strides[], + NpyAuxData *NPY_UNUSED(auxdata)) +{ + using real_type = typename T::value_type; + npy_intp N = dimensions[0]; + char *in = data[0]; + char *out = data[1]; + npy_intp istride = strides[0]; + npy_intp ostride = strides[1]; + + if constexpr (!real_part) { + in += sizeof(real_type); + } + + while (N--) { + real_type value = *reinterpret_cast(in); + *reinterpret_cast(out) = value; + in += istride; + out += ostride; + } + return 0; +} + +template +int RegisterRealImag(PyArray_DTypeMeta* complex_dtype) { + using real_type = typename T::value_type; + Safe_PyObjectPtr real_descr = make_safe( + (PyObject*)PyArray_DescrFromType(TypeDescriptor::Dtype())); + if (!real_descr) { + return -1; + } + + static PyType_Slot meth_slots[] = { + {NPY_METH_resolve_descriptors, (void*)&complex_to_real_resolve_descriptors}, + {NPY_METH_strided_loop, (void*)&extract_complex_part_loop}, + {0, nullptr}, + }; + PyArray_DTypeMeta* dtypes[2] = {complex_dtype, NPY_DTYPE(real_descr.get())}; + PyArrayMethod_Spec meth_spec = { + .name = "generic_real_imag_loop", + .nin = 1, + .nout = 1, + .casting = NPY_NO_CASTING, + .flags = NPY_METH_NO_FLOATINGPOINT_ERRORS, + .dtypes = dtypes, + .slots = meth_slots, + }; + constexpr const char* ufunc_name = real_part ? "real" : "imag"; + PyUFunc_LoopSlot loop_slots[] = { + {ufunc_name, &meth_spec}, + {nullptr, nullptr}, + }; + return PyUFunc_AddLoopsFromSpecs(loop_slots); +} + + +template +int RegisterRealAndImag(PyArray_DTypeMeta* complex_dtype) { + // TODO: FIXME, once NumPy main is bumped, this needs to be x16 or NPY_2_5_API_VERSION + if (PyArray_RUNTIME_VERSION < 0x15) { + return 0; + } + if (RegisterRealImag(complex_dtype) < 0) { + return -1; + } + return RegisterRealImag(complex_dtype); +} + + template bool RegisterComplexDtype(PyObject* numpy) { // bases must be a tuple for Python 3.9 and earlier. Change to just pass @@ -963,6 +1071,10 @@ bool RegisterComplexDtype(PyObject* numpy) { CustomComplexType::npy_descr = PyArray_DescrFromType(TypeDescriptor::npy_type); + if (RegisterRealAndImag(NPY_DTYPE(CustomComplexType::npy_descr)) < 0) { + return false; + } + Safe_PyObjectPtr typeDict_obj = make_safe(PyObject_GetAttrString(numpy, "sctypeDict")); if (!typeDict_obj) return false; diff --git a/ml_dtypes/_src/numpy.h b/ml_dtypes/_src/numpy.h index 8b55e4d9..e1f0ba7a 100644 --- a/ml_dtypes/_src/numpy.h +++ b/ml_dtypes/_src/numpy.h @@ -37,6 +37,28 @@ limitations under the License. #include "numpy/arrayobject.h" #include "numpy/arrayscalars.h" #include "numpy/ufuncobject.h" +#include "numpy/dtype_api.h" + + +#ifndef PyUFunc_AddLoopsFromSpecs +// Backport `PyUFunc_AddLoopsFromSpecs` for conditional use if we are +// on NumPy 2.5+ at runtime. (function available with 2.4, but not imag/real) +#if NPY_API_VERSION < 0x15 +typedef struct { + const char *name; + PyArrayMethod_Spec *spec; +} PyUFunc_LoopSlot; // defined starting NumPy 2.4+ +#endif + +static inline int PyUFunc_AddLoopsFromSpecs(PyUFunc_LoopSlot *loop_specs) { + if (PyArray_RUNTIME_VERSION < 0x15) { + return 0; // no-op as function is not available. + } + return (*(int (*)(PyUFunc_LoopSlot *))PyUFunc_API[47])(loop_specs); +} + +#endif + namespace ml_dtypes { diff --git a/ml_dtypes/tests/custom_complex_test.py b/ml_dtypes/tests/custom_complex_test.py index ddf76ff9..44783fbf 100644 --- a/ml_dtypes/tests/custom_complex_test.py +++ b/ml_dtypes/tests/custom_complex_test.py @@ -129,6 +129,21 @@ def test_real_imag_arrays(sctype): np.testing.assert_array_equal(imag_part, [2.0, 4.0]) +@pytest.mark.parametrize("sctype", COMPLEX_SCTYPES) +@pytest.mark.xfail( + np.lib.NumpyVersion(np.__version__) < "2.5.0.dev0", + reason="2.5 introduced real and imag helpers." +) +def test_real_imag_arrays(sctype): + # Test ml_dtypes.real() and ml_dtypes.imag() helpers. + arr = np.array([1 + 2j, 3 + 4j], dtype=sctype) + real_part = ml_dtypes.real(arr) + imag_part = ml_dtypes.imag(arr) + expected_dtype = ml_dtypes.finfo(sctype).dtype # the real one + assert real_part.dtype == imag_part.dtype == expected_dtype + np.testing.assert_array_equal(real_part, [1.0, 3.0]) + np.testing.assert_array_equal(imag_part, [2.0, 4.0]) + @pytest.mark.parametrize("sctype", COMPLEX_SCTYPES) @pytest.mark.parametrize("value", COMPLEX_VALUES) def test_str_repr(sctype, value):