diff --git a/python/tvm_ffi/core.pyi b/python/tvm_ffi/core.pyi index c7b35b76..deb6c9f5 100644 --- a/python/tvm_ffi/core.pyi +++ b/python/tvm_ffi/core.pyi @@ -190,6 +190,7 @@ class DLTensorTestWrapper: def _dltensor_test_wrapper_c_dlpack_from_pyobject_as_intptr() -> int: ... class Function(Object): + def __init__(self, func: Callable[..., Any] | None = None) -> None: ... @property def release_gil(self) -> bool: ... @release_gil.setter diff --git a/python/tvm_ffi/cython/function.pxi b/python/tvm_ffi/cython/function.pxi index f66d7dfb..5e6257dc 100644 --- a/python/tvm_ffi/cython/function.pxi +++ b/python/tvm_ffi/cython/function.pxi @@ -889,6 +889,32 @@ cdef class Function(CObject): def __cinit__(self) -> None: self.c_release_gil = _RELEASE_GIL_BY_DEFAULT + def __init__(self, func: Optional[Callable[..., Any]] = None) -> None: + """Initialize a Function from a Python callable. + + This constructor allows creating a `tvm_ffi.Function` directly + from a Python function or another `tvm_ffi.Function` instance. + + Parameters + ---------- + func : Optional[Callable[..., Any]] + The Python callable to wrap. When ``None`` (the default), + the object is left in its default state with a null handle. + """ + if func is None: + return + cdef TVMFFIObjectHandle chandle = NULL + if isinstance(func, Function): + chandle = (func).chandle + if chandle == NULL: + raise ValueError("Cannot initialize from a moved-from Function object.") + TVMFFIObjectIncRef(chandle) + elif callable(func): + _convert_to_ffi_func_handle(func, &chandle) + else: + raise TypeError(f"func must be callable, got {type(func)}") + self.chandle = chandle + property release_gil: """Whether calls release the Python GIL while executing.""" diff --git a/tests/python/test_function.py b/tests/python/test_function.py index 686fc08e..4bd7b8a8 100644 --- a/tests/python/test_function.py +++ b/tests/python/test_function.py @@ -153,6 +153,38 @@ def fapply(f: Any, *args: Any) -> Any: assert fapply(add, 1, 3.3) == 4.3 +def test_pyfunc_init() -> None: + def add(a: int, b: int) -> int: + return a + b + + # Test creating from a Python callable + fadd = tvm_ffi.Function(add) + assert isinstance(fadd, tvm_ffi.Function) + assert fadd(1, 2) == 3 + + # Test creating from an existing tvm_ffi.Function + fadd2 = tvm_ffi.Function(fadd) + assert isinstance(fadd2, tvm_ffi.Function) + assert fadd2(3, 4) == 7 + assert fadd.same_as(fadd2) + + # Test creating from a moved-from function raises ValueError + f_source = tvm_ffi.Function(add) + f_dest = tvm_ffi.Function.__new__(tvm_ffi.Function) + f_dest.__move_handle_from__(f_source) + with pytest.raises(ValueError, match="Cannot initialize from a moved-from Function object"): + tvm_ffi.Function(f_source) + + # Test creating without arguments (backward compat: null handle) + f_empty = tvm_ffi.Function() + assert isinstance(f_empty, tvm_ffi.Function) + assert f_empty.__chandle__() == 0 + + # Test creating from a non-callable raises TypeError + with pytest.raises(TypeError): + tvm_ffi.Function(123) # ty: ignore[invalid-argument-type] + + def test_global_func() -> None: @tvm_ffi.register_global_func("mytest.echo") def echo(x: Any) -> Any: