Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/tvm_ffi/core.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ class DLTensorTestWrapper:
def _dltensor_test_wrapper_c_dlpack_from_pyobject_as_intptr() -> int: ...

class Function(Object):
def __init__(self, func: Callable[..., Any] | None = None) -> None: ...
@property
def release_gil(self) -> bool: ...
@release_gil.setter
Expand Down
26 changes: 26 additions & 0 deletions python/tvm_ffi/cython/function.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -889,6 +889,32 @@ cdef class Function(CObject):
def __cinit__(self) -> None:
self.c_release_gil = _RELEASE_GIL_BY_DEFAULT

def __init__(self, func: Optional[Callable[..., Any]] = None) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let us wait until cutedsl release cycles fixes the compact issue, the land this one

"""Initialize a Function from a Python callable.

This constructor allows creating a `tvm_ffi.Function` directly
from a Python function or another `tvm_ffi.Function` instance.

Parameters
----------
func : Optional[Callable[..., Any]]
The Python callable to wrap. When ``None`` (the default),
the object is left in its default state with a null handle.
"""
if func is None:
return
cdef TVMFFIObjectHandle chandle = NULL
if isinstance(func, Function):
chandle = (<CObject>func).chandle
if chandle == NULL:
raise ValueError("Cannot initialize from a moved-from Function object.")
TVMFFIObjectIncRef(chandle)
elif callable(func):
_convert_to_ffi_func_handle(func, &chandle)
else:
raise TypeError(f"func must be callable, got {type(func)}")
self.chandle = chandle

property release_gil:
"""Whether calls release the Python GIL while executing."""

Expand Down
32 changes: 32 additions & 0 deletions tests/python/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,38 @@ def fapply(f: Any, *args: Any) -> Any:
assert fapply(add, 1, 3.3) == 4.3


def test_pyfunc_init() -> None:
def add(a: int, b: int) -> int:
return a + b

# Test creating from a Python callable
fadd = tvm_ffi.Function(add)
assert isinstance(fadd, tvm_ffi.Function)
assert fadd(1, 2) == 3

# Test creating from an existing tvm_ffi.Function
fadd2 = tvm_ffi.Function(fadd)
assert isinstance(fadd2, tvm_ffi.Function)
assert fadd2(3, 4) == 7
assert fadd.same_as(fadd2)

# Test creating from a moved-from function raises ValueError
f_source = tvm_ffi.Function(add)
f_dest = tvm_ffi.Function.__new__(tvm_ffi.Function)
f_dest.__move_handle_from__(f_source)
with pytest.raises(ValueError, match="Cannot initialize from a moved-from Function object"):
tvm_ffi.Function(f_source)

# Test creating without arguments (backward compat: null handle)
f_empty = tvm_ffi.Function()
assert isinstance(f_empty, tvm_ffi.Function)
assert f_empty.__chandle__() == 0

# Test creating from a non-callable raises TypeError
with pytest.raises(TypeError):
tvm_ffi.Function(123) # ty: ignore[invalid-argument-type]


def test_global_func() -> None:
@tvm_ffi.register_global_func("mytest.echo")
def echo(x: Any) -> Any:
Expand Down