diff --git a/include/tvm/ffi/c_api.h b/include/tvm/ffi/c_api.h index 3f40bc59..0aa01579 100644 --- a/include/tvm/ffi/c_api.h +++ b/include/tvm/ffi/c_api.h @@ -913,6 +913,15 @@ typedef enum { * By default this flag is off (meaning the field accepts positional arguments). */ kTVMFFIFieldFlagBitMaskKwOnly = 1 << 10, + /*! + * \brief The setter field is a TVMFFIObjectHandle pointing to a FunctionObj. + * + * When this flag is set, the ``setter`` member of TVMFFIFieldInfo is not a + * TVMFFIFieldSetter function pointer but instead a TVMFFIObjectHandle + * pointing to a FunctionObj. The FunctionObj is called with two arguments: + * ``(field_addr_as_OpaquePtr, value_as_AnyView)``. + */ + kTVMFFIFieldFlagBitSetterIsFunctionObj = 1 << 11, #ifdef __cplusplus }; #else @@ -1008,9 +1017,15 @@ typedef struct { TVMFFIFieldGetter getter; /*! * \brief The setter to access the field. + * + * When kTVMFFIFieldFlagBitSetterIsFunctionObj is NOT set (default), + * this is a TVMFFIFieldSetter function pointer (cast to void*). + * When kTVMFFIFieldFlagBitSetterIsFunctionObj IS set, + * this is a TVMFFIObjectHandle pointing to a FunctionObj. + * * \note The setter is set even if the field is readonly for serialization. */ - TVMFFIFieldSetter setter; + void* setter; /*! * \brief The default value or default factory of the field. * diff --git a/include/tvm/ffi/function.h b/include/tvm/ffi/function.h index 2ee1a0df..48c9d5b2 100644 --- a/include/tvm/ffi/function.h +++ b/include/tvm/ffi/function.h @@ -1066,6 +1066,66 @@ inline int32_t TypeKeyToIndex(std::string_view type_key) { #else #define TVM_FFI_DLL_EXPORT_TYPED_FUNC_DOC(ExportName, DocString) #endif + +/*! + * \brief Create an empty object via the type's native creator or ``__ffi_new__`` type attr. + * + * Falls back to the ``__ffi_new__`` type attribute (used by Python-defined types) + * when the native ``metadata->creator`` is NULL. + * + * \param type_info The type info for the object to create. + * \return An owned ObjectPtr to the newly allocated (zero-initialized) object. + * \throws RuntimeError if neither creator nor __ffi_new__ is available. + */ +inline ObjectPtr CreateEmptyObject(const TVMFFITypeInfo* type_info) { + // Fast path: native C++ creator + if (type_info->metadata != nullptr && type_info->metadata->creator != nullptr) { + TVMFFIObjectHandle handle; + TVM_FFI_CHECK_SAFE_CALL(type_info->metadata->creator(&handle)); + return details::ObjectUnsafe::ObjectPtrFromOwned(static_cast(handle)); + } + // Fallback: __ffi_new__ type attr (Python-defined types) + constexpr TVMFFIByteArray kFFINewAttrName = {"__ffi_new__", 11}; + const TVMFFITypeAttrColumn* column = TVMFFIGetTypeAttrColumn(&kFFINewAttrName); + if (column != nullptr) { + int32_t offset = type_info->type_index - column->begin_index; + if (offset >= 0 && offset < column->size) { + AnyView attr_view = AnyView::CopyFromTVMFFIAny(column->data[offset]); + if (auto opt_func = attr_view.try_cast()) { + ObjectRef obj_ref = (*opt_func)().cast(); + return details::ObjectUnsafe::ObjectPtrFromObjectRef(std::move(obj_ref)); + } + } + } + TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_info->type_index) + << "` does not support reflection creation" + << " (no native creator or __ffi_new__ type attr)"; +} + +/*! + * \brief Check whether a type supports reflection creation. + * + * Returns true if the type has a native creator or a ``__ffi_new__`` type attr. + * + * \param type_info The type info to check. + * \return true if CreateEmptyObject would succeed. + */ +inline bool HasCreator(const TVMFFITypeInfo* type_info) { + if (type_info->metadata != nullptr && type_info->metadata->creator != nullptr) { + return true; + } + constexpr TVMFFIByteArray kFFINewAttrName = {"__ffi_new__", 11}; + const TVMFFITypeAttrColumn* column = TVMFFIGetTypeAttrColumn(&kFFINewAttrName); + if (column != nullptr) { + int32_t offset = type_info->type_index - column->begin_index; + if (offset >= 0 && offset < column->size && + column->data[offset].type_index >= kTVMFFIStaticObjectBegin) { + return true; + } + } + return false; +} + } // namespace ffi } // namespace tvm #endif // TVM_FFI_FUNCTION_H_ diff --git a/include/tvm/ffi/reflection/accessor.h b/include/tvm/ffi/reflection/accessor.h index 700f7b8c..b08cd6aa 100644 --- a/include/tvm/ffi/reflection/accessor.h +++ b/include/tvm/ffi/reflection/accessor.h @@ -52,6 +52,37 @@ inline const TVMFFIFieldInfo* GetFieldInfo(std::string_view type_key, const char TVM_FFI_UNREACHABLE(); } +/*! + * \brief Call the field setter, dispatching between function pointer and FunctionObj. + * + * When kTVMFFIFieldFlagBitSetterIsFunctionObj is off, invokes the setter as a + * TVMFFIFieldSetter function pointer. When on, calls via TVMFFIFunctionCall + * with arguments (field_addr as OpaquePtr, value). + * + * \param field_info The field info containing the setter. + * \param field_addr The address of the field in the object. + * \param value The value to set (as a TVMFFIAny pointer). + * \return 0 on success, nonzero on failure. + */ +inline int CallFieldSetter(const TVMFFIFieldInfo* field_info, void* field_addr, + const TVMFFIAny* value) { + if (!(field_info->flags & kTVMFFIFieldFlagBitSetterIsFunctionObj)) { + auto setter = reinterpret_cast(field_info->setter); + return setter(field_addr, value); + } else { + TVMFFIAny args[2]; + args[0].type_index = kTVMFFIOpaquePtr; + args[0].zero_padding = 0; + args[0].v_ptr = field_addr; + args[1] = *value; + TVMFFIAny result; + result.type_index = kTVMFFINone; + result.v_int64 = 0; + return TVMFFIFunctionCall(static_cast(field_info->setter), args, 2, + &result); + } +} + /*! * \brief helper wrapper class to obtain a getter. */ @@ -118,8 +149,8 @@ class FieldSetter { */ void operator()(const Object* obj_ptr, AnyView value) const { const void* addr = reinterpret_cast(obj_ptr) + field_info_->offset; - TVM_FFI_CHECK_SAFE_CALL( - field_info_->setter(const_cast(addr), reinterpret_cast(&value))); + TVM_FFI_CHECK_SAFE_CALL(CallFieldSetter(field_info_, const_cast(addr), + reinterpret_cast(&value))); } void operator()(const ObjectPtr& obj_ptr, AnyView value) const { @@ -215,9 +246,9 @@ inline void SetFieldToDefault(const TVMFFIFieldInfo* field_info, void* field_add Function factory = AnyView::CopyFromTVMFFIAny(field_info->default_value_or_factory).cast(); Any default_val = factory(); - field_info->setter(field_addr, reinterpret_cast(&default_val)); + CallFieldSetter(field_info, field_addr, reinterpret_cast(&default_val)); } else { - field_info->setter(field_addr, &(field_info->default_value_or_factory)); + CallFieldSetter(field_info, field_addr, &(field_info->default_value_or_factory)); } } diff --git a/include/tvm/ffi/reflection/creator.h b/include/tvm/ffi/reflection/creator.h index a7e860c1..6df175b6 100644 --- a/include/tvm/ffi/reflection/creator.h +++ b/include/tvm/ffi/reflection/creator.h @@ -48,13 +48,8 @@ class ObjectCreator { * \param type_info The type info. */ explicit ObjectCreator(const TVMFFITypeInfo* type_info) : type_info_(type_info) { - int32_t type_index = type_info->type_index; - if (type_info->metadata == nullptr) { - TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_index) - << "` does not have reflection registered"; - } - if (type_info->metadata->creator == nullptr) { - TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_index) + if (!HasCreator(type_info)) { + TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_info->type_index) << "` does not support default constructor, " << "as a result cannot be created via reflection"; } @@ -66,17 +61,14 @@ class ObjectCreator { * \return The created object. */ Any operator()(const Map& fields) const { - TVMFFIObjectHandle handle; - TVM_FFI_CHECK_SAFE_CALL(type_info_->metadata->creator(&handle)); - ObjectPtr ptr = - details::ObjectUnsafe::ObjectPtrFromOwned(static_cast(handle)); + ObjectPtr ptr = CreateEmptyObject(type_info_); size_t match_field_count = 0; ForEachFieldInfo(type_info_, [&](const TVMFFIFieldInfo* field_info) { String field_name(field_info->name); void* field_addr = reinterpret_cast(ptr.get()) + field_info->offset; if (fields.count(field_name) != 0) { Any field_value = fields[field_name]; - field_info->setter(field_addr, reinterpret_cast(&field_value)); + CallFieldSetter(field_info, field_addr, reinterpret_cast(&field_value)); ++match_field_count; } else if (field_info->flags & kTVMFFIFieldFlagBitMaskHasDefault) { SetFieldToDefault(field_info, field_addr); diff --git a/include/tvm/ffi/reflection/init.h b/include/tvm/ffi/reflection/init.h index 1a1aa21f..272d8071 100644 --- a/include/tvm/ffi/reflection/init.h +++ b/include/tvm/ffi/reflection/init.h @@ -25,6 +25,7 @@ #include #include +#include #include #include #include @@ -42,6 +43,20 @@ namespace tvm { namespace ffi { namespace reflection { +namespace details { + +template +TObjectRef CastFromAny(AnyView input) { + TVMFFIAny input_pod = input.CopyToTVMFFIAny(); + if (auto opt = TypeTraits::TryCastFromAnyView(&input_pod)) { + return *std::move(opt); + } + TVM_FFI_THROW(TypeError) << "Cannot cast from `" << TypeIndexToTypeKey(input_pod.type_index) + << "` to `" << TypeTraits::TypeStr() << "`"; +} + +} // namespace details + /*! * \brief Create a packed ``__ffi_init__`` constructor for the given type. * @@ -69,10 +84,8 @@ inline Function MakeInit(int32_t type_index) { }; // ---- Pre-compute field analysis (once per type) ------------------------- const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(type_index); - TVM_FFI_ICHECK(type_info->metadata != nullptr) - << "Type `" << TypeIndexToTypeKey(type_index) << "` has no reflection metadata"; - TVM_FFI_ICHECK(type_info->metadata->creator != nullptr) - << "Type `" << TypeIndexToTypeKey(type_index) << "` has no creator"; + TVM_FFI_ICHECK(HasCreator(type_info)) << "Type `" << TypeIndexToTypeKey(type_index) + << "` has no creator or __ffi_new__ for __ffi_init__"; auto info = std::make_shared(); info->type_key = std::string_view(type_info->type_key.data, type_info->type_key.size); @@ -101,16 +114,11 @@ inline Function MakeInit(int32_t type_index) { // Eagerly resolve the KWARGS sentinel via global function registry. ObjectRef kwargs_sentinel = Function::GetGlobalRequired("ffi.GetKwargsObject")().cast(); - // Cache pointers for the lambda (avoid repeated lookups). - TVMFFIObjectCreator creator = type_info->metadata->creator; return Function::FromPacked( - [info, kwargs_sentinel, creator](PackedArgs args, Any* rv) { - // ---- 1. Create object via creator ------------------------------------ - TVMFFIObjectHandle handle; - TVM_FFI_CHECK_SAFE_CALL(creator(&handle)); - ObjectPtr obj_ptr = - details::ObjectUnsafe::ObjectPtrFromOwned(static_cast(handle)); + [info, kwargs_sentinel, type_info](PackedArgs args, Any* rv) { + // ---- 1. Create object via CreateEmptyObject -------------------------- + ObjectPtr obj_ptr = CreateEmptyObject(type_info); // ---- 2. Find KWARGS sentinel position -------------------------------- int kwargs_pos = -1; @@ -128,7 +136,7 @@ inline Function MakeInit(int32_t type_index) { auto set_field = [&](size_t fi, const TVMFFIAny* value) { void* addr = reinterpret_cast(obj_ptr.get()) + info->all_fields[fi].info->offset; - TVM_FFI_CHECK_SAFE_CALL(info->all_fields[fi].info->setter(addr, value)); + TVM_FFI_CHECK_SAFE_CALL(CallFieldSetter(info->all_fields[fi].info, addr, value)); field_set[fi] = true; }; @@ -219,7 +227,7 @@ inline void RegisterAutoInit(int32_t type_index) { info.flags = kTVMFFIFieldFlagBitMaskIsStaticMethod; info.method = AnyView(auto_init_fn).CopyToTVMFFIAny(); static const std::string kMetadata = - "{\"type_schema\":" + std::string(details::TypeSchemaImpl::v()) + + "{\"type_schema\":" + std::string(::tvm::ffi::details::TypeSchemaImpl::v()) + ",\"auto_init\":true}"; info.metadata = TVMFFIByteArray{kMetadata.c_str(), kMetadata.size()}; TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterMethod(type_index, &info)); diff --git a/include/tvm/ffi/reflection/overload.h b/include/tvm/ffi/reflection/overload.h index 6d0e783d..80e9bce9 100644 --- a/include/tvm/ffi/reflection/overload.h +++ b/include/tvm/ffi/reflection/overload.h @@ -404,14 +404,15 @@ class OverloadObjectDef : private ObjectDef { template static auto GetOverloadMethod(std::string name, Func&& func) { using WrapFn = decltype(WrapFunction(std::forward(func))); - using OverloadFn = details::OverloadedFunction>; + using OverloadFn = ::tvm::ffi::details::OverloadedFunction>; return ffi::Function::FromPackedInplace(WrapFunction(std::forward(func)), std::move(name)); } template static auto NewOverload(std::string name, Func&& func) { - return details::CreateNewOverload(WrapFunction(std::forward(func)), std::move(name)); + return ::tvm::ffi::details::CreateNewOverload(WrapFunction(std::forward(func)), + std::move(name)); } template @@ -448,11 +449,11 @@ class OverloadObjectDef : private ObjectDef { info.flags |= kTVMFFIFieldFlagBitMaskWritable; } info.getter = ReflectionDefBase::FieldGetter; - info.setter = ReflectionDefBase::FieldSetter; + info.setter = reinterpret_cast(ReflectionDefBase::FieldSetter); // initialize default value to nullptr info.default_value_or_factory = AnyView(nullptr).CopyToTVMFFIAny(); info.doc = TVMFFIByteArray{nullptr, 0}; - info.metadata_.emplace_back("type_schema", details::TypeSchema::v()); + info.metadata_.emplace_back("type_schema", ::tvm::ffi::details::TypeSchema::v()); // apply field info traits ((ApplyFieldInfoTrait(&info, std::forward(extra_args)), ...)); // call register @@ -464,7 +465,7 @@ class OverloadObjectDef : private ObjectDef { // register a method template void RegisterMethod(const char* name, bool is_static, Func&& func, Extra&&... extra) { - using FuncInfo = details::FunctionInfo>; + using FuncInfo = ::tvm::ffi::details::FunctionInfo>; MethodInfoBuilder info; info.name = TVMFFIByteArray{name, std::char_traits::length(name)}; info.doc = TVMFFIByteArray{nullptr, 0}; @@ -478,7 +479,7 @@ class OverloadObjectDef : private ObjectDef { // if an overload method exists, register to existing overload function if (const auto overload_it = registered_fields_.find(name); overload_it != registered_fields_.end()) { - details::OverloadBase* overload_ptr = overload_it->second; + ::tvm::ffi::details::OverloadBase* overload_ptr = overload_it->second; return overload_ptr->Register(NewOverload(std::move(method_name), std::forward(func))); } @@ -496,7 +497,7 @@ class OverloadObjectDef : private ObjectDef { TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterMethod(type_index_, &info)); } - std::unordered_map registered_fields_; + std::unordered_map registered_fields_; }; } // namespace reflection diff --git a/include/tvm/ffi/reflection/registry.h b/include/tvm/ffi/reflection/registry.h index 20760155..8fc40dc2 100644 --- a/include/tvm/ffi/reflection/registry.h +++ b/include/tvm/ffi/reflection/registry.h @@ -362,7 +362,8 @@ template TVM_FFI_INLINE int64_t GetFieldByteOffsetToObject(T Class::* field_ptr) { int64_t field_offset_to_class = reinterpret_cast(&(static_cast(nullptr)->*field_ptr)); - return field_offset_to_class - details::ObjectUnsafe::GetObjectOffsetToSubclass(); + return field_offset_to_class - + ::tvm::ffi::details::ObjectUnsafe::GetObjectOffsetToSubclass(); } /// \cond Doxygen_Suppress @@ -371,7 +372,7 @@ class ReflectionDefBase { template static int FieldGetter(void* field, TVMFFIAny* result) { TVM_FFI_SAFE_CALL_BEGIN(); - *result = details::AnyUnsafe::MoveAnyToTVMFFIAny(Any(*reinterpret_cast(field))); + *result = ::tvm::ffi::details::AnyUnsafe::MoveAnyToTVMFFIAny(Any(*reinterpret_cast(field))); TVM_FFI_SAFE_CALL_END(); } @@ -390,7 +391,7 @@ class ReflectionDefBase { static int ObjectCreatorDefault(TVMFFIObjectHandle* result) { TVM_FFI_SAFE_CALL_BEGIN(); ObjectPtr obj = make_object(); - *result = details::ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(std::move(obj)); + *result = ::tvm::ffi::details::ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(std::move(obj)); TVM_FFI_SAFE_CALL_END(); } @@ -398,7 +399,7 @@ class ReflectionDefBase { static int ObjectCreatorUnsafeInit(TVMFFIObjectHandle* result) { TVM_FFI_SAFE_CALL_BEGIN(); ObjectPtr obj = make_object(UnsafeInit{}); - *result = details::ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(std::move(obj)); + *result = ::tvm::ffi::details::ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(std::move(obj)); TVM_FFI_SAFE_CALL_END(); } @@ -499,7 +500,7 @@ class GlobalDef : public ReflectionDefBase { */ template GlobalDef& def(const char* name, Func&& func, Extra&&... extra) { - using FuncInfo = details::FunctionInfo>; + using FuncInfo = ::tvm::ffi::details::FunctionInfo>; RegisterFunc(name, ffi::Function::FromTyped(std::forward(func), std::string(name)), FuncInfo::TypeSchema(), std::forward(extra)...); return *this; @@ -519,8 +520,8 @@ class GlobalDef : public ReflectionDefBase { */ template GlobalDef& def_packed(const char* name, Func func, Extra&&... extra) { - RegisterFunc(name, ffi::Function::FromPacked(func), details::TypeSchemaImpl::v(), - std::forward(extra)...); + RegisterFunc(name, ffi::Function::FromPacked(func), + ::tvm::ffi::details::TypeSchemaImpl::v(), std::forward(extra)...); return *this; } @@ -540,7 +541,7 @@ class GlobalDef : public ReflectionDefBase { */ template GlobalDef& def_method(const char* name, Func&& func, Extra&&... extra) { - using FuncInfo = details::FunctionInfo>; + using FuncInfo = ::tvm::ffi::details::FunctionInfo>; RegisterFunc(name, GetMethod(std::string(name), std::forward(func)), FuncInfo::TypeSchema(), std::forward(extra)...); return *this; @@ -692,8 +693,26 @@ inline constexpr const char* kHash = "__ffi_hash__"; inline constexpr const char* kEq = "__ffi_eq__"; /*! \brief Type attribute for custom recursive three-way comparison. */ inline constexpr const char* kCompare = "__ffi_compare__"; +/*! \brief Type attribute for converting AnyView to a specific reflected object type. */ +inline constexpr const char* kConvert = "__ffi_convert__"; } // namespace type_attr +/*! + * \brief Register the __ffi_convert__ type attribute for a reflected object type. + * \tparam TObjectRef The object reference type to register conversion for. + * \param type_index The runtime type index of the object. + * \param type_key The type key string of the object. + */ +template +inline void RegisterConvertTypeAttr(int32_t type_index, const char* type_key) { + TVMFFIByteArray attr_name = {type_attr::kConvert, + std::char_traits::length(type_attr::kConvert)}; + Function convert_func = ffi::Function::FromTyped( + &details::CastFromAny, std::string(type_key) + "." + type_attr::kConvert); + TVMFFIAny attr_value = AnyView(convert_func).CopyToTVMFFIAny(); + TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterAttr(type_index, &attr_name, &attr_value)); +} + /*! * \brief Helper to register Object's reflection metadata. * \tparam Class The class type. @@ -814,6 +833,21 @@ class ObjectDef : public ReflectionDefBase { return *this; } + /*! + * \brief Register ``__ffi_convert__`` for the object reference wrapper. + * + * This enables typed ``AnyView -> TObjectRef`` conversion through + * ``TypeTraits::TryCastFromAnyView``. + */ + template + TVM_FFI_INLINE ObjectDef& ref() { + static_assert( + std::is_same_v, + "ObjectDef::ref() requires TObjectRef::ContainerType == Class"); + RegisterConvertTypeAttr(type_index_, type_key_); + return *this; + } + /*! * \brief Register a constructor for this object type. * @@ -911,11 +945,11 @@ class ObjectDef : public ReflectionDefBase { info.flags |= kTVMFFIFieldFlagBitMaskWritable; } info.getter = FieldGetter; - info.setter = FieldSetter; + info.setter = reinterpret_cast(FieldSetter); // initialize default value to nullptr info.default_value_or_factory = AnyView(nullptr).CopyToTVMFFIAny(); info.doc = TVMFFIByteArray{nullptr, 0}; - info.metadata_.emplace_back("type_schema", details::TypeSchema::v()); + info.metadata_.emplace_back("type_schema", ::tvm::ffi::details::TypeSchema::v()); // apply field info traits ((ApplyFieldInfoTrait(&info, std::forward(extra_args)), ...)); // call register @@ -927,7 +961,7 @@ class ObjectDef : public ReflectionDefBase { // register a method template void RegisterMethod(const char* name, bool is_static, Func&& func, Extra&&... extra) { - using FuncInfo = details::FunctionInfo>; + using FuncInfo = ::tvm::ffi::details::FunctionInfo>; MethodInfoBuilder info; info.name = TVMFFIByteArray{name, std::char_traits::length(name)}; info.doc = TVMFFIByteArray{nullptr, 0}; diff --git a/python/tvm_ffi/__init__.py b/python/tvm_ffi/__init__.py index 70221f9e..aa17624d 100644 --- a/python/tvm_ffi/__init__.py +++ b/python/tvm_ffi/__init__.py @@ -66,7 +66,7 @@ def _is_config_mode() -> bool: init_ffi_api, ) from ._dtype import dtype - from .core import Object, ObjectConvertible, Function + from .core import Object, ObjectConvertible, Function, CAny from ._convert import convert from .error import register_error from ._tensor import Device, device, DLDeviceType diff --git a/python/tvm_ffi/core.pyi b/python/tvm_ffi/core.pyi index d39d83c1..5836f24a 100644 --- a/python/tvm_ffi/core.pyi +++ b/python/tvm_ffi/core.pyi @@ -209,6 +209,8 @@ def _get_global_func(name: str, allow_missing: bool) -> Function | None: ... def _convert_to_ffi_func(pyfunc: Callable[..., Any]) -> Function: ... def _convert_to_opaque_object(pyobject: Any) -> OpaquePyObject: ... def _print_debug_info() -> None: ... +def _register_py_class(parent_type_info: TypeInfo, type_key: str, type_cls: type) -> TypeInfo: ... +def _register_fields(type_info: TypeInfo, fields: list[Any]) -> list[TypeField]: ... class String(str, PyNativeObject): __slots__ = ["_tvm_ffi_cached_object"] @@ -228,6 +230,16 @@ class Bytes(bytes, PyNativeObject): # pylint: disable=no-self-argument def __from_tvm_ffi_object__(cls, obj: Any) -> Bytes: ... +# --------------------------------------------------------------------------- +# Owned FFI value container (from cython/object.pxi) +# --------------------------------------------------------------------------- + +class CAny: + def __init__(self, value: Any = None) -> None: ... + @property + def type_index(self) -> int: ... + def to_py(self) -> Any: ... + # --------------------------------------------------------------------------- # Type reflection metadata (from cython/type_info.pxi) # --------------------------------------------------------------------------- @@ -235,13 +247,26 @@ class Bytes(bytes, PyNativeObject): class TypeSchema: origin: str args: tuple[TypeSchema, ...] = () + origin_type_index: int - def __init__(self, origin: str, args: tuple[TypeSchema, ...] = ()) -> None: ... + def __init__( + self, + origin: str, + args: tuple[TypeSchema, ...] = (), + origin_type_index: int = ..., + ) -> None: ... @staticmethod def from_json_obj(obj: dict[str, Any]) -> TypeSchema: ... @staticmethod def from_json_str(s: str) -> TypeSchema: ... + @staticmethod + def from_type_index(type_index: int, args: tuple[TypeSchema, ...] = ()) -> TypeSchema: ... + @staticmethod + def from_annotation(annotation: object) -> TypeSchema: ... def repr(self, ty_map: Callable[[str], str] | None = None) -> str: ... + def check_value(self, value: object) -> None: ... + def convert(self, value: object) -> CAny: ... + def to_json(self) -> dict[str, Any]: ... class TypeField: name: str @@ -252,6 +277,7 @@ class TypeField: metadata: dict[str, Any] getter: Any setter: Any + ty: TypeSchema | None c_init: bool c_kw_only: bool c_has_default: bool @@ -277,4 +303,5 @@ class TypeInfo: methods: list[TypeMethod] parent_type_info: TypeInfo | None + def _register_fields(self, fields: list[Any]) -> None: ... def prototype_py(self) -> str: ... diff --git a/python/tvm_ffi/cython/base.pxi b/python/tvm_ffi/cython/base.pxi index 15fc053e..007aa0a9 100644 --- a/python/tvm_ffi/cython/base.pxi +++ b/python/tvm_ffi/cython/base.pxi @@ -195,6 +195,8 @@ cdef extern from "tvm/ffi/c_api.h": void (*update_backtrace)( TVMFFIObjectHandle self, const TVMFFIByteArray* backtrace, int32_t update_mode ) + TVMFFIObjectHandle cause_chain + TVMFFIObjectHandle extra_context ctypedef int (*TVMFFISafeCallType)( void* handle, const TVMFFIAny* args, int32_t num_args, @@ -204,9 +206,15 @@ cdef extern from "tvm/ffi/c_api.h": kTVMFFIFieldFlagBitMaskWritable = 1 << 0 kTVMFFIFieldFlagBitMaskHasDefault = 1 << 1 kTVMFFIFieldFlagBitMaskIsStaticMethod = 1 << 2 + kTVMFFIFieldFlagBitMaskSEqHashIgnore = 1 << 3 + kTVMFFIFieldFlagBitMaskSEqHashDef = 1 << 4 kTVMFFIFieldFlagBitMaskDefaultFromFactory = 1 << 5 + kTVMFFIFieldFlagBitMaskReprOff = 1 << 6 + kTVMFFIFieldFlagBitMaskCompareOff = 1 << 7 + kTVMFFIFieldFlagBitMaskHashOff = 1 << 8 kTVMFFIFieldFlagBitMaskInitOff = 1 << 9 kTVMFFIFieldFlagBitMaskKwOnly = 1 << 10 + kTVMFFIFieldFlagBitSetterIsFunctionObj = 1 << 11 ctypedef int (*TVMFFIFieldGetter)(void* field, TVMFFIAny* result) noexcept ctypedef int (*TVMFFIFieldSetter)(void* field, const TVMFFIAny* value) noexcept @@ -221,7 +229,7 @@ cdef extern from "tvm/ffi/c_api.h": int64_t alignment int64_t offset TVMFFIFieldGetter getter - TVMFFIFieldSetter setter + void* setter TVMFFIAny default_value_or_factory int32_t field_static_type_index @@ -232,10 +240,19 @@ cdef extern from "tvm/ffi/c_api.h": int64_t flags TVMFFIAny method + cdef enum TVMFFISEqHashKind: + kTVMFFISEqHashKindUnsupported = 0 + kTVMFFISEqHashKindTreeNode = 1 + kTVMFFISEqHashKindFreeVar = 2 + kTVMFFISEqHashKindDAGNode = 3 + kTVMFFISEqHashKindConstTreeNode = 4 + kTVMFFISEqHashKindUniqueInstance = 5 + ctypedef struct TVMFFITypeMetadata: TVMFFIByteArray doc TVMFFIObjectCreator creator - int64_t total_size + int32_t total_size + TVMFFISEqHashKind structural_eq_hash_kind ctypedef struct TVMFFITypeInfo: int32_t type_index @@ -297,6 +314,20 @@ cdef extern from "tvm/ffi/c_api.h": DLDevice TVMFFIDLDeviceFromIntPair(int32_t device_type, int32_t device_id) nogil const TVMFFITypeAttrColumn* TVMFFIGetTypeAttrColumn(const TVMFFIByteArray* attr_name) nogil + int32_t TVMFFITypeGetOrAllocIndex( + const TVMFFIByteArray* type_key, + int32_t static_type_index, + int32_t type_depth, + int32_t num_child_slots, + int32_t child_slots_can_overflow, + int32_t parent_type_index + ) nogil + int TVMFFITypeRegisterField(int32_t type_index, const TVMFFIFieldInfo* info) nogil + int TVMFFITypeRegisterMetadata(int32_t type_index, const TVMFFITypeMetadata* metadata) nogil + int TVMFFITypeRegisterAttr(int32_t type_index, const TVMFFIByteArray* attr_name, + const TVMFFIAny* attr_value) nogil + void TVMFFIErrorSetRaisedFromCStr(const char* kind, const char* message) nogil + cdef extern from "tvm/ffi/extra/c_env_api.h": ctypedef void* TVMFFIStreamHandle @@ -357,7 +388,8 @@ cdef extern from "tvm_ffi_python_helpers.h": int TVMFFIPyCallFieldSetter( TVMFFIPyArgSetterFactory setter_factory, - TVMFFIFieldSetter field_setter, + void* field_setter, + int64_t field_flags, void* field_ptr, PyObject* py_arg, int* c_api_ret_code diff --git a/python/tvm_ffi/cython/core.pyx b/python/tvm_ffi/cython/core.pyx index d755da63..76bc1ceb 100644 --- a/python/tvm_ffi/cython/core.pyx +++ b/python/tvm_ffi/cython/core.pyx @@ -38,6 +38,7 @@ include "./tensor.pxi" _register_object_by_index(kTVMFFITensor, Tensor) include "./function.pxi" _register_object_by_index(kTVMFFIFunction, Function) +include "./type_converter.pxi" # Global invalid/missing object singleton MISSING = _get_global_func("ffi.GetInvalidObject", False)() diff --git a/python/tvm_ffi/cython/object.pxi b/python/tvm_ffi/cython/object.pxi index b02c95e0..0ac31116 100644 --- a/python/tvm_ffi/cython/object.pxi +++ b/python/tvm_ffi/cython/object.pxi @@ -488,6 +488,7 @@ cdef _type_info_create_from_type_key(object type_cls, str type_key): setter = FieldSetter.__new__(FieldSetter) (setter).setter = field.setter (setter).offset = field.offset + (setter).flags = field.flags metadata_obj = json.loads(bytearray_to_str(&field.metadata)) if field.metadata.size != 0 else {} fields.append( TypeField( @@ -572,6 +573,94 @@ def _lookup_or_register_type_info_from_type_key(type_key: str) -> TypeInfo: return info +def _register_py_class(parent_type_info, str type_key, object type_cls): + """Register a new Python-defined TVM-FFI type. + + Allocates a dynamic type index for *type_key* as a child of + *parent_type_info* and registers it in the global type tables. + + Parameters + ---------- + parent_type_info : TypeInfo + The parent type's TypeInfo (e.g., Object's TypeInfo). + type_key : str + The unique type key string for the new type. + type_cls : type + The Python class to associate with this type. + + Returns + ------- + TypeInfo + The newly created TypeInfo with ``fields=None`` (pending registration). + + Raises + ------ + ValueError + If *type_key* is already registered. + """ + # Reject duplicate type keys + if type_key in TYPE_KEY_TO_INFO: + raise ValueError( + f"Type key '{type_key}' is already registered" + ) + + cdef int32_t parent_type_index = parent_type_info.type_index + cdef int32_t parent_type_depth = len(parent_type_info.type_ancestors) + cdef int32_t type_depth = parent_type_depth + 1 + cdef ByteArrayArg type_key_arg = ByteArrayArg(c_str(type_key)) + cdef int32_t type_index + + # Allocate a new type index + # static_type_index=-1 means dynamic allocation + # num_child_slots=0, child_slots_can_overflow=1 + type_index = TVMFFITypeGetOrAllocIndex( + type_key_arg.cptr(), + -1, # static_type_index (dynamic) + type_depth, + 0, # num_child_slots + 1, # child_slots_can_overflow + parent_type_index, + ) + + # Build ancestors list + cdef list ancestors = list(parent_type_info.type_ancestors) + ancestors.append(parent_type_index) + + # Create TypeInfo with fields=None (pending _register_fields call) + cdef object info = TypeInfo( + type_cls=type_cls, + type_index=type_index, + type_key=type_key, + type_ancestors=ancestors, + fields=None, + methods=[], + parent_type_info=parent_type_info, + ) + + _update_registry(type_index, type_key, info, type_cls) + return info + + +def _rollback_py_class(object type_info): + """Roll back a ``_register_py_class`` call from the Python-level registry. + + Called by ``@py_class`` when phase-2 (field validation) fails, so + the type key can be reused after the user fixes the error. The + C-level type index is permanently consumed (cannot be reclaimed), + but the Python dicts are cleaned up so that a retry does not hit + "already registered". + """ + cdef int32_t idx = type_info.type_index + cdef str key = type_info.type_key + cdef object cls = type_info.type_cls + TYPE_KEY_TO_INFO.pop(key, None) + if cls is not None: + TYPE_CLS_TO_INFO.pop(cls, None) + if 0 <= idx < len(TYPE_INDEX_TO_INFO): + TYPE_INDEX_TO_INFO[idx] = None + TYPE_INDEX_TO_CLS[idx] = None + + def _lookup_type_attr(type_index: int32_t, attr_key: str) -> Any: cdef ByteArrayArg attr_key_bytes = ByteArrayArg(c_str(attr_key)) cdef const TVMFFITypeAttrColumn* column = TVMFFIGetTypeAttrColumn(&attr_key_bytes.cdata) @@ -582,7 +671,8 @@ def _lookup_type_attr(type_index: int32_t, attr_key: str) -> Any: offset = type_index - column.begin_index if offset < 0 or offset >= column.size: return None - return make_ret(column.data[offset]) + CHECK_CALL(TVMFFIAnyViewToOwnedAny(&(column.data[offset]), &data)) + return make_ret(data) def _type_cls_to_type_info(type_cls: type) -> TypeInfo | None: @@ -595,3 +685,114 @@ cdef dict TYPE_CLS_TO_INFO = {} cdef dict TYPE_KEY_TO_INFO = {} _set_class_object(Object) + + +# --------------------------------------------------------------------------- +# CAny: Owned TVMFFIAny value container +# --------------------------------------------------------------------------- +cdef class CAny: + """Owned :c:type:`TVMFFIAny` value container. + + Holds sole ownership of the underlying value. For object types + (``type_index >= kTVMFFIStaticObjectBegin``), the reference is + properly ref-counted and released in ``__dealloc__``. + + Use :meth:`to_py` to recover the Python object. + """ + + cdef TVMFFIAny cdata + + def __cinit__(self): + """Initialize the contained value to ``None``.""" + self.cdata.type_index = kTVMFFINone + self.cdata.v_int64 = 0 + + def __init__(self, value=None): + """Pack a Python value into an owned :c:type:`TVMFFIAny`. + + Uses ``TVMFFIPyPyObjectToFFIAny`` to produce a non-owning AnyView, + then ``TVMFFIAnyViewToOwnedAny`` to convert to an owned Any. + + Parameters + ---------- + value : object, optional + The Python value to pack. When ``None`` (the default), the + container stays in the ``kTVMFFINone`` state set by ``__cinit__``. + """ + if value is None: + return + cdef TVMFFIAny temp + cdef int c_api_ret_code + temp.type_index = kTVMFFINone + temp.v_int64 = 0 + TVMFFIPyPyObjectToFFIAny( + TVMFFIPyArgSetterFactory_, + value, + &temp, + &c_api_ret_code + ) + CHECK_CALL(c_api_ret_code) + CHECK_CALL(TVMFFIAnyViewToOwnedAny(&temp, &self.cdata)) + + def __dealloc__(self): + """Release owned object reference, if any.""" + if self.cdata.type_index >= kTVMFFIStaticObjectBegin: + if self.cdata.v_obj != NULL: + CHECK_CALL(TVMFFIObjectDecRef(self.cdata.v_obj)) + self.cdata.v_obj = NULL + + @property + def type_index(self) -> int: + """The TVM FFI type index of the contained value.""" + return self.cdata.type_index + + def to_py(self) -> object: + """Convert the contained value to a Python object. + + For non-object types (int, float, bool, None, etc.), returns + the corresponding Python scalar. For object types + (``type_index >= kTVMFFIStaticObjectBegin``), returns the + appropriate :class:`CObject` subclass (e.g. ``Array``, ``List``, + ``String``, ``Function``). + + Short strings and bytes (stored inline as ``SmallStr``/``SmallBytes``) + are promoted to :class:`String` / :class:`Bytes` to guarantee + that every value in the TVM FFI type system round-trips as its + canonical FFI type. + + Safe to call multiple times — each call produces an + independent Python object with its own reference. + + Returns + ------- + object + The Python representation of the contained value. + """ + cdef TVMFFIAny copy = self.cdata + if copy.type_index >= kTVMFFIStaticObjectBegin: + if copy.v_obj != NULL: + TVMFFIObjectIncRef(copy.v_obj) + cdef object result = make_ret(copy) + # Promote inline SmallStr/SmallBytes to their FFI wrapper types + # so that convert().to_py() always yields tvm_ffi.String / tvm_ffi.Bytes. + if copy.type_index == kTVMFFISmallStr: + return String(result) + if copy.type_index == kTVMFFISmallBytes: + return Bytes(result) + return result + + def __repr__(self) -> str: + """Return a developer-friendly representation.""" + cdef int32_t ti = self.cdata.type_index + if ti == kTVMFFINone: + return "CAny(None)" + elif ti == kTVMFFIInt: + return f"CAny(int={self.cdata.v_int64})" + elif ti == kTVMFFIFloat: + return f"CAny(float={self.cdata.v_float64})" + elif ti == kTVMFFIBool: + return f"CAny(bool={bool(self.cdata.v_int64)})" + elif ti >= kTVMFFIStaticObjectBegin: + return f"CAny(object, type_index={ti})" + else: + return f"CAny(type_index={ti})" diff --git a/python/tvm_ffi/cython/tvm_ffi_python_helpers.h b/python/tvm_ffi/cython/tvm_ffi_python_helpers.h index 666bb6e9..281863a9 100644 --- a/python/tvm_ffi/cython/tvm_ffi_python_helpers.h +++ b/python/tvm_ffi/cython/tvm_ffi_python_helpers.h @@ -398,14 +398,28 @@ class TVMFFIPyCallManager { } } - TVM_FFI_INLINE int SetField(TVMFFIPyArgSetterFactory setter_factory, - TVMFFIFieldSetter field_setter, void* field_ptr, PyObject* py_arg, + TVM_FFI_INLINE int SetField(TVMFFIPyArgSetterFactory setter_factory, void* field_setter, + int64_t field_flags, void* field_ptr, PyObject* py_arg, int* c_api_ret_code) { try { TVMFFIPyCallContext ctx(&call_stack_, 1); TVMFFIAny* c_arg = ctx.packed_args; if (SetArgument(setter_factory, &ctx, py_arg, c_arg) != 0) return -1; - c_api_ret_code[0] = (*field_setter)(field_ptr, c_arg); + if (!(field_flags & kTVMFFIFieldFlagBitSetterIsFunctionObj)) { + auto setter = reinterpret_cast(field_setter); + c_api_ret_code[0] = (*setter)(field_ptr, c_arg); + } else { + TVMFFIAny args[2]; + args[0].type_index = kTVMFFIOpaquePtr; + args[0].zero_padding = 0; + args[0].v_ptr = field_ptr; + args[1] = *c_arg; + TVMFFIAny result; + result.type_index = kTVMFFINone; + result.v_int64 = 0; + c_api_ret_code[0] = + TVMFFIFunctionCall(static_cast(field_setter), args, 2, &result); + } return 0; } catch (const std::exception& ex) { // very rare, catch c++ exception and set python error @@ -534,17 +548,18 @@ TVM_FFI_INLINE int TVMFFIPyConstructorCall(TVMFFIPyArgSetterFactory setter_facto /*! * \brief Set a field of a FFI object * \param setter_factory The factory function to create the setter - * \param field_setter The field setter function + * \param field_setter The field setter (function pointer or FunctionObj handle) + * \param field_flags The field flags (to dispatch between function pointer and FunctionObj) * \param field_ptr The pointer to the field * \param py_arg The python argument to be set * \param c_api_ret_code The return code of the function * \return 0 on success, nonzero on failure */ TVM_FFI_INLINE int TVMFFIPyCallFieldSetter(TVMFFIPyArgSetterFactory setter_factory, - TVMFFIFieldSetter field_setter, void* field_ptr, + void* field_setter, int64_t field_flags, void* field_ptr, PyObject* py_arg, int* c_api_ret_code) { - return TVMFFIPyCallManager::ThreadLocal()->SetField(setter_factory, field_setter, field_ptr, - py_arg, c_api_ret_code); + return TVMFFIPyCallManager::ThreadLocal()->SetField(setter_factory, field_setter, field_flags, + field_ptr, py_arg, c_api_ret_code); } /*! diff --git a/python/tvm_ffi/cython/type_converter.pxi b/python/tvm_ffi/cython/type_converter.pxi new file mode 100644 index 00000000..a47d35a2 --- /dev/null +++ b/python/tvm_ffi/cython/type_converter.pxi @@ -0,0 +1,806 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Type converter implementation for TypeSchema. + +Provides the ``_type_convert_impl`` function used by +``TypeSchema.convert`` and ``TypeSchema.check_value``. + +Each ``_TypeConverter`` dispatches directly to a Cython cdef function that +returns a fully materialized :class:`CAny`. Container converters recurse into +their child schemas, rebuild the target FFI container shape, and then wrap the +final result in ``CAny``. +""" + +import ctypes +import os +from numbers import Integral, Real +from collections.abc import Mapping + + +cdef object _INT64_MIN = -(1 << 63) +cdef object _INT64_MAX = (1 << 63) - 1 +cdef int _VALUE_PROTOCOL_MAX_DEPTH = 64 +cdef str _TYPE_ATTR_FFI_CONVERT = "__ffi_convert__" +cdef class _TypeConverter +ctypedef CAny (*_dispatch_fn_t)(_TypeConverter, object, bint*) except * + + +# --------------------------------------------------------------------------- +# cdef class _TypeConverter — holds dispatch state as C-level struct fields +# --------------------------------------------------------------------------- +cdef class _TypeConverter: + """Pre-built converter holding a C function pointer and sub-converters.""" + + cdef _dispatch_fn_t dispatch + cdef int32_t type_index + cdef tuple subs + cdef str err_hint + cdef Function _fn_convert + cdef bint _fn_convert_ready + + @property + def fn_convert(self): + cdef object attr + assert self.type_index >= kTVMFFIStaticObjectBegin + if not self._fn_convert_ready: + attr = _lookup_type_attr(self.type_index, _TYPE_ATTR_FFI_CONVERT) + if attr is not None: + assert isinstance(attr, Function) + self._fn_convert = attr + else: + self._fn_convert = None + self._fn_convert_ready = True + return self._fn_convert + + +class _ConvertError(Exception): + """Internal exception used to signal conversion failure.""" + + __slots__ = () + + def __init__(self, message): + super().__init__(message) + + @property + def message(self): + return self.args[0] + + +# --------------------------------------------------------------------------- +# Converters (1/N): Simple value converters +# --------------------------------------------------------------------------- + +cdef CAny _tc_convert_any(_TypeConverter _conv, object value, bint* changed) except *: + """Any: accept any marshalable FFI value.""" + cdef CAny packed + assert _CLASS_DEVICE is not None + assert _CLASS_DTYPE is not None + packed = CAnyChecked(value, "Any", value) + if not isinstance( + value, + ( + type(None), + bool, + int, + float, + ctypes.c_void_p, + String, + Bytes, + Tensor, + DataType, + CObject, + _CLASS_DEVICE, + _CLASS_DTYPE, + ), + ): + changed[0] = True + return packed + + +cdef CAny _tc_convert_none(_TypeConverter _conv, object value, bint* changed) except *: + """None accepts: None only.""" + if value is None: + return CAny(None) + raise _ConvertError(f"expected None, got {_tc_describe_value_type(value)}") + + +cdef CAny _tc_convert_int(_TypeConverter _conv, object value, bint* changed) except *: + """int accepts: int, bool, Integral, __tvm_ffi_int__ protocol.""" + cdef object ivalue + if isinstance(value, bool): + changed[0] = True + return CAny(int(value)) + if isinstance(value, int): + if not (_INT64_MIN <= value <= _INT64_MAX): + raise _ConvertError( + f"integer {value} out of int64 range [{_INT64_MIN}, {_INT64_MAX}]" + ) + return CAny(value) + if isinstance(value, Integral): + try: + ivalue = int(value) + except Exception as err: + raise _ConvertError(f"int() failed for {type(value).__qualname__}: {err}") from None + if not (_INT64_MIN <= ivalue <= _INT64_MAX): + raise _ConvertError( + f"integer {ivalue} out of int64 range [{_INT64_MIN}, {_INT64_MAX}]" + ) + changed[0] = True + return CAny(ivalue) + if hasattr(type(value), "__tvm_ffi_int__"): + changed[0] = True + return CAnyChecked(value, "int", value) + raise _ConvertError(f"expected int, got {_tc_describe_value_type(value)}") + + +cdef CAny _tc_convert_float(_TypeConverter _conv, object value, bint* changed) except *: + """float accepts: float, int, bool, Real, __tvm_ffi_float__ protocol.""" + cdef object fvalue + if isinstance(value, float): + return CAny(value) + if isinstance(value, (int, bool)): + changed[0] = True + return CAny(float(value)) + if isinstance(value, (Integral, Real)): + try: + fvalue = float(value) + except Exception as err: + raise _ConvertError(f"float() failed for {type(value).__qualname__}: {err}") from None + changed[0] = True + return CAny(fvalue) + if hasattr(type(value), "__tvm_ffi_float__"): + changed[0] = True + return CAnyChecked(value, "float", value) + raise _ConvertError(f"expected float, got {_tc_describe_value_type(value)}") + + +cdef CAny _tc_convert_bool(_TypeConverter _conv, object value, bint* changed) except *: + """bool accepts: bool, int, Integral.""" + cdef object bvalue + if isinstance(value, bool): + return CAny(value) + if isinstance(value, Integral): # TODO: do we coerce Integral to bool? + try: + bvalue = bool(value) + except Exception as err: + raise _ConvertError(f"bool() failed for {type(value).__qualname__}: {err}") from None + changed[0] = True + return CAny(bvalue) + raise _ConvertError(f"expected bool, got {_tc_describe_value_type(value)}") + + +cdef CAny _tc_convert_str(_TypeConverter _conv, object value, bint* changed) except *: + """str accepts: str only.""" + if isinstance(value, str): + if not isinstance(value, String): + value = String(value) + changed[0] = True + return CAny(value) + raise _ConvertError(f"expected str, got {_tc_describe_value_type(value)}") + + +cdef CAny _tc_convert_bytes(_TypeConverter _conv, object value, bint* changed) except *: + """bytes accepts: bytes, bytearray.""" + if isinstance(value, Bytes): + return CAny(value) + if isinstance(value, bytes): + changed[0] = True + return CAny(Bytes(value)) + if isinstance(value, bytearray): + changed[0] = True + return CAny(Bytes(bytes(value))) + raise _ConvertError(f"expected bytes, got {_tc_describe_value_type(value)}") + + +cdef CAny _tc_convert_device(_TypeConverter _conv, object value, bint* changed) except *: + """Device accepts: Device and __dlpack_device__ without __dlpack__.""" + cdef object vtype = type(value) + assert _CLASS_DEVICE is not None + if isinstance(value, _CLASS_DEVICE): + return CAny(value) + if hasattr(vtype, "__dlpack_device__") and not hasattr(vtype, "__dlpack__"): + changed[0] = True + return CAnyChecked(value, "Device", value) + raise _ConvertError(f"expected Device, got {_tc_describe_value_type(value)}") + + +cdef CAny _tc_convert_dtype(_TypeConverter _conv, object value, bint* changed) except *: + """dtype accepts: DataType, dtype wrapper, str and dtype protocols.""" + cdef object dtype_value + assert _CLASS_DTYPE is not None + if isinstance(value, (DataType, _CLASS_DTYPE)): + return CAny(value) + if isinstance(value, str): + try: + dtype_value = DataType(value) + except Exception: + raise _ConvertError(f"expected dtype, got invalid dtype string {value!r}") from None + changed[0] = True + return CAny(dtype_value) + if ( + (torch is not None and isinstance(value, torch.dtype)) + or (numpy is not None and isinstance(value, numpy.dtype)) + or hasattr(value, "__dlpack_data_type__") + ): + changed[0] = True + return CAnyChecked(value, "dtype", value) + raise _ConvertError(f"expected dtype, got {_tc_describe_value_type(value)}") + + +cdef CAny _tc_convert_opaque_ptr(_TypeConverter _conv, object value, bint* changed) except *: + """ctypes.c_void_p accepts ctypes.c_void_p, None and opaque pointer protocols.""" + cdef object vtype = type(value) + if isinstance(value, ctypes.c_void_p): + return CAny(value) + # TODO: noticed that `OpaquePtr(nullptr) != None` - need to double check if this is correct + if value is None: + changed[0] = True + return CAny(ctypes.c_void_p(None)) + if hasattr(vtype, "__tvm_ffi_opaque_ptr__") or hasattr(vtype, "__cuda_stream__"): + changed[0] = True + return CAnyChecked(value, "ctypes.c_void_p", value) + raise _ConvertError(f"expected ctypes.c_void_p, got {_tc_describe_value_type(value)}") + + +cdef CAny _tc_convert_tensor(_TypeConverter _conv, object value, bint* changed) except *: + """Tensor accepts Tensor, Tensor subtypes and DLPack exporters.""" + cdef object vtype = type(value) + if isinstance(value, Tensor): + return CAny(value) + if hasattr(vtype, "__dlpack__"): + changed[0] = True + return CAnyChecked(value, "Tensor", value) + if os.environ.get("TVM_FFI_SKIP_DLPACK_C_EXCHANGE_API", "0") != "1": + if hasattr(vtype, "__dlpack_c_exchange_api__"): + changed[0] = True + return CAnyChecked(value, "Tensor", value) + raise _ConvertError(f"expected Tensor, got {_tc_describe_value_type(value)}") + + +cdef CAny _tc_convert_callable(_TypeConverter _conv, object value, bint* changed) except *: + """Callable accepts Function and any Python callable.""" + cdef Function func + if isinstance(value, Function): + return CAny(value) + if callable(value): + if isinstance(value, CObject): + raise _ConvertError(f"expected Callable, got {_tc_describe_value_type(value)}") + changed[0] = True + return CAnyChecked(value, "Callable", value) + raise _ConvertError(f"expected Callable, got {_tc_describe_value_type(value)}") + + +# --------------------------------------------------------------------------- +# Converters (2/N): Sequence/Mapping Containers +# --------------------------------------------------------------------------- + + +cdef CAny _tc_convert_array(_TypeConverter conv, object value, bint* changed) except *: + """Dispatch for Array[T]. Accepts Array or List CObjects (cross-type).""" + from tvm_ffi.container import Array + + return _tc_convert_seq(conv, value, changed, Array) + + +cdef CAny _tc_convert_list(_TypeConverter conv, object value, bint* changed) except *: + """Dispatch for List[T]. Accepts List or Array CObjects (cross-type).""" + from tvm_ffi.container import List + + return _tc_convert_seq(conv, value, changed, List) + + +cdef CAny _tc_convert_map(_TypeConverter conv, object value, bint* changed) except *: + """Dispatch for Map[K, V]. Accepts Map or Dict CObjects (cross-type).""" + from tvm_ffi import _ffi_api + from tvm_ffi.container import Map + + return _tc_convert_mapping(conv, value, changed, Map, _ffi_api.Map) + + +cdef CAny _tc_convert_dict(_TypeConverter conv, object value, bint* changed) except *: + """Dispatch for Dict[K, V]. Accepts Dict or Map CObjects (cross-type).""" + from tvm_ffi import _ffi_api + from tvm_ffi.container import Dict + + return _tc_convert_mapping(conv, value, changed, Dict, _ffi_api.Dict) + + +cdef CAny _tc_convert_seq(_TypeConverter conv, object value, bint* changed, object seq_type) except *: + from tvm_ffi.container import Array, List + + cdef _TypeConverter elem_conv = conv.subs[0] if conv.subs else None + + if not isinstance(value, (Array, List, list, tuple)): + raise _ConvertError(f"expected {seq_type.__name__}, got {_tc_describe_value_type(value)}") + + if elem_conv is None and isinstance(value, seq_type): + return CAny(value) + + cdef list converted = [] + cdef bint item_changed + cdef object raw_item + cdef int i = 0 + cdef CAny item + for raw_item in value: + if elem_conv is not None: + item_changed = False + try: + item = _type_convert_dispatch_with_fallback(elem_conv, raw_item, &item_changed) + except _ConvertError as err: + raise _ConvertError(f"element [{i}]: {err.message}") from None + if item_changed: + changed[0] = True + converted.append(item.to_py()) + else: + converted.append(raw_item) + else: + converted.append(raw_item) + i += 1 + if isinstance(value, seq_type) and not changed[0]: + return CAny(value) + changed[0] = True + return CAnyChecked(seq_type(converted), seq_type.__name__, value) + + +cdef CAny _tc_convert_tuple(_TypeConverter conv, object value, bint* changed) except *: + """Dispatch for tuple[T1, T2, ...] or bare tuple.""" + cdef int i + cdef int n + cdef list converted = [] + cdef CAny item + cdef bint item_changed + cdef object raw_item + + from tvm_ffi.container import Array, List + + if not isinstance(value, (Array, List, list, tuple)): + raise _ConvertError(f"expected tuple, got {_tc_describe_value_type(value)}") + + if conv.subs is None: + if isinstance(value, Array): + return CAny(value) + changed[0] = True + return CAnyChecked(Array(value), "tuple", value) + + n = len(conv.subs) + if len(value) != n: + raise _ConvertError( + f"expected tuple of length {n}, got {type(value).__name__} of length {len(value)}" + ) + + for i in range(n): + raw_item = value[i] + item_changed = False + try: + item = _type_convert_dispatch_with_fallback( + <_TypeConverter>(conv.subs[i]), + raw_item, + &item_changed, + ) + except _ConvertError as err: + raise _ConvertError(f"element [{i}]: {err.message}") from None + if item_changed: + changed[0] = True + converted.append(item.to_py()) + else: + converted.append(raw_item) + if isinstance(value, Array) and not changed[0]: + return CAny(value) + changed[0] = True + return CAnyChecked(Array(converted), "tuple", value) + + +cdef CAny _tc_convert_mapping( + _TypeConverter conv, + object value, + bint* changed, + object mapping_type, + object constructor, +) except *: + cdef _TypeConverter key_conv = conv.subs[0] if conv.subs else None + cdef _TypeConverter val_conv = conv.subs[1] if conv.subs else None + cdef list list_kvs = [] + cdef CAny item + cdef bint item_changed + cdef object raw_key + cdef object raw_val + cdef object new_key + cdef object new_val + cdef object mapping + cdef str expected = mapping_type.__name__ + + if not isinstance(value, Mapping): + raise _ConvertError(f"expected {expected}, got {_tc_describe_value_type(value)}") + + if key_conv is None and val_conv is None and isinstance(value, mapping_type): + return CAny(value) + + for raw_key, raw_val in value.items(): + new_key = raw_key + if key_conv is not None: + item_changed = False + try: + item = _type_convert_dispatch_with_fallback(key_conv, raw_key, &item_changed) + except _ConvertError as err: + raise _ConvertError(f"key {raw_key!r}: {err.message}") from None + if item_changed: + changed[0] = True + new_key = item.to_py() + new_val = raw_val + if val_conv is not None: + item_changed = False + try: + item = _type_convert_dispatch_with_fallback(val_conv, raw_val, &item_changed) + except _ConvertError as err: + raise _ConvertError(f"value for key {raw_key!r}: {err.message}") from None + if item_changed: + changed[0] = True + new_val = item.to_py() + list_kvs.append(new_key) + list_kvs.append(new_val) + if isinstance(value, mapping_type) and not changed[0]: + return CAny(value) + changed[0] = True + try: + mapping = constructor(*list_kvs) + except _ConvertError: + raise + except Exception as err: + raise _ConvertError( + f"expected {expected}, got {_tc_describe_value_type(value)}: {err}" + ) from None + return CAnyChecked(mapping, expected, value) + + +# --------------------------------------------------------------------------- +# Converters (3/N): Optional and Union +# --------------------------------------------------------------------------- + +cdef CAny _tc_convert_optional(_TypeConverter conv, object value, bint* changed) except *: + """Dispatch for Optional[T]: None passthrough or inner dispatch.""" + if value is None: + return CAny(None) + return _type_convert_dispatch_with_fallback(<_TypeConverter>(conv.subs[0]), value, changed) + + +cdef CAny _tc_convert_union(_TypeConverter conv, object value, bint* changed) except *: + """Dispatch for Union[T1, T2, ...].""" + cdef _TypeConverter alt + cdef bint alt_changed + cdef CAny result + for alt_obj in conv.subs: + alt = <_TypeConverter>alt_obj + try: + alt_changed = False + result = alt.dispatch(alt, value, &alt_changed) + changed[0] = alt_changed + return result + except _ConvertError: + pass + raise _ConvertError(f"expected {conv.err_hint}, got {_tc_describe_value_type(value)}") + + +# --------------------------------------------------------------------------- +# Converters (4/N): Object Types +# --------------------------------------------------------------------------- + +cdef CAny _tc_convert_object(_TypeConverter conv, object value, bint* changed) except *: + """Convert *value* to an object compatible with ``conv.type_index``.""" + # TODO: SmallStr and SmallBytes => ObjectRef conversion is not supported yet + cdef int32_t actual_type_index = kTVMFFINone + cdef CAny packed + cdef CAny converted + cdef Function fn_convert + cdef object err = None + + # Step 1: existing FFI objects that already satisfy the target schema are passthrough. + assert conv.type_index >= kTVMFFIStaticObjectBegin + if isinstance(value, CObject): + actual_type_index = TVMFFIObjectGetTypeIndex((value).chandle) + if _tc_type_index_is_instance(actual_type_index, conv.type_index): + return CAny(value) + changed[0] = True + + # Step 2: pack, and convert to the target type. + packed = CAnyChecked(value, conv.err_hint, value) + fn_convert = conv.fn_convert + try: + if fn_convert is not None: + converted = CAny.__new__(CAny) + CHECK_CALL(TVMFFIFunctionCall(fn_convert.chandle, &packed.cdata, 1, &converted.cdata)) + else: + converted = packed + except Exception as err_: + err = err_ + else: + actual_type_index = converted.type_index + if actual_type_index >= kTVMFFIStaticObjectBegin: + if _tc_type_index_is_instance(actual_type_index, conv.type_index): + return converted + raise _ConvertError(f"expected {conv.err_hint}, got {_tc_describe_value_type(value)}") from err + + +cdef inline bint _tc_type_index_is_instance(int32_t actual_tindex, int32_t target_tindex) noexcept: + """Check if *actual_tindex* is *target_tindex* or a subclass thereof.""" + # TODO: this can be optimized by looking up `TYPE_INDEX_TO_INFO` + if actual_tindex == target_tindex: + return True + cdef const TVMFFITypeInfo* actual_info = TVMFFIGetTypeInfo(actual_tindex) + if actual_info == NULL: + return False + cdef const TVMFFITypeInfo* target_info = TVMFFIGetTypeInfo(target_tindex) + if target_info == NULL: + return False + cdef int32_t target_depth = target_info.type_depth + if actual_info.type_depth <= target_depth: + return False + return actual_info.type_ancestors[target_depth].type_index == target_tindex + + +# --------------------------------------------------------------------------- +# Helper: describe the Python type of a value for error messages +# --------------------------------------------------------------------------- +cdef str _tc_describe_value_type(object value): + """Return a human-readable type description for *value*.""" + cdef object type_info + if value is None: + return "None" + if isinstance(value, bool): + return "bool" + if isinstance(value, int): + return "int" + if isinstance(value, float): + return "float" + if isinstance(value, str): + return "str" + if isinstance(value, (bytes, bytearray)): + return "bytes" + if isinstance(value, CObject): + type_info = getattr(type(value), "__tvm_ffi_type_info__", None) + if type_info is not None: + return type_info.type_key + return _type_index_to_key(TVMFFIObjectGetTypeIndex((value).chandle)) + return type(value).__qualname__ + + +cdef CAny CAnyChecked(object value, str expected, object original_value) except *: + """Pack *value* into CAny and normalize packing failures to _ConvertError.""" + try: + return CAny(value) + except _ConvertError: + raise + except Exception as err: + raise _ConvertError( + f"expected {expected}, got {_tc_describe_value_type(original_value)}: {err}" + ) from None + + +# --------------------------------------------------------------------------- +# Builder (runs once per TypeSchema at construction time) +# --------------------------------------------------------------------------- + +def _build_converter(schema): + """Build a ``_TypeConverter`` for *schema*.""" + cdef _TypeConverter conv = _TypeConverter.__new__(_TypeConverter) + cdef _TypeConverter sub_conv + cdef _TypeConverter key_conv + cdef _TypeConverter val_conv + origin = schema.origin + args = schema.args + origin_tindex = schema.origin_type_index + conv.err_hint = origin + + def _to_type_converter_or_none(object schema_arg): + if schema_arg.origin == "Any": + return None + return <_TypeConverter>(schema_arg._converter) + + if origin_tindex == kTVMFFIAny or origin == "Any": + conv.dispatch = _tc_convert_any + conv.err_hint = "Any" + return conv + + if origin == "Optional": + conv.dispatch = _tc_convert_optional + conv.subs = (<_TypeConverter>(args[0]._converter),) + return conv + if origin == "Union": + conv.dispatch = _tc_convert_union + conv.subs = tuple(<_TypeConverter>(a._converter) for a in args) + conv.err_hint = " | ".join(repr(a) for a in args) + return conv + + if origin == "int": + conv.dispatch = _tc_convert_int + return conv + if origin == "float": + conv.dispatch = _tc_convert_float + return conv + if origin == "bool": + conv.dispatch = _tc_convert_bool + return conv + if origin == "None": + conv.dispatch = _tc_convert_none + return conv + if origin == "str": + conv.dispatch = _tc_convert_str + return conv + if origin == "bytes": + conv.dispatch = _tc_convert_bytes + return conv + if origin == "Device": + conv.dispatch = _tc_convert_device + return conv + if origin == "dtype": + conv.dispatch = _tc_convert_dtype + return conv + if origin == "ctypes.c_void_p": + conv.dispatch = _tc_convert_opaque_ptr + return conv + if origin == "Tensor": + conv.dispatch = _tc_convert_tensor + return conv + if origin == "Callable": + conv.dispatch = _tc_convert_callable + return conv + + if origin == "Array": + conv.dispatch = _tc_convert_array + if len(args) > 0: + sub_conv = _to_type_converter_or_none(args[0]) + if sub_conv is not None: + conv.subs = (sub_conv,) + return conv + if origin in ("List", "list"): + conv.dispatch = _tc_convert_list + conv.err_hint = "List" + if len(args) > 0: + sub_conv = _to_type_converter_or_none(args[0]) + if sub_conv is not None: + conv.subs = (sub_conv,) + return conv + if origin == "Map": + conv.dispatch = _tc_convert_map + if len(args) == 2: + key_conv = _to_type_converter_or_none(args[0]) + val_conv = _to_type_converter_or_none(args[1]) + if key_conv is not None or val_conv is not None: + conv.subs = (key_conv, val_conv) + return conv + if origin in ("Dict", "dict"): + conv.dispatch = _tc_convert_dict + conv.err_hint = "Dict" + if len(args) == 2: + key_conv = _to_type_converter_or_none(args[0]) + val_conv = _to_type_converter_or_none(args[1]) + if key_conv is not None or val_conv is not None: + conv.subs = (key_conv, val_conv) + return conv + if origin == "tuple": + conv.dispatch = _tc_convert_tuple + if args is not None: + conv.subs = tuple(<_TypeConverter>(a._converter) for a in args) + return conv + + if origin == "Object": + conv.dispatch = _tc_convert_object + conv.type_index = kTVMFFIObject + conv.err_hint = "Object" + return conv + if origin_tindex >= kTVMFFIStaticObjectBegin: + conv.dispatch = _tc_convert_object + conv.type_index = origin_tindex + conv.err_hint = origin + return conv + + tindex = _object_type_key_to_index(origin) + if tindex is not None: + conv.dispatch = _tc_convert_object + conv.type_index = tindex + conv.err_hint = origin + return conv + + raise TypeError(f"unknown TypeSchema origin: {origin!r}") + + +# --------------------------------------------------------------------------- +# Eager protocol normalization and dispatch +# --------------------------------------------------------------------------- + + +cdef void _tc_raise_eager_value_protocol_error(_TypeConverter conv, object value) except *: + if conv.dispatch == _tc_convert_optional: + _tc_raise_eager_value_protocol_error(<_TypeConverter>(conv.subs[0]), value) + if conv.err_hint == "Any": + raise _ConvertError(f"failed to convert Any from {_tc_describe_value_type(value)}") + raise _ConvertError(f"expected {conv.err_hint}, got {_tc_describe_value_type(value)}") + + +cdef object _tc_eager_protocol_step(object value, bint* stalled_value_protocol) except *: + cdef object vtype + cdef object inner + if isinstance(value, (Tensor, CObject, ObjectRValueRef, PyNativeObject)): + return value + vtype = type(value) + if hasattr(vtype, "__tvm_ffi_object__"): + try: + return value.__tvm_ffi_object__() + except Exception: + raise _ConvertError( + f"__tvm_ffi_object__() failed for {_tc_describe_value_type(value)}" + ) from None + if hasattr(vtype, "__tvm_ffi_value__"): + try: + inner = value.__tvm_ffi_value__() + except Exception: + # Report the schema mismatch instead of leaking the raw + # __tvm_ffi_value__ implementation error. + stalled_value_protocol[0] = True + return value + if inner is value: + stalled_value_protocol[0] = True + return inner + if isinstance(value, ObjectConvertible): + # Normalize ObjectConvertible eagerly so nested Union/container dispatch + # sees the inner FFI object instead of the Python wrapper. + try: + inner = value.asobject() + except Exception: + raise _ConvertError(f"asobject() failed for {_tc_describe_value_type(value)}") from None + if not isinstance(inner, CObject): + raise _ConvertError( + f"asobject() returned {_tc_describe_value_type(inner)} " + f"for {_tc_describe_value_type(value)}" + ) + return inner + return value + + +cdef CAny _type_convert_dispatch_with_fallback(_TypeConverter conv, object value, bint* changed) except *: + """Dispatch after eager protocol normalization with cycle protection.""" + cdef int depth = 0 + cdef object inner + cdef bint stalled_value_protocol + cdef bint used_value_protocol = False + cdef CAny result + while True: + stalled_value_protocol = False + inner = _tc_eager_protocol_step(value, &stalled_value_protocol) + if stalled_value_protocol: + _tc_raise_eager_value_protocol_error(conv, value) + if inner is value: + break + depth += 1 + if depth > _VALUE_PROTOCOL_MAX_DEPTH: + raise _ConvertError("infinite __tvm_ffi_value__ cycle detected") from None + used_value_protocol = True + value = inner + changed[0] = False + result = conv.dispatch(conv, value, changed) + if used_value_protocol: + changed[0] = True + return result + + +# --------------------------------------------------------------------------- +# Main dispatcher (thin entry point from Python-level TypeSchema methods) +# --------------------------------------------------------------------------- +cdef CAny _type_convert_impl(_TypeConverter converter, object value) except *: + """Dispatch to the C-level converter.""" + cdef bint changed + return _type_convert_dispatch_with_fallback(converter, value, &changed) diff --git a/python/tvm_ffi/cython/type_info.pxi b/python/tvm_ffi/cython/type_info.pxi index ab4cdc9b..2df8a5c6 100644 --- a/python/tvm_ffi/cython/type_info.pxi +++ b/python/tvm_ffi/cython/type_info.pxi @@ -16,9 +16,17 @@ # under the License. import dataclasses import json +import typing +import collections.abc +from functools import cached_property from typing import Optional, Any from io import StringIO +try: + from types import UnionType as _UnionType +except ImportError: + _UnionType = None + cdef class FieldGetter: cdef dict __dict__ @@ -38,8 +46,9 @@ cdef class FieldGetter: cdef class FieldSetter: cdef dict __dict__ - cdef TVMFFIFieldSetter setter + cdef void* setter cdef int64_t offset + cdef int64_t flags def __call__(self, CObject obj, value): cdef int c_api_ret_code @@ -47,6 +56,7 @@ cdef class FieldSetter: TVMFFIPyCallFieldSetter( TVMFFIPyArgSetterFactory_, self.setter, + self.flags, field_ptr, value, &c_api_ret_code @@ -76,6 +86,8 @@ _TYPE_SCHEMA_ORIGIN_CONVERTER = { "ffi.List": "List", "ffi.Map": "Map", "ffi.Dict": "Dict", + # OpaquePyObject accepts any Python value at the FFI boundary (the C++ + # side wraps it opaquely), so mapping to "Any" is semantically correct. "ffi.OpaquePyObject": "Any", "ffi.Object": "Object", "ffi.Tensor": "Tensor", @@ -92,8 +104,55 @@ _TYPE_SCHEMA_ORIGIN_CONVERTER = { "ffi.SmallStr": "str", "ffi.String": "str", "DataType": "dtype", + # C++ STL types (emitted by TypeTraits in include/tvm/ffi/extra/stl.h) + "std::vector": "Array", + "std::optional": "Optional", + "std::variant": "Union", + "std::tuple": "tuple", + "std::map": "Map", + "std::unordered_map": "Map", + "std::function": "Callable", + # Rvalue reference (C++ move semantics). Python has no move semantics, + # so the checker treats it as a plain Object reference. + "ObjectRValueRef": "Object", +} + +# Sentinel for structural types (Optional, Union) that have no single type_index +_ORIGIN_TYPE_INDEX_STRUCTURAL = -2 +# Sentinel for unknown/unresolved origins +_ORIGIN_TYPE_INDEX_UNKNOWN = -3 + +# Map origin string -> type_index for known types +_ORIGIN_TO_TYPE_INDEX = { + "None": kTVMFFINone, + "int": kTVMFFIInt, + "bool": kTVMFFIBool, + "float": kTVMFFIFloat, + "str": kTVMFFIStr, + "bytes": kTVMFFIBytes, + "Device": kTVMFFIDevice, + "dtype": kTVMFFIDataType, + "ctypes.c_void_p": kTVMFFIOpaquePtr, + "Tensor": kTVMFFITensor, + "Object": kTVMFFIObject, + "Callable": kTVMFFIFunction, + "Array": kTVMFFIArray, + "List": kTVMFFIList, + "Map": kTVMFFIMap, + "Dict": kTVMFFIDict, + "Any": kTVMFFIAny, } +# Reverse map: type_index -> origin string +_TYPE_INDEX_TO_ORIGIN = {v: k for k, v in _ORIGIN_TO_TYPE_INDEX.items()} +# Low-level type indices that alias canonical origins +_TYPE_INDEX_TO_ORIGIN[kTVMFFIDLTensorPtr] = "Tensor" +_TYPE_INDEX_TO_ORIGIN[kTVMFFIRawStr] = "str" +_TYPE_INDEX_TO_ORIGIN[kTVMFFIByteArrayPtr] = "bytes" +_TYPE_INDEX_TO_ORIGIN[kTVMFFISmallStr] = "str" +_TYPE_INDEX_TO_ORIGIN[kTVMFFISmallBytes] = "bytes" +_TYPE_INDEX_TO_ORIGIN[kTVMFFIObjectRValueRef] = "Object" + @dataclasses.dataclass(repr=False) class TypeSchema: @@ -104,37 +163,83 @@ class TypeSchema: :py:meth:`repr`. """ origin: str - args: tuple[TypeSchema, ...] = () + args: tuple["TypeSchema", ...] | None = None + origin_type_index: int = dataclasses.field(default=_ORIGIN_TYPE_INDEX_UNKNOWN, repr=False) def __post_init__(self): origin = self.origin args = self.args + if args is not None and not isinstance(args, tuple): + args = tuple(args) + self.args = args + if origin != "tuple" and args is None: + args = () + self.args = args if origin == "Union": - assert len(args) >= 2, "Union must have at least two arguments" + if len(args) < 2: + raise ValueError("Union must have at least two arguments") elif origin == "Optional": - assert len(args) == 1, "Optional must have exactly one argument" + if len(args) != 1: + raise ValueError("Optional must have exactly one argument") elif origin in ("list", "Array", "List"): - assert len(args) in (0, 1), f"{origin} must have 0 or 1 argument" + if len(args) not in (0, 1): + raise ValueError(f"{origin} must have 0 or 1 argument") if args == (): self.args = (TypeSchema("Any"),) elif origin in ("dict", "Map", "Dict"): - assert len(args) in (0, 2), f"{origin} must have 0 or 2 arguments" + if len(args) not in (0, 2): + raise ValueError(f"{origin} must have 0 or 2 arguments") if args == (): self.args = (TypeSchema("Any"), TypeSchema("Any")) elif origin == "tuple": pass # tuple can have arbitrary number of arguments + # Compute origin_type_index if not already set + if self.origin_type_index == _ORIGIN_TYPE_INDEX_UNKNOWN: + if origin in ("Optional", "Union"): + self.origin_type_index = _ORIGIN_TYPE_INDEX_STRUCTURAL + elif origin in _ORIGIN_TO_TYPE_INDEX: + self.origin_type_index = _ORIGIN_TO_TYPE_INDEX[origin] + else: + # Try to resolve as a registered object type key + tindex = _object_type_key_to_index(origin) + if tindex is not None: + self.origin_type_index = tindex + + @cached_property + def _converter(self): + """Lazily build the type converter on first use. + + Deferred construction ensures all object types are registered + by the time the converter is built. Raises TypeError for + unresolvable origins. + """ + return _build_converter(self) def __repr__(self) -> str: return self.repr(ty_map=None) @staticmethod def from_json_obj(obj: dict[str, Any]) -> "TypeSchema": - """Construct a :class:`TypeSchema` from a parsed JSON object.""" - assert isinstance(obj, dict) and "type" in obj, obj + """Construct a :class:`TypeSchema` from a parsed JSON object. + + Non-dict elements in the ``"args"`` list (e.g., numeric lengths + emitted by ``std::array`` TypeTraits) are silently skipped. + """ + if not isinstance(obj, dict) or "type" not in obj: + raise TypeError( + f"expected schema dict with 'type' key, got {type(obj).__name__}" + ) origin = obj["type"] origin = _TYPE_SCHEMA_ORIGIN_CONVERTER.get(origin, origin) - args = obj.get("args", ()) - args = tuple(TypeSchema.from_json_obj(a) for a in args) + if "args" not in obj: + return TypeSchema(origin) + raw_args = obj["args"] + if not isinstance(raw_args, (list, tuple)): + raw_args = () + args = tuple( + TypeSchema.from_json_obj(a) for a in raw_args + if isinstance(a, dict) + ) return TypeSchema(origin, args) @staticmethod @@ -142,6 +247,238 @@ class TypeSchema: """Construct a :class:`TypeSchema` from a JSON string.""" return TypeSchema.from_json_obj(json.loads(s)) + @staticmethod + def from_type_index(type_index: int, args: "tuple[TypeSchema, ...]" = ()) -> "TypeSchema": + """Construct a :class:`TypeSchema` from a type_index and optional args. + + Parameters + ---------- + type_index : int + A valid TVM FFI type index (e.g., ``kTVMFFIInt``, ``kTVMFFIArray``, + or an object type index from ``_object_type_key_to_index``). + Passing an unregistered index triggers a fatal C++ assertion; + callers must ensure the index was obtained from the type registry. + args : tuple[TypeSchema, ...], optional + Type arguments for parameterized types (e.g., element type for Array). + + Returns + ------- + TypeSchema + A new schema with the origin resolved from the type index. + """ + origin = _TYPE_INDEX_TO_ORIGIN.get(type_index, None) + if origin is None: + origin = _type_index_to_key(type_index) + return TypeSchema(origin, args, origin_type_index=type_index) + + @staticmethod + def from_annotation(annotation: object) -> "TypeSchema": + """Construct a :class:`TypeSchema` from a Python type annotation. + + Parameters + ---------- + annotation : object + A Python type annotation such as ``int``, ``list[int]``, + ``Optional[str]``, ``Union[int, str]``, ``Callable[[int], str]``, + or a registered :class:`CObject` subclass. + + Returns + ------- + TypeSchema + The corresponding schema. + + Raises + ------ + TypeError + If the annotation cannot be mapped to a TypeSchema. + + Examples + -------- + >>> TypeSchema.from_annotation(int) + int + >>> TypeSchema.from_annotation(list[int]) + List[int] + >>> TypeSchema.from_annotation(tuple[int, ...]) + Array[int] + """ + # --- Singletons --- + if annotation is type(None) or annotation is None: + return TypeSchema("None") + if annotation is typing.Any: + return TypeSchema("Any") + + # --- Bare builtin scalar types --- + if annotation is bool: + return TypeSchema("bool") + if annotation is int: + return TypeSchema("int") + if annotation is float: + return TypeSchema("float") + if annotation is str: + return TypeSchema("str") + if annotation is bytes: + return TypeSchema("bytes") + + # --- Bare container types (unparameterised) --- + if annotation is list: + return TypeSchema("List") + if annotation is dict: + return TypeSchema("Dict") + if annotation is tuple: + return TypeSchema("tuple") + if annotation is collections.abc.Callable: + return TypeSchema("Callable") + + # --- Python 3.10+ union syntax (X | Y) --- + if _UnionType is not None and isinstance(annotation, _UnionType): + return _annotation_union(typing.get_args(annotation)) + + # --- Generic aliases (list[int], Optional[T], etc.) --- + origin = typing.get_origin(annotation) + targs = typing.get_args(annotation) + + if origin is typing.Union: + return _annotation_union(targs) + + if origin is list: + if len(targs) > 1: + raise TypeError( + f"list takes at most 1 type argument, got {len(targs)}" + ) + if targs: + return TypeSchema("List", (TypeSchema.from_annotation(targs[0]),)) + return TypeSchema("List") + + if origin is dict: + if len(targs) == 1 or len(targs) > 2: + raise TypeError( + f"dict requires 0 or 2 type arguments, got {len(targs)}" + ) + if len(targs) == 2: + return TypeSchema("Dict", ( + TypeSchema.from_annotation(targs[0]), + TypeSchema.from_annotation(targs[1]), + )) + return TypeSchema("Dict") + + if origin is tuple: + if len(targs) == 2 and targs[1] is Ellipsis: + # tuple[T, ...] → homogeneous variable-length → Array + return TypeSchema("Array", (TypeSchema.from_annotation(targs[0]),)) + if targs: + return TypeSchema( + "tuple", + tuple(TypeSchema.from_annotation(a) for a in targs), + ) + if annotation is not tuple: + return TypeSchema("tuple", ()) + return TypeSchema("tuple") + + if origin is collections.abc.Callable: + if len(targs) == 2: + params, ret = targs + ret_schema = TypeSchema.from_annotation(ret) + if isinstance(params, list): + # Callable[[P1, P2], R] → (R, P1, P2) + param_schemas = tuple( + TypeSchema.from_annotation(p) for p in params + ) + return TypeSchema("Callable", (ret_schema,) + param_schemas) + # Callable[..., R] + return TypeSchema("Callable", (ret_schema,)) + return TypeSchema("Callable") + + # --- Parameterised CObject subclasses (Array[int], Dict[str, V], …) --- + if isinstance(origin, type) and issubclass(origin, CObject): + return _annotation_cobject(origin, targs) + + # --- Bare (unparameterised) CObject subclasses --- + if isinstance(annotation, type) and issubclass(annotation, CObject): + return _annotation_cobject(annotation, ()) + + # --- PyNativeObject subclasses (String, Bytes) --- + if isinstance(annotation, type) and issubclass(annotation, PyNativeObject): + if issubclass(annotation, str): + return TypeSchema("str") + if issubclass(annotation, bytes): + return TypeSchema("bytes") + + # --- Non-CObject cdef classes with known origins --- + if annotation is DataType: + return TypeSchema("dtype") + if annotation is Device: + return TypeSchema("Device") + + # --- ctypes.c_void_p --- + import ctypes as _ctypes + if annotation is _ctypes.c_void_p: + return TypeSchema("ctypes.c_void_p") + + # --- Types with __dlpack__ protocol (e.g. torch.Tensor) → Tensor --- + if isinstance(annotation, type) and hasattr(annotation, "__dlpack__"): + return TypeSchema("Tensor") + + raise TypeError( + f"Cannot convert {annotation!r} to TypeSchema" + ) + + def check_value(self, value: object) -> None: + """Validate that *value* is compatible with this type schema. + + Parameters + ---------- + value : object + The Python value to check. + + Raises + ------ + TypeError + If the value is not compatible with the schema, with a + human-readable error message describing the mismatch. + """ + try: + _type_convert_impl(self._converter, value) + except RecursionError: + raise TypeError( + f"type check failed for {self!r}: " + f"infinite __tvm_ffi_value__ cycle detected" + ) from None + except _ConvertError as err: + raise TypeError(f"type check failed for {self!r}: {err.message}") from None + + def convert(self, value: object) -> "CAny": + """Convert *value* according to this type schema, returning a :class:`CAny`. + + Applies the same implicit conversions as the C++ FFI + ``TypeTraits::TryCastFromAnyView`` rules. The result is + always a :class:`CAny` instance that owns the converted value. + Use ``result.to_py()`` to recover the Python object. + + Parameters + ---------- + value : object + The Python value to convert. + + Returns + ------- + CAny + The converted value wrapped in a CAny. + + Raises + ------ + TypeError + If the value cannot be converted to this schema's type. + """ + try: + return _type_convert_impl(self._converter, value) + except RecursionError: + raise TypeError( + f"type conversion failed for {self!r}: " + f"infinite __tvm_ffi_value__ cycle detected" + ) from None + except _ConvertError as err: + raise TypeError(f"type conversion failed for {self!r}: {err.message}") from None + def repr(self, ty_map: "Optional[Callable[[str], str]]" = None) -> str: """Render a human-readable representation of this schema. @@ -185,7 +522,8 @@ class TypeSchema: origin = self.origin else: origin = ty_map(self.origin) - args = [i.repr(ty_map) for i in self.args] + schema_args = self.args + args = [i.repr(ty_map) for i in (() if schema_args is None else schema_args)] if origin == "Union": return " | ".join(args) elif origin == "Optional": @@ -197,12 +535,63 @@ class TypeSchema: ret = args[0] args = ", ".join(args[1:]) return f"Callable[[{args}], {ret}]" + elif origin == "tuple" and schema_args == (): + return "tuple[()]" elif not args: return origin else: args = ", ".join(args) return f"{origin}[{args}]" + def to_json(self) -> dict[str, Any]: + """Convert a TypeSchema to a JSON-compatible dict.""" + if self.args is not None and (self.args or self.origin == "tuple"): + return { + "type": self.origin, + "args": [a.to_json() for a in self.args], + } + return {"type": self.origin} + + +def _annotation_union(args): + """Convert Union type args to a TypeSchema (Optional or Union).""" + non_none = tuple(a for a in args if a is not type(None)) + has_none = len(non_none) < len(args) + converted = tuple(TypeSchema.from_annotation(a) for a in non_none) + if has_none: + if len(non_none) == 1: + return TypeSchema("Optional", converted) + return TypeSchema("Optional", (TypeSchema("Union", converted),)) + return TypeSchema("Union", converted) + + +def _annotation_cobject(cls, targs): + """Handle a CObject subclass (bare or parameterised) in from_annotation.""" + info = TYPE_CLS_TO_INFO.get(cls) + if info is None: + raise TypeError( + f"CObject subclass {cls!r} is not registered " + f"in TYPE_CLS_TO_INFO; use @register_object to register it" + ) + # Prefer canonical short origin from _TYPE_INDEX_TO_ORIGIN (e.g. "Array") + # over the registered type_key (e.g. "ffi.Array") when available. + origin = _TYPE_INDEX_TO_ORIGIN.get(info.type_index, info.type_key) + n = len(targs) + if n > 0: + if origin in ("Array", "List"): + if n != 1: + raise TypeError( + f"{origin} requires 1 type argument, got {n}" + ) + elif origin in ("Map", "Dict"): + if n != 2: + raise TypeError( + f"{origin} requires 2 type arguments, got {n}" + ) + arg_schemas = tuple(TypeSchema.from_annotation(a) for a in targs) + return TypeSchema(origin, arg_schemas, origin_type_index=info.type_index) + return TypeSchema(origin, origin_type_index=info.type_index) + @dataclasses.dataclass(eq=False) class TypeField: @@ -216,6 +605,7 @@ class TypeField: metadata: dict[str, Any] getter: FieldGetter setter: FieldSetter + ty: Optional[TypeSchema] = None c_init: bool = True c_kw_only: bool = False c_has_default: bool = False @@ -282,7 +672,7 @@ class TypeInfo: type_index: int type_key: str type_ancestors: list[int] - fields: list[TypeField] + fields: Optional[list[TypeField]] methods: list[TypeMethod] parent_type_info: Optional[TypeInfo] @@ -296,6 +686,315 @@ class TypeInfo: # ensure parent is registered self.parent_type_info = _lookup_or_register_type_info_from_type_key(parent_type_key) + @cached_property + def total_size(self) -> int: + """Total object size in bytes (header + all fields). + + For native C++ types with metadata, returns metadata.total_size. + For Python-defined types, computes from field layout. + """ + cdef const TVMFFITypeInfo* c_info = TVMFFIGetTypeInfo(self.type_index) + if c_info != NULL and c_info.metadata != NULL: + return c_info.metadata.total_size + cdef int64_t end = sizeof(TVMFFIObject) + if self.fields: + for f in self.fields: + f_end = f.offset + f.size + if f_end > end: + end = f_end + return (end + 7) & ~7 # align to 8 bytes + + def _register_fields(self, fields): + """Register Field descriptors and set up __ffi_new__/__ffi_init__. + + Delegates to the module-level _register_fields function, + stores the resulting list[TypeField] on self.fields, + then reads back methods registered by C++ via _register_methods. + + Can only be called once (fields must be None beforehand). + """ + assert self.fields is None, ( + f"_register_fields already called for {self.type_key!r}" + ) + self.fields = _register_fields(self, fields) + self._register_methods() + + def _register_methods(self): + """Read methods from the C type table into self.methods. + + Called after C++ registers __ffi_init__, __ffi_shallow_copy__, etc. + """ + cdef const TVMFFITypeInfo* c_info = TVMFFIGetTypeInfo(self.type_index) + cdef const TVMFFIMethodInfo* mi + self.methods = [] + for i in range(c_info.num_methods): + mi = &(c_info.methods[i]) + self.methods.append(TypeMethod( + name=bytearray_to_str(&mi.name), + doc=bytearray_to_str(&mi.doc) if mi.doc.size != 0 else None, + func=_get_method_from_method_info(mi), + is_static=(mi.flags & kTVMFFIFieldFlagBitMaskIsStaticMethod) != 0, + metadata=json.loads(bytearray_to_str(&mi.metadata)) if mi.metadata.size != 0 else {}, + )) + + +# --------------------------------------------------------------------------- +# Python-defined type field registration helpers +# --------------------------------------------------------------------------- + +# Native layout for each TypeSchema origin: (size, alignment, field_static_type_index) +_ORIGIN_NATIVE_LAYOUT = { + "int": (8, 8, kTVMFFIInt), + "float": (8, 8, kTVMFFIFloat), + "bool": (1, 1, kTVMFFIBool), + "ctypes.c_void_p": (8, 8, kTVMFFIOpaquePtr), + "dtype": (4, 2, kTVMFFIDataType), + "Device": (8, 4, kTVMFFIDevice), + "Any": (16, 8, -1), # kTVMFFIAny = -1 + # str/bytes can be SmallStr/SmallBytes (inline, not ObjectRef), + # so store as Any (16 bytes) to handle both inline and heap variants. + "str": (16, 8, -1), + "bytes": (16, 8, -1), + # Optional/Union can hold any type including inline scalars + "Optional": (16, 8, -1), + "Union": (16, 8, -1), +} + +cdef _register_one_field( + int32_t type_index, + object py_field, + int64_t offset, + int64_t size, + int64_t alignment, + int32_t field_type_index, + TVMFFIFieldGetter getter, + CObject setter_fn, +): + """Build a TVMFFIFieldInfo and register it for the given type.""" + cdef TVMFFIFieldInfo info + cdef int c_api_ret_code + + # --- name --- + name_bytes = c_str(py_field.name) + cdef ByteArrayArg name_arg = ByteArrayArg(name_bytes) + info.name = name_arg.cdata + + # --- doc --- + cdef ByteArrayArg doc_arg + if py_field.doc is not None: + doc_bytes = c_str(py_field.doc) + doc_arg = ByteArrayArg(doc_bytes) + info.doc = doc_arg.cdata + else: + info.doc.data = NULL + info.doc.size = 0 + + # --- metadata (JSON with type_schema) --- + metadata_str = json.dumps({"type_schema": py_field.ty.to_json()}) + metadata_bytes = c_str(metadata_str) + cdef ByteArrayArg metadata_arg = ByteArrayArg(metadata_bytes) + info.metadata = metadata_arg.cdata + + # --- flags --- + cdef int64_t flags = kTVMFFIFieldFlagBitMaskWritable | kTVMFFIFieldFlagBitSetterIsFunctionObj + if py_field.default is not MISSING or py_field.default_factory is not MISSING: + flags |= kTVMFFIFieldFlagBitMaskHasDefault + if py_field.default_factory is not MISSING: + flags |= kTVMFFIFieldFlagBitMaskDefaultFromFactory + if not py_field.init: + flags |= kTVMFFIFieldFlagBitMaskInitOff + if not py_field.repr: + flags |= kTVMFFIFieldFlagBitMaskReprOff + if not py_field.hash: + flags |= kTVMFFIFieldFlagBitMaskHashOff + if not py_field.compare: + flags |= kTVMFFIFieldFlagBitMaskCompareOff + if py_field.kw_only: + flags |= kTVMFFIFieldFlagBitMaskKwOnly + info.flags = flags + + # --- native layout --- + info.size = size + info.alignment = alignment + info.offset = offset + + # --- getter / setter --- + info.getter = getter + info.setter = setter_fn.chandle + + # --- default value --- + cdef TVMFFIAny default_any + default_any.type_index = kTVMFFINone + default_any.v_int64 = 0 + # Determine which Python object (if any) to store as the default. + # No memory leak: TVMFFIAny is a POD struct; TVMFFITypeRegisterField + # copies the bytes into the type table, which owns the reference. + cdef object default_obj = MISSING + if py_field.default is not MISSING: + default_obj = py_field.default + elif py_field.default_factory is not MISSING: + default_obj = py_field.default_factory + if default_obj is not MISSING: + TVMFFIPyPyObjectToFFIAny( + TVMFFIPyArgSetterFactory_, + default_obj, + &default_any, + &c_api_ret_code + ) + CHECK_CALL(c_api_ret_code) + info.default_value_or_factory = default_any + + # --- field_static_type_index --- + info.field_static_type_index = field_type_index + + CHECK_CALL(TVMFFITypeRegisterField(type_index, &info)) + + +cdef int _f_type_convert(void* type_converter, const TVMFFIAny* value, TVMFFIAny* result) noexcept with gil: + """C callback for type conversion, called from C++ MakeFieldSetter. + + Parameters + ---------- + type_converter : void* + A PyObject* pointing to a _TypeConverter instance (borrowed reference). + value : const TVMFFIAny* + The packed value to convert (borrowed from the caller). + result : TVMFFIAny* + Output: the converted value (caller takes ownership). + + Returns 0 on success, -1 on error (error stored in TLS via set_last_ffi_error). + """ + cdef TVMFFIAny temp + cdef _TypeConverter conv + cdef CAny cany + try: + # Unpack the packed AnyView to a Python object. + # We must IncRef if it's an object, because make_ret takes ownership. + temp = value[0] + if temp.type_index >= kTVMFFIStaticObjectBegin: + if temp.v_obj != NULL: + TVMFFIObjectIncRef(temp.v_obj) + py_value = make_ret(temp) + # Dispatch directly through the C-level converter + conv = <_TypeConverter>type_converter + cany = _type_convert_impl(conv, py_value) + # Transfer ownership from CAny to result (zero cany to prevent double-free) + result[0] = cany.cdata + cany.cdata.type_index = kTVMFFINone + cany.cdata.v_int64 = 0 + return 0 + except Exception as err: + set_last_ffi_error(err) + return -1 + + +def _register_fields(type_info, fields): + """Register Field descriptors for a Python-defined type and set up __ffi_new__/__ffi_init__. + + For each Field: + 1. Computes native layout (size, alignment, offset) + 2. Obtains a C getter function pointer + 3. Creates a FunctionObj setter with type conversion + 4. Registers via TVMFFITypeRegisterField + + After all fields, registers __ffi_new__ (object allocator) and + __ffi_init__ (auto-generated constructor). + + Parameters + ---------- + type_info : TypeInfo + The TypeInfo of the type being defined. + fields : list[Field] + The Field descriptors to register. + + Returns + ------- + list[TypeField] + The registered field descriptors. + """ + cdef int32_t type_index = type_info.type_index + # Start field offsets AFTER all parent fields (not at fixed offset 24). + # This is critical for inheritance: child fields must not overlap parent memory. + cdef int64_t current_offset = type_info.parent_type_info.total_size + cdef int64_t size, alignment + cdef int32_t field_type_index + cdef TVMFFIFieldGetter getter + cdef FieldGetter fgetter + cdef FieldSetter fsetter + + # Get global functions + _get_field_getter = _get_global_func("ffi.GetFieldGetter", False) + _make_field_setter = _get_global_func("ffi.MakeFieldSetter", False) + _make_ffi_new = _get_global_func("ffi.MakeFFINew", False) + _register_auto_init = _get_global_func("ffi.RegisterAutoInit", False) + + cdef list type_fields = [] + + for py_field in fields: + # 1. Get layout + layout = _ORIGIN_NATIVE_LAYOUT.get(py_field.ty.origin, (8, 8, kTVMFFIObject)) + size = layout[0] + alignment = layout[1] + field_type_index = layout[2] + + # 2. Compute offset (align up) + current_offset = (current_offset + alignment - 1) & ~(alignment - 1) + field_offset = current_offset + current_offset += size + + # 3. Get getter (C function pointer) and setter (FunctionObj). + # Pointers are transported as int64_t through the FFI boundary. + getter = _get_field_getter(field_type_index) + setter_fn = _make_field_setter( + field_type_index, + py_field.ty._converter, + &_f_type_convert, + ) + + # 4. Register field in the C type table + _register_one_field( + type_index, py_field, field_offset, size, alignment, + field_type_index, getter, setter_fn, + ) + + # 5. Build the Python-side TypeField descriptor + fgetter = FieldGetter.__new__(FieldGetter) + fgetter.getter = getter + fgetter.offset = field_offset + fsetter = FieldSetter.__new__(FieldSetter) + fsetter.setter = setter_fn.chandle + fsetter.offset = field_offset + fsetter.flags = (kTVMFFIFieldFlagBitMaskWritable | kTVMFFIFieldFlagBitSetterIsFunctionObj) + type_fields.append( + TypeField( + name=py_field.name, + doc=py_field.doc, + size=size, + offset=field_offset, + frozen=False, + metadata={"type_schema": py_field.ty.to_json()}, + getter=fgetter, + setter=fsetter, + ty=py_field.ty, + c_init=(py_field.init if hasattr(py_field, "init") else True), + c_kw_only=(py_field.kw_only if hasattr(py_field, "kw_only") else False), + c_has_default=(py_field.default is not MISSING or py_field.default_factory is not MISSING), + ) + ) + + # Align total size to 8 bytes + cdef int64_t total_size = (current_offset + 7) & ~7 + if total_size < sizeof(TVMFFIObject): + total_size = sizeof(TVMFFIObject) + + # 7. Register __ffi_new__ + deleter + _make_ffi_new(type_index, total_size) + + # 8. Register __ffi_init__ (auto-generated constructor) + _register_auto_init(type_index) + + return type_fields + def _member_method_wrapper(method_func: Callable[..., Any]) -> Callable[..., Any]: def wrapper(self: Any, *args: Any) -> Any: diff --git a/python/tvm_ffi/dataclasses/__init__.py b/python/tvm_ffi/dataclasses/__init__.py index 8ea00ab3..bb6a0391 100644 --- a/python/tvm_ffi/dataclasses/__init__.py +++ b/python/tvm_ffi/dataclasses/__init__.py @@ -14,8 +14,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""C++ FFI classes with structural comparison and hashing.""" +"""FFI dataclass decorators: ``c_class`` for C++-backed types, ``py_class`` for Python-defined types.""" from .c_class import c_class +from .field import KW_ONLY, Field, field +from .py_class import py_class -__all__ = ["c_class"] +__all__ = ["KW_ONLY", "Field", "c_class", "field", "py_class"] diff --git a/python/tvm_ffi/dataclasses/field.py b/python/tvm_ffi/dataclasses/field.py new file mode 100644 index 00000000..9295f13a --- /dev/null +++ b/python/tvm_ffi/dataclasses/field.py @@ -0,0 +1,204 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Field descriptor and ``field()`` helper for Python-defined TVM-FFI types.""" + +from __future__ import annotations + +import sys +from collections.abc import Callable +from typing import Any + +from ..core import MISSING, TypeSchema + +# Re-export the stdlib KW_ONLY sentinel so type checkers recognise +# ``_: KW_ONLY`` as a keyword-only boundary rather than a real field. +# dataclasses.KW_ONLY was added in Python 3.10; on older runtimes we +# define a class sentinel (a class, not an instance, so that ``_: KW_ONLY`` +# is a valid type annotation for static analysers targeting 3.9). +if sys.version_info >= (3, 10): + from dataclasses import KW_ONLY +else: + + class KW_ONLY: + """Sentinel type: annotations after ``_: KW_ONLY`` are keyword-only.""" + + +class Field: + """Descriptor for a single field in a Python-defined TVM-FFI type. + + When constructed directly (low-level API), *name* and *ty* should be + provided. When returned by :func:`field` (``@py_class`` workflow), + *name* and *ty* are ``None`` and filled in by the decorator. + + Parameters + ---------- + name : str | None + The field name. ``None`` when created via :func:`field`; filled + in by the ``@py_class`` decorator. + ty : TypeSchema | None + The type schema. ``None`` when created via :func:`field`; filled + in by the ``@py_class`` decorator. + default : object + Default value for the field. Mutually exclusive with *default_factory*. + ``MISSING`` when not set. + default_factory : Callable[[], object] | None + A zero-argument callable that produces the default value. + Mutually exclusive with *default*. ``None`` when not set. + init : bool + Whether this field appears in the auto-generated ``__init__``. + repr : bool + Whether this field appears in ``__repr__`` output. + hash : bool | None + Whether this field participates in recursive hashing. + ``None`` means "follow *compare*" (the native dataclass default). + compare : bool + Whether this field participates in recursive comparison. + kw_only : bool | None + Whether this field is keyword-only in ``__init__``. + ``None`` means "inherit from the decorator-level *kw_only* flag". + doc : str | None + Optional docstring for the field. + + """ + + __slots__ = ( + "compare", + "default", + "default_factory", + "doc", + "hash", + "init", + "kw_only", + "name", + "repr", + "ty", + ) + name: str | None + ty: TypeSchema | None + default: object + default_factory: Callable[[], object] | None + init: bool + repr: bool + hash: bool | None + compare: bool + kw_only: bool | None + doc: str | None + + def __init__( + self, + name: str | None = None, + ty: TypeSchema | None = None, + *, + default: object = MISSING, + default_factory: Callable[[], object] | None = MISSING, # type: ignore[assignment] + init: bool = True, + repr: bool = True, + hash: bool | None = True, + compare: bool = False, + kw_only: bool | None = False, + doc: str | None = None, + ) -> None: + # MISSING means "parameter not provided". + # An explicit None from the user fails the callable() check, + # matching stdlib dataclasses semantics. + if default_factory is not MISSING: + if default is not MISSING: + raise ValueError("cannot specify both default and default_factory") + if not callable(default_factory): + raise TypeError( + f"default_factory must be a callable, got {type(default_factory).__name__}" + ) + self.name = name + self.ty = ty + self.default = default + self.default_factory = default_factory + self.init = init + self.repr = repr + self.hash = hash + self.compare = compare + self.kw_only = kw_only + self.doc = doc + + +def field( + *, + default: object = MISSING, + default_factory: Callable[[], object] | None = MISSING, # type: ignore[assignment] + init: bool = True, + repr: bool = True, + hash: bool | None = None, + compare: bool = True, + kw_only: bool | None = None, + doc: str | None = None, +) -> Any: + """Customize a field in a ``@py_class``-decorated class. + + Returns a :class:`Field` sentinel whose *name* and *ty* are + ``None``. The ``@py_class`` decorator fills them in later + from the class annotations. + + The return type is ``Any`` because ``dataclass_transform`` field + specifiers must be assignable to any annotated type (e.g. + ``x: int = field(default=0)``). + + Parameters + ---------- + default + Default value for the field. Mutually exclusive with *default_factory*. + default_factory + A zero-argument callable that produces the default value. + Mutually exclusive with *default*. + init + Whether this field appears in the auto-generated ``__init__``. + repr + Whether this field appears in ``__repr__`` output. + hash + Whether this field participates in recursive hashing. + ``None`` (default) means "follow *compare*". + compare + Whether this field participates in recursive comparison. + kw_only + Whether this field is keyword-only in ``__init__``. + ``None`` means "inherit from the decorator-level ``kw_only`` flag". + doc + Optional docstring for the field. + + Returns + ------- + Any + A :class:`Field` sentinel recognised by ``@py_class``. + + Examples + -------- + .. code-block:: python + + @py_class + class Point(Object): + x: float + y: float = field(default=0.0, repr=False) + + """ + return Field( + default=default, + default_factory=default_factory, + init=init, + repr=repr, + hash=hash, + compare=compare, + kw_only=kw_only, + doc=doc, + ) diff --git a/python/tvm_ffi/dataclasses/py_class.py b/python/tvm_ffi/dataclasses/py_class.py new file mode 100644 index 00000000..9080fae3 --- /dev/null +++ b/python/tvm_ffi/dataclasses/py_class.py @@ -0,0 +1,459 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""The ``py_class`` decorator: Python-defined FFI classes with dataclass semantics.""" + +from __future__ import annotations + +import sys +import typing +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any, ClassVar, TypeVar + +from typing_extensions import dataclass_transform + +from .. import core +from ..core import MISSING, TypeSchema +from ..registry import _add_class_attrs, _install_dataclass_dunders +from .field import KW_ONLY, Field, field + +_T = TypeVar("_T", bound=type) + + +# --------------------------------------------------------------------------- +# Module-level state +# --------------------------------------------------------------------------- +# +# Registration happens in two phases: +# +# Phase 1 (_phase1_register_type) +# Allocates a C-level type index and inserts the class into the +# global type registry. This must happen early so that self- +# referential and mutually-referential annotations can resolve +# the class via ``TypeSchema.from_annotation()``. Phase 1 always +# succeeds (or raises immediately for non-Object parents). +# +# Phase 2 (_phase2_register_fields) +# Resolves string annotations via ``typing.get_type_hints``, +# converts them to ``TypeSchema`` / ``Field`` objects, validates +# field ordering, registers fields with the Cython layer, and +# installs ``__init__``, ``__repr__``, ``__eq__``, etc. +# +# If ``get_type_hints`` raises ``NameError`` (forward reference +# not yet defined), the class is added to ``_PENDING_CLASSES`` +# and retried after each successful phase-2. If phase-2 fails +# for any other reason, ``_rollback_registration`` undoes phase-1 +# so the type key can be reused. +# --------------------------------------------------------------------------- + + +@dataclass +class _PendingClass: + """Bookkeeping for a class whose annotations couldn't be resolved yet.""" + + cls: type + type_info: Any # core.TypeInfo + globalns: dict[str, Any] + params: dict[str, Any] + + +#: Classes whose phase-2 (field registration) was deferred because +#: ``typing.get_type_hints`` raised ``NameError`` on an unresolved +#: forward reference. Retried after each successful phase-2 via +#: :func:`_flush_pending`. +_PENDING_CLASSES: list[_PendingClass] = [] + +#: Per-module mapping of ``class.__name__ → class`` for every +#: ``@py_class``-decorated type. Used as *localns* when resolving +#: annotations so that mutual references between classes in the same +#: module work even before the second class is assigned to the module +#: variable by Python. +_PY_CLASS_BY_MODULE: dict[str, dict[str, type]] = {} + + +# --------------------------------------------------------------------------- +# Phase 1: type registration +# --------------------------------------------------------------------------- + + +def _phase1_register_type(cls: type, type_key: str | None) -> Any: + """Phase 1: allocate type index and register the type (always succeeds).""" + parent_info: core.TypeInfo | None = None + for base in cls.__bases__: + parent_info = core._type_cls_to_type_info(base) + if parent_info is not None: + break + if parent_info is None: + raise TypeError( + f"{cls.__name__} must inherit from a registered FFI Object type (e.g. tvm_ffi.Object)" + ) + if type_key is None: + type_key = f"{cls.__module__}.{cls.__qualname__}" + info = core._register_py_class(parent_info, type_key, cls) + setattr(cls, "__tvm_ffi_type_info__", info) + # Register in resolution namespace so sibling classes can find us + _PY_CLASS_BY_MODULE.setdefault(cls.__module__, {})[cls.__name__] = cls + return info + + +def _rollback_registration(cls: type, type_info: Any) -> None: + """Undo :func:`_phase1_register_type` after a phase-2 failure. + + The C-level type index is permanently consumed (cannot be reclaimed), + but the Python-level registry dicts are cleaned up so a retry with + the same type key does not hit "already registered". + """ + # Remove from the Cython-level registry dicts (TYPE_KEY_TO_INFO, + # TYPE_CLS_TO_INFO, TYPE_INDEX_TO_INFO, TYPE_INDEX_TO_CLS). + core._rollback_py_class(type_info) # ty: ignore[unresolved-attribute] + # Remove from our own module-level resolution namespace. + _PY_CLASS_BY_MODULE.get(cls.__module__, {}).pop(cls.__name__, None) + try: + delattr(cls, "__tvm_ffi_type_info__") + except AttributeError: + pass + + +# --------------------------------------------------------------------------- +# Phase 2: annotation resolution, field registration, dunder installation +# --------------------------------------------------------------------------- + + +def _collect_own_fields( + cls: type, + hints: dict[str, Any], + decorator_kw_only: bool, +) -> list[Field]: + """Parse own annotations into :class:`Field` objects. + + - Skips ``ClassVar`` annotations. + - Handles ``KW_ONLY`` sentinel. + - Extracts ``Field`` metadata from class attributes (set via :func:`field`). + - Handles bare defaults (non-``Field`` values). + - Converts resolved types to ``TypeSchema``. + - Resolves ``hash=None`` to follow ``compare``. + """ + fields: list[Field] = [] + kw_only_active = decorator_kw_only + own_annotations: dict[str, str] = getattr(cls, "__annotations__", {}) + + for name in own_annotations: + resolved_type = hints.get(name) + # Skip ClassVar + if ( + resolved_type is None + or resolved_type is ClassVar + or typing.get_origin(resolved_type) is ClassVar + ): + continue + + # KW_ONLY sentinel + if resolved_type is KW_ONLY: + kw_only_active = True + if name in cls.__dict__: + try: + delattr(cls, name) + except AttributeError: + pass + continue + + # Extract Field from class dict (inline of _pop_field_from_class) + class_val = cls.__dict__.get(name, MISSING) + if isinstance(class_val, Field): + f = class_val + elif class_val is not MISSING: + f = field(default=class_val) + else: + f = field() + if class_val is not MISSING: + try: + delattr(cls, name) + except AttributeError: + pass + + # Fill in name and ty (set by the decorator, not the user) + f.name = name + f.ty = TypeSchema.from_annotation(resolved_type) + + # Resolve kw_only: None means "inherit from decorator" + if f.kw_only is None: + f.kw_only = kw_only_active + + # Resolve hash=None → follow compare (native dataclass semantics) + if f.hash is None: + f.hash = f.compare + + fields.append(f) + + return fields + + +def _phase2_register_fields( + cls: type, + type_info: Any, + globalns: dict[str, Any], + params: dict[str, Any], +) -> bool: + """Phase 2: resolve annotations, register fields, install dunders. + + Returns True on success, False if forward references are unresolved. + """ + # Resolve string annotations to types; return False (defer) on NameError. + localns = dict(_PY_CLASS_BY_MODULE.get(cls.__module__, {})) + localns[cls.__name__] = cls + try: + kwargs: dict[str, Any] = {"globalns": globalns, "localns": localns} + if sys.version_info >= (3, 11): + kwargs["include_extras"] = True + hints = typing.get_type_hints(cls, **kwargs) + except (NameError, AttributeError): + return False + + own_fields = _collect_own_fields(cls, hints, params["kw_only"]) + + type_info._register_fields(own_fields) + _add_class_attrs(cls, type_info) + + # Remove deferred __init__ and restore user-defined __init__ if saved + if "_py_class_deferred_init" in cls.__dict__: + # Always remove the deferred wrapper + if "__init__" in cls.__dict__: + delattr(cls, "__init__") + try: + delattr(cls, "_py_class_deferred_init") + except AttributeError: + pass + # Restore user-defined __init__ if it was saved + user_init = cls.__dict__.get("_py_class_user_init") + if user_init is not None: + cls.__init__ = user_init + delattr(cls, "_py_class_user_init") + + _install_dataclass_dunders( + cls, + init=params["init"], + repr=params["repr"], + eq=params["eq"], + order=params["order"], + unsafe_hash=params["unsafe_hash"], + ) + return True + + +# --------------------------------------------------------------------------- +# Deferred resolution (when phase 2 cannot run at decoration time) +# --------------------------------------------------------------------------- + + +def _flush_pending() -> None: + """Retry all pending classes. Called after each successful phase 2.""" + changed = True + while changed: + changed = False + remaining: list[_PendingClass] = [] + for entry in _PENDING_CLASSES: + if _phase2_register_fields(entry.cls, entry.type_info, entry.globalns, entry.params): + changed = True + else: + remaining.append(entry) + _PENDING_CLASSES[:] = remaining + + +def _raise_unresolved_forward_reference(cls: type, globalns: dict[str, Any]) -> None: + """Raise :class:`TypeError` listing the annotations that cannot be resolved.""" + localns = dict(_PY_CLASS_BY_MODULE.get(cls.__module__, {})) + localns[cls.__name__] = cls + unresolved: list[str] = [] + for name, ann_str in getattr(cls, "__annotations__", {}).items(): + if isinstance(ann_str, str): + try: + eval(ann_str, globalns, localns) + except NameError: + unresolved.append(f"{name}: {ann_str}") + raise TypeError( + f"Cannot instantiate {cls.__name__}: unresolved forward references: {unresolved}" + ) + + +def _make_temporary_init( + cls: type, type_info: Any, globalns: dict[str, Any], params: dict[str, Any] +) -> Callable[[...], None]: + def __init__(self: Any, *args: Any, **kwargs: Any) -> None: + if type_info.fields is None: + try: + if not _phase2_register_fields(cls, type_info, globalns, params): + _raise_unresolved_forward_reference(cls, globalns) + _flush_pending() + except Exception: + # Remove from pending list and roll back so the type key can be reused. + _PENDING_CLASSES[:] = [p for p in _PENDING_CLASSES if p.cls is not cls] + _rollback_registration(cls, type_info) + raise + # cls.__init__ has been replaced by the real init (or restored user init) + cls.__init__(self, *args, **kwargs) + + __init__.__qualname__ = f"{cls.__qualname__}.__init__" + __init__.__module__ = cls.__module__ + return __init__ + + +def _install_deferred_init( + cls: type, + type_info: Any, + globalns: dict[str, Any], + params: dict[str, Any], +) -> None: + """Install a temporary ``__init__`` that completes registration on first call. + + Preserves a user-defined ``__init__`` if present in *cls.__dict__*; + it is restored by :func:`_phase2_register_fields` after registration + completes so that ``_install_dataclass_dunders`` sees it and skips + auto-generation. + """ + # Save user-defined __init__ before overwriting + user_init = cls.__dict__.get("__init__") + if user_init is not None: + cls._py_class_user_init = user_init # type: ignore[attr-defined] + + cls.__init__ = _make_temporary_init(cls, type_info, globalns, params) # type: ignore[assignment] + cls._py_class_deferred_init = True # type: ignore[attr-defined] + + +# --------------------------------------------------------------------------- +# Main decorator +# --------------------------------------------------------------------------- + + +@dataclass_transform( + eq_default=False, + order_default=False, + field_specifiers=(field, Field), +) +def py_class( + cls_or_type_key: type | str | None = None, + /, + *, + type_key: str | None = None, + init: bool = True, + repr: bool = True, + eq: bool = False, + order: bool = False, + unsafe_hash: bool = False, + kw_only: bool = False, + slots: bool = True, +) -> Callable[[_T], _T] | _T: + """Register a Python-defined FFI class with dataclass-style semantics. + + Can be used as: + + .. code-block:: python + + @py_class # bare decorator + class Point(Object): + x: float + y: float + + + @py_class("my.Point") # with explicit type_key + class Point(Object): ... + + + @py_class(eq=True, order=True) # with options + class Point(Object): ... + + + @py_class("my.Point", eq=True) # both + class Point(Object): ... + + Parameters + ---------- + cls_or_type_key + When a string, used as the FFI type key. When a type (bare + decorator usage), the class to decorate. + type_key + Explicit FFI type key. Auto-generated from + ``{module}.{qualname}`` when omitted. + init + If True (default), generate ``__init__`` from field annotations. + repr + If True (default), generate ``__repr__``. + eq + If True, generate ``__eq__`` and ``__ne__``. + order + If True, generate ``__lt__``, ``__le__``, ``__gt__``, ``__ge__``. + Requires ``eq=True``. + unsafe_hash + If True, generate ``__hash__`` (unsafe for mutable objects). + kw_only + If True, all fields are keyword-only in ``__init__`` by default. + slots + Accepted for ``dataclass_transform`` compatibility. Object + subclasses always use ``__slots__ = ()`` via the metaclass. + + Returns + ------- + Callable | type + A class decorator, or the decorated class (bare usage). + + """ + if order and not eq: + raise ValueError("order=True requires eq=True") + + effective_type_key = type_key + params: dict[str, Any] = { + "init": init, + "repr": repr, + "eq": eq, + "order": order, + "unsafe_hash": unsafe_hash, + "kw_only": kw_only, + } + + def decorator(cls: _T) -> _T: + nonlocal effective_type_key + globalns = getattr(sys.modules.get(cls.__module__, None), "__dict__", {}) + + info = _phase1_register_type(cls, effective_type_key) + + try: + if _phase2_register_fields(cls, info, globalns, params): + _flush_pending() + else: + _PENDING_CLASSES.append(_PendingClass(cls, info, globalns, params)) + _install_deferred_init(cls, info, globalns, params) + except Exception: + # Phase-2 failed (bad annotation, field ordering, etc.). + # Roll back phase-1 so the type key can be reused after + # the user fixes the error. + _rollback_registration(cls, info) + raise + + return cls + + # Handle different calling conventions: + # @py_class → cls_or_type_key is the class + # @py_class("key") → cls_or_type_key is a string + # @py_class() → cls_or_type_key is None + # @py_class(eq=True) → cls_or_type_key is None + if cls_or_type_key is None: + return decorator + if isinstance(cls_or_type_key, str): + effective_type_key = cls_or_type_key + return decorator + if isinstance(cls_or_type_key, type): + return decorator(cls_or_type_key) + raise TypeError(f"py_class: expected str or type, got {type(cls_or_type_key)}") diff --git a/python/tvm_ffi/registry.py b/python/tvm_ffi/registry.py index 82540494..8a84bff1 100644 --- a/python/tvm_ffi/registry.py +++ b/python/tvm_ffi/registry.py @@ -356,16 +356,30 @@ def _make_init(type_cls: type, type_info: TypeInfo) -> Callable[..., None]: """ sig = _make_init_signature(type_info) kwargs_obj = core.KWARGS + has_post_init = hasattr(type_cls, "__post_init__") - def __init__(self: Any, *args: Any, **kwargs: Any) -> None: - ffi_args: list[Any] = list(args) - ffi_args.append(kwargs_obj) - for key, val in kwargs.items(): - ffi_args.append(key) - ffi_args.append(val) - self.__ffi_init__(*ffi_args) - - __init__.__signature__ = sig # ty: ignore[unresolved-attribute] + if has_post_init: + + def __init__(self: Any, *args: Any, **kwargs: Any) -> None: + ffi_args: list[Any] = list(args) + ffi_args.append(kwargs_obj) + for key, val in kwargs.items(): + ffi_args.append(key) + ffi_args.append(val) + self.__ffi_init__(*ffi_args) + self.__post_init__() + + else: + + def __init__(self: Any, *args: Any, **kwargs: Any) -> None: + ffi_args: list[Any] = list(args) + ffi_args.append(kwargs_obj) + for key, val in kwargs.items(): + ffi_args.append(key) + ffi_args.append(val) + self.__ffi_init__(*ffi_args) + + __init__.__signature__ = sig # ty: ignore[invalid-assignment] __init__.__qualname__ = f"{type_cls.__qualname__}.__init__" __init__.__module__ = type_cls.__module__ return __init__ @@ -666,7 +680,19 @@ def __ge__(self: Any, other: Any) -> bool: dunders["__gt__"] = __gt__ dunders["__ge__"] = __ge__ + # Install dunders respecting user-defined overrides. + # Semantic families (__eq__/__ne__, __lt__/__le__/__gt__/__ge__) are + # treated as a unit: if the user defines any member, the whole family + # is skipped so generated and user-defined methods don't disagree. + _eq_family = {"__eq__", "__ne__"} + _order_family = {"__lt__", "__le__", "__gt__", "__ge__"} + skip_eq = bool(_eq_family & set(cls.__dict__)) + skip_order = bool(_order_family & set(cls.__dict__)) for name, impl in dunders.items(): + if name in _eq_family and skip_eq: + continue + if name in _order_family and skip_order: + continue if name not in cls.__dict__: setattr(cls, name, impl) diff --git a/rust/tvm-ffi-sys/src/c_api.rs b/rust/tvm-ffi-sys/src/c_api.rs index 2035eda9..56209bfa 100644 --- a/rust/tvm-ffi-sys/src/c_api.rs +++ b/rust/tvm-ffi-sys/src/c_api.rs @@ -283,9 +283,15 @@ pub struct TVMFFIFieldInfo { pub offset: i64, /// The getter to access the field pub getter: Option, - /// The setter to access the field - /// The setter is set even if the field is readonly for serialization - pub setter: Option, + /// The setter to access the field. + /// + /// When kTVMFFIFieldFlagBitSetterIsFunctionObj is NOT set (default), + /// this is a TVMFFIFieldSetter function pointer cast to *mut c_void. + /// When kTVMFFIFieldFlagBitSetterIsFunctionObj IS set, + /// this is a TVMFFIObjectHandle pointing to a FunctionObj. + /// + /// The setter is set even if the field is readonly for serialization. + pub setter: *mut c_void, /// The default value or factory of the field, this field holds AnyView. /// Valid when flags set kTVMFFIFieldFlagBitMaskHasDefault. /// When kTVMFFIFieldFlagBitMaskDefaultFromFactory is also set, diff --git a/src/ffi/extra/dataclass.cc b/src/ffi/extra/dataclass.cc index 9ea64cf2..8b5eebba 100644 --- a/src/ffi/extra/dataclass.cc +++ b/src/ffi/extra/dataclass.cc @@ -1662,6 +1662,254 @@ class RecursiveComparer : public ObjectGraphDFS(self_void); + if (flags & kTVMFFIObjectDeleterFlagBitMaskStrong) { + const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(self->type_index); + reflection::ForEachFieldInfo(type_info, [&](const TVMFFIFieldInfo* finfo) { + void* field_addr = reinterpret_cast(self) + finfo->offset; + int32_t ti = finfo->field_static_type_index; + if (ti == TypeIndex::kTVMFFIAny) { + // Any field: call destructor to release owned references + reinterpret_cast(field_addr)->~Any(); + } else if (ti >= TypeIndex::kTVMFFIStaticObjectBegin) { + // ObjectRef field: call destructor to DecRef + reinterpret_cast(field_addr)->~ObjectRef(); + } + // POD types (int, float, bool, etc.): no cleanup needed + }); + } + if (flags & kTVMFFIObjectDeleterFlagBitMaskWeak) { + std::free(self_void); + } +} + +/*! + * \brief Generic field getter for Python-defined types. + * + * Reads a value of type T from the given field address and packs it into + * a TVMFFIAny result. + * + * \tparam T The C++ type stored at the field address. + */ +template +int PyClassFieldGetter(void* field, TVMFFIAny* result) { + TVM_FFI_SAFE_CALL_BEGIN(); + *result = details::AnyUnsafe::MoveAnyToTVMFFIAny(Any(*reinterpret_cast(field))); + TVM_FFI_SAFE_CALL_END(); +} + +/*! + * \brief Return the TVMFFIFieldGetter function pointer for a given field type index. + * + * \param field_type_index The static type index of the field. + * \return The function pointer as int64_t for FFI transport. + */ +int64_t GetFieldGetter(int32_t field_type_index) { + TVMFFIFieldGetter getter = nullptr; + switch (field_type_index) { + case TypeIndex::kTVMFFIInt: + getter = &PyClassFieldGetter; + break; + case TypeIndex::kTVMFFIFloat: + getter = &PyClassFieldGetter; + break; + case TypeIndex::kTVMFFIBool: + getter = &PyClassFieldGetter; + break; + case TypeIndex::kTVMFFIOpaquePtr: + getter = &PyClassFieldGetter; + break; + case TypeIndex::kTVMFFIDataType: + getter = &PyClassFieldGetter; + break; + case TypeIndex::kTVMFFIDevice: + getter = &PyClassFieldGetter; + break; + default: + if (field_type_index == TypeIndex::kTVMFFIAny || field_type_index == TypeIndex::kTVMFFINone) { + getter = &PyClassFieldGetter; + } else if (field_type_index >= TypeIndex::kTVMFFIStaticObjectBegin) { + getter = &PyClassFieldGetter; + } else { + TVM_FFI_THROW(ValueError) << "Unsupported field type index for getter: " + << field_type_index; + } + break; + } + return reinterpret_cast(getter); +} + +/*! + * \brief Write a converted value to a field of the appropriate C++ type. + * + * Dispatches on field_type_index to reinterpret the destination address and + * assign from the converted Any value. + */ +void WriteFieldValue(void* field_addr, int32_t field_type_index, Any value) { + switch (field_type_index) { + case TypeIndex::kTVMFFIInt: + *reinterpret_cast(field_addr) = value.cast(); + return; + case TypeIndex::kTVMFFIFloat: + *reinterpret_cast(field_addr) = value.cast(); + return; + case TypeIndex::kTVMFFIBool: + *reinterpret_cast(field_addr) = value.cast(); + return; + case TypeIndex::kTVMFFIOpaquePtr: + *reinterpret_cast(field_addr) = value.cast(); + return; + case TypeIndex::kTVMFFIDataType: + *reinterpret_cast(field_addr) = value.cast(); + return; + case TypeIndex::kTVMFFIDevice: + *reinterpret_cast(field_addr) = value.cast(); + return; + default: + break; + } + if (field_type_index == TypeIndex::kTVMFFIAny || field_type_index == TypeIndex::kTVMFFINone) { + *reinterpret_cast(field_addr) = std::move(value); + } else if (field_type_index >= TypeIndex::kTVMFFIStaticObjectBegin) { + *reinterpret_cast(field_addr) = value.cast(); + } else { + TVM_FFI_THROW(ValueError) << "Unsupported field type index for setter: " << field_type_index; + } +} + +/*! + * \brief Create a FunctionObj setter for a Python-defined field. + * + * The returned Function accepts (OpaquePtr field_addr, AnyView value), + * calls f_convert to coerce the value via the type_converter, and writes + * the result to the field. + * + * \param field_type_index The static type index of the field. + * \param type_converter_int Opaque pointer (as int64_t) to the Python _TypeConverter (borrowed). + * \param f_convert_int C function pointer (as int64_t): int(void*, const TVMFFIAny*, TVMFFIAny*). + * Returns 0 on success, -1 on error (error stored in TLS). + * \return A packed Function suitable for use as a FunctionObj setter. + */ +Function MakeFieldSetter(int32_t field_type_index, int64_t type_converter_int, + int64_t f_convert_int) { + // NOLINTNEXTLINE(performance-no-int-to-ptr) + void* type_converter = reinterpret_cast(type_converter_int); + using FConvert = int (*)(void*, const TVMFFIAny*, TVMFFIAny*); + // NOLINTNEXTLINE(performance-no-int-to-ptr) + auto f_convert = reinterpret_cast(f_convert_int); + + return Function::FromPacked([field_type_index, type_converter, f_convert]( + const AnyView* args, int32_t num_args, Any* rv) { + void* field_addr = args[0].cast(); + // Call the Cython-level type converter via C function pointer. + TVMFFIAny converted; + converted.type_index = TypeIndex::kTVMFFINone; + converted.v_int64 = 0; + int err = f_convert(type_converter, reinterpret_cast(&args[1]), &converted); + if (err != 0) { + throw details::MoveFromSafeCallRaised(); + } + // Take ownership of the converted value and write to the field. + Any owned = details::AnyUnsafe::MoveTVMFFIAnyToAny(&converted); + WriteFieldValue(field_addr, field_type_index, std::move(owned)); + }); +} + +/*! + * \brief Register a ``__ffi_new__`` type attribute for a Python-defined type. + * + * Creates a factory Function that allocates zero-initialized memory of the + * given size, sets up the TVMFFIObject header (type_index, ref counts, + * deleter), and returns an ObjectRef. Also registers this factory as the + * ``__ffi_new__`` type attribute so that ``CreateEmptyObject`` can find it. + * + * \param type_index The type index of the Python-defined type. + * \param total_size The total object size in bytes (header + fields). + */ +void MakeFFINew(int32_t type_index, int32_t total_size) { + // Pre-compute type_info pointer (stable for the process lifetime). + // Used by the shallow-copy lambda below; new_fn doesn't need it since + // calloc zero-initialization suffices (no placement construction needed). + const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(type_index); + Function new_fn = Function::FromTyped([type_index, total_size]() -> ObjectRef { + void* obj_ptr = std::calloc(1, static_cast(total_size)); + if (!obj_ptr) { + TVM_FFI_THROW(RuntimeError) << "Failed to allocate " << total_size << " bytes for type " + << TypeIndexToTypeKey(type_index); + } + TVMFFIObject* ffi_obj = reinterpret_cast(obj_ptr); + ffi_obj->type_index = type_index; + ffi_obj->combined_ref_count = details::kCombinedRefCountBothOne; + ffi_obj->deleter = PyClassDeleter; + // calloc zero-initializes all bytes. For non-trivial field types: + // - Any: zero state is {type_index=kTVMFFINone, v_int64=0}, representing None. + // - ObjectRef: zero state is a null pointer. + // Both are valid initial states whose destructors and assignment operators + // handle correctly, so no placement construction is needed. + Object* obj = reinterpret_cast(obj_ptr); + return ObjectRef(details::ObjectUnsafe::ObjectPtrFromOwned(obj)); + }); + // Register as __ffi_new__ type attribute + reflection::EnsureTypeAttrColumn("__ffi_new__"); + TVMFFIByteArray attr_name = {"__ffi_new__", 11}; + TVMFFIAny attr_value = AnyView(new_fn).CopyToTVMFFIAny(); + TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterAttr(type_index, &attr_name, &attr_value)); + // Register __ffi_shallow_copy__ for deep-copy support. + // The shallow copy allocates a new object and copies all fields by value + // (IncRef-ing any ObjectRef/Any fields). + Function copy_fn = + Function::FromTyped([type_index, total_size, type_info](const Object* src) -> ObjectRef { + void* obj_ptr = std::calloc(1, static_cast(total_size)); + if (!obj_ptr) { + TVM_FFI_THROW(RuntimeError) << "Failed to allocate for shallow copy"; + } + TVMFFIObject* ffi_obj = reinterpret_cast(obj_ptr); + ffi_obj->type_index = type_index; + ffi_obj->combined_ref_count = details::kCombinedRefCountBothOne; + ffi_obj->deleter = PyClassDeleter; + reflection::ForEachFieldInfo(type_info, [&](const TVMFFIFieldInfo* finfo) { + void* dst = reinterpret_cast(obj_ptr) + finfo->offset; + const void* field_src = reinterpret_cast(src) + finfo->offset; + int32_t ti = finfo->field_static_type_index; + if (ti == TypeIndex::kTVMFFIAny) { + new (dst) Any(*reinterpret_cast(field_src)); + } else if (ti >= TypeIndex::kTVMFFIStaticObjectBegin) { + new (dst) ObjectRef(*reinterpret_cast(field_src)); + } else { + // POD: memcpy + std::memcpy(dst, field_src, static_cast(finfo->size)); + } + }); + Object* obj = reinterpret_cast(obj_ptr); + return ObjectRef(details::ObjectUnsafe::ObjectPtrFromOwned(obj)); + }); + // Register as type attribute for generic deep copy lookup + reflection::EnsureTypeAttrColumn(reflection::type_attr::kShallowCopy); + TVMFFIByteArray copy_attr_name = { + reflection::type_attr::kShallowCopy, + std::char_traits::length(reflection::type_attr::kShallowCopy)}; + TVMFFIAny copy_attr_value = AnyView(copy_fn).CopyToTVMFFIAny(); + TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterAttr(type_index, ©_attr_name, ©_attr_value)); + // Also register as an instance method so Python can call __ffi_shallow_copy__ + TVMFFIMethodInfo copy_method; + copy_method.name = copy_attr_name; + copy_method.doc = TVMFFIByteArray{nullptr, 0}; + copy_method.flags = 0; + copy_method.method = AnyView(copy_fn).CopyToTVMFFIAny(); + copy_method.metadata = TVMFFIByteArray{nullptr, 0}; + TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterMethod(type_index, ©_method)); +} + } // namespace // ============================================================================ @@ -1725,6 +1973,12 @@ TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; // MakeInit refl::GlobalDef().def("ffi.MakeInit", refl::MakeInit); + // Python-defined type support + refl::EnsureTypeAttrColumn("__ffi_new__"); + refl::GlobalDef().def("ffi.GetFieldGetter", GetFieldGetter); + refl::GlobalDef().def("ffi.MakeFieldSetter", MakeFieldSetter); + refl::GlobalDef().def("ffi.MakeFFINew", MakeFFINew); + refl::GlobalDef().def("ffi.RegisterAutoInit", refl::RegisterAutoInit); // Deep copy refl::EnsureTypeAttrColumn(refl::type_attr::kShallowCopy); refl::GlobalDef().def("ffi.DeepCopy", DeepCopy); diff --git a/src/ffi/extra/reflection_extra.cc b/src/ffi/extra/reflection_extra.cc index 5182f1df..44e8ac3c 100644 --- a/src/ffi/extra/reflection_extra.cc +++ b/src/ffi/extra/reflection_extra.cc @@ -42,15 +42,7 @@ void MakeObjectFromPackedArgs(ffi::PackedArgs args, Any* ret) { TVM_FFI_ICHECK(args.size() % 2 == 1); const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(type_index); - - if (type_info->metadata == nullptr || type_info->metadata->creator == nullptr) { - TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_index) - << "` does not support reflection creation"; - } - TVMFFIObjectHandle handle; - TVM_FFI_CHECK_SAFE_CALL(type_info->metadata->creator(&handle)); - ObjectPtr ptr = - details::ObjectUnsafe::ObjectPtrFromOwned(static_cast(handle)); + ObjectPtr ptr = CreateEmptyObject(type_info); std::vector keys; std::vector keys_found; @@ -77,7 +69,8 @@ void MakeObjectFromPackedArgs(ffi::PackedArgs args, Any* ret) { void* field_addr = reinterpret_cast(ptr.get()) + field_info->offset; if (arg_index < keys.size()) { AnyView field_value = args[static_cast(arg_index * 2 + 2)]; - field_info->setter(field_addr, reinterpret_cast(&field_value)); + reflection::CallFieldSetter(field_info, field_addr, + reinterpret_cast(&field_value)); keys_found[arg_index] = true; } else if (field_info->flags & kTVMFFIFieldFlagBitMaskHasDefault) { reflection::SetFieldToDefault(field_info, field_addr); diff --git a/src/ffi/extra/serialization.cc b/src/ffi/extra/serialization.cc index 2351f7ca..c1fb6211 100644 --- a/src/ffi/extra/serialization.cc +++ b/src/ffi/extra/serialization.cc @@ -371,15 +371,7 @@ class ObjectGraphDeserializer { } // otherwise, we go over the fields and create the data. const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(type_index); - if (type_info->metadata == nullptr || type_info->metadata->creator == nullptr) { - TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_index) - << "` does not support default constructor" - << ", so ToJSONGraph is not supported for this type"; - } - TVMFFIObjectHandle handle; - TVM_FFI_CHECK_SAFE_CALL(type_info->metadata->creator(&handle)); - ObjectPtr ptr = - details::ObjectUnsafe::ObjectPtrFromOwned(static_cast(handle)); + ObjectPtr ptr = CreateEmptyObject(type_info); auto decode_field_value = [&](const TVMFFIFieldInfo* field_info, const json::Value& data) -> Any { @@ -411,7 +403,8 @@ class ObjectGraphDeserializer { void* field_addr = reinterpret_cast(ptr.get()) + field_info->offset; if (data_object.count(field_name) != 0) { Any field_value = decode_field_value(field_info, data_object[field_name]); - field_info->setter(field_addr, reinterpret_cast(&field_value)); + reflection::CallFieldSetter(field_info, field_addr, + reinterpret_cast(&field_value)); } else if (field_info->flags & kTVMFFIFieldFlagBitMaskHasDefault) { reflection::SetFieldToDefault(field_info, field_addr); } else { diff --git a/src/ffi/function.cc b/src/ffi/function.cc index 4f83d28a..4b378f74 100644 --- a/src/ffi/function.cc +++ b/src/ffi/function.cc @@ -229,11 +229,18 @@ TVM_FFI_STATIC_INIT_BLOCK() { }) .def("ffi.String", [](tvm::ffi::String val) -> tvm::ffi::String { return val; }) .def("ffi.Bytes", [](tvm::ffi::Bytes val) -> tvm::ffi::Bytes { return val; }) - .def("ffi.GetGlobalFuncMetadata", [](const tvm::ffi::String& name) -> tvm::ffi::String { - const auto* f = tvm::ffi::GlobalFunctionTable::Global()->Get(name); - if (f == nullptr) { - TVM_FFI_THROW(RuntimeError) << "Global Function is not found: " << name; - } - return f->metadata_data; - }); + .def("ffi.GetGlobalFuncMetadata", + [](const tvm::ffi::String& name) -> tvm::ffi::String { + const auto* f = tvm::ffi::GlobalFunctionTable::Global()->Get(name); + if (f == nullptr) { + TVM_FFI_THROW(RuntimeError) << "Global Function is not found: " << name; + } + return f->metadata_data; + }) + .def("ffi.FunctionFromExternC", + [](void* self, void* safe_call, void* deleter) -> tvm::ffi::Function { + return tvm::ffi::Function::FromExternC(self, + reinterpret_cast(safe_call), + reinterpret_cast(deleter)); + }); } diff --git a/src/ffi/object.cc b/src/ffi/object.cc index 94ad9815..7606ba04 100644 --- a/src/ffi/object.cc +++ b/src/ffi/object.cc @@ -23,7 +23,10 @@ #include #include #include +#include +#include #include +#include #include #include #include @@ -79,6 +82,16 @@ class TypeTable { /*! \brief Whether child can overflow. */ bool child_slots_can_overflow{true}; + ~Entry() { + // Release FunctionObj setter handles that were IncRef'd during RegisterTypeField. + for (TVMFFIFieldInfo& field : type_fields_data) { + if ((field.flags & kTVMFFIFieldFlagBitSetterIsFunctionObj) && field.setter != nullptr) { + TVMFFIObjectDecRef(static_cast(field.setter)); + field.setter = nullptr; + } + } + } + Entry(int32_t type_index, int32_t type_depth, String type_key, int32_t num_slots, bool child_slots_can_overflow, const Entry* parent) { // setup fields in the class @@ -216,6 +229,12 @@ class TypeTable { void RegisterTypeField(int32_t type_index, const TVMFFIFieldInfo* info) { Entry* entry = GetTypeEntry(type_index); TVMFFIFieldInfo field_data = *info; + // IncRef FunctionObj setter so it stays alive in the type table + if (field_data.flags & kTVMFFIFieldFlagBitSetterIsFunctionObj) { + if (field_data.setter != nullptr) { + TVMFFIObjectIncRef(static_cast(field_data.setter)); + } + } field_data.name = this->CopyString(info->name); field_data.doc = this->CopyString(info->doc); field_data.metadata = this->CopyString(info->metadata); @@ -359,6 +378,7 @@ class TypeTable { -1); TVMFFITypeMetadata info; info.total_size = sizeof(Object); + info.structural_eq_hash_kind = kTVMFFISEqHashKindUnsupported; info.creator = nullptr; info.doc = TVMFFIByteArray{nullptr, 0}; RegisterTypeMetadata(Object::_type_index, &info); @@ -587,6 +607,20 @@ namespace { TVM_FFI_STATIC_INIT_BLOCK() { using namespace tvm::ffi; namespace refl = tvm::ffi::reflection; + refl::RegisterConvertTypeAttr(TypeIndex::kTVMFFIObject, StaticTypeKey::kTVMFFIObject); + refl::RegisterConvertTypeAttr(TypeIndex::kTVMFFIStr, StaticTypeKey::kTVMFFIStr); + refl::RegisterConvertTypeAttr(TypeIndex::kTVMFFIBytes, StaticTypeKey::kTVMFFIBytes); + refl::RegisterConvertTypeAttr(TypeIndex::kTVMFFIError, StaticTypeKey::kTVMFFIError); + refl::RegisterConvertTypeAttr(TypeIndex::kTVMFFIFunction, + StaticTypeKey::kTVMFFIFunction); + refl::RegisterConvertTypeAttr(TypeIndex::kTVMFFIShape, StaticTypeKey::kTVMFFIShape); + refl::RegisterConvertTypeAttr(TypeIndex::kTVMFFITensor, StaticTypeKey::kTVMFFITensor); + refl::RegisterConvertTypeAttr>(TypeIndex::kTVMFFIArray, StaticTypeKey::kTVMFFIArray); + refl::RegisterConvertTypeAttr>(TypeIndex::kTVMFFIMap, StaticTypeKey::kTVMFFIMap); + // Skipped: TypeIndex::kTVMFFIModule + // Skipped: TypeIndex::kTVMFFIOpaquePyObject + refl::RegisterConvertTypeAttr>(TypeIndex::kTVMFFIList, StaticTypeKey::kTVMFFIList); + refl::RegisterConvertTypeAttr>(TypeIndex::kTVMFFIDict, StaticTypeKey::kTVMFFIDict); refl::GlobalDef() .def_method( "ffi.GetRegisteredTypeKeys", diff --git a/src/ffi/testing/testing.cc b/src/ffi/testing/testing.cc index 111e6447..682d0365 100644 --- a/src/ffi/testing/testing.cc +++ b/src/ffi/testing/testing.cc @@ -72,6 +72,7 @@ class TestIntPair : public tvm::ffi::ObjectRef { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::ObjectDef() + .ref() .def(refl::init()) .def_ro("a", &TestIntPairObj::a, "Field `a`") .def_ro("b", &TestIntPairObj::b, "Field `b`") diff --git a/tests/python/test_dataclass_py_class.py b/tests/python/test_dataclass_py_class.py new file mode 100644 index 00000000..d08c374f --- /dev/null +++ b/tests/python/test_dataclass_py_class.py @@ -0,0 +1,3533 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Tests for Python-defined TVM-FFI types: ``@py_class`` decorator and low-level Field API.""" + +# ruff: noqa: D102, PLR0124, PLW1641 +from __future__ import annotations + +import copy +import gc +import inspect +import itertools +import math +import sys +from typing import ClassVar + +import pytest +import tvm_ffi +from tvm_ffi import core +from tvm_ffi._ffi_api import DeepCopy, RecursiveEq, RecursiveHash, ReprPrint +from tvm_ffi.core import MISSING, Object, TypeInfo, TypeSchema +from tvm_ffi.dataclasses import KW_ONLY, Field, field, py_class +from tvm_ffi.registry import _add_class_attrs, _install_dataclass_dunders +from tvm_ffi.testing import TestObjectBase as _TestObjectBase + +_needs_310 = pytest.mark.skipif(sys.version_info < (3, 10), reason="X | Y syntax requires 3.10+") + +# --------------------------------------------------------------------------- +# Unique type key generator (avoids collisions across tests) +# --------------------------------------------------------------------------- +_counter = itertools.count() + + +def _unique_key(base: str) -> str: + return f"testing.py_class_dec.{base}_{next(_counter)}" + + +def _get_type_info(cls: type) -> TypeInfo: + ret = cls.__tvm_ffi_type_info__ # ty: ignore[unresolved-attribute] + assert isinstance(ret, TypeInfo), f"Expected TypeInfo, got {type(ret)}" + return ret + + +# --------------------------------------------------------------------------- +# Low-level helpers for _make_type-based tests +# --------------------------------------------------------------------------- +_counter_ff = itertools.count() + + +def _unique_key_ff(base: str) -> str: + """Return a globally unique type key for low-level field tests.""" + return f"testing.py_class.{base}_{next(_counter_ff)}" + + +def _make_type( + name: str, + fields: list[Field], + *, + parent: type = core.Object, + eq: bool = False, + unsafe_hash: bool = False, + repr: bool = True, +) -> type: + """Create, register, and fully set up a Python-defined TVM-FFI type. + + Returns the ready-to-use Python class. + """ + type_key = _unique_key_ff(name) + parent_info = core._type_cls_to_type_info(parent) + assert parent_info is not None + cls = type(name, (parent,), {"__slots__": ()}) + info = core._register_py_class(parent_info, type_key, cls) + info._register_fields(fields) + setattr(cls, "__tvm_ffi_type_info__", info) + _add_class_attrs(cls, info) + _install_dataclass_dunders( + cls, + init=True, + repr=repr, + eq=eq, + order=False, + unsafe_hash=unsafe_hash, + ) + return cls + + +# ########################################################################### +# 1. Basic registration +# ########################################################################### +class TestBasicRegistration: + """@py_class decorator with different calling conventions.""" + + def test_bare_decorator(self) -> None: + @py_class(_unique_key("Bare")) + class Bare(Object): + x: int + + info = _get_type_info(Bare) + assert info is not None + assert len(info.fields) == 1 + assert info.fields[0].name == "x" + + def test_decorator_with_options(self) -> None: + @py_class(_unique_key("Opts"), eq=True) + class Opts(Object): + x: int + + assert hasattr(Opts, "__eq__") + assert Opts(x=1) == Opts(x=1) + + def test_auto_type_key(self) -> None: + @py_class(_unique_key("AutoKey")) + class AutoKey(Object): + x: int + + info = _get_type_info(AutoKey) + assert info.type_key.startswith("testing.") + + def test_explicit_type_key(self) -> None: + key = _unique_key("ExplicitKey") + + @py_class(key) + class ExplicitKey(Object): + x: int + + assert _get_type_info(ExplicitKey).type_key == key + + def test_empty_class(self) -> None: + @py_class(_unique_key("Empty")) + class Empty(Object): + pass + + obj = Empty() + assert obj is not None + + def test_isinstance_check(self) -> None: + @py_class(_unique_key("InstCheck")) + class InstCheck(Object): + x: int + + obj = InstCheck(x=42) + assert isinstance(obj, InstCheck) + assert isinstance(obj, Object) + + +# ########################################################################### +# 2. Field parsing +# ########################################################################### +class TestFieldParsing: + """Annotation-to-Field conversion.""" + + def test_int_field(self) -> None: + @py_class(_unique_key("IntFld")) + class IntFld(Object): + x: int + + obj = IntFld(x=42) + assert obj.x == 42 + + def test_float_field(self) -> None: + @py_class(_unique_key("FltFld")) + class FltFld(Object): + x: float + + obj = FltFld(x=3.14) + assert abs(obj.x - 3.14) < 1e-10 + + def test_str_field(self) -> None: + @py_class(_unique_key("StrFld")) + class StrFld(Object): + x: str + + obj = StrFld(x="hello") + assert obj.x == "hello" + + def test_bool_field(self) -> None: + @py_class(_unique_key("BoolFld")) + class BoolFld(Object): + x: bool + + obj = BoolFld(x=True) + assert obj.x is True + + @_needs_310 + def test_optional_field(self) -> None: + @py_class(_unique_key("OptFld")) + class OptFld(Object): + x: int | None + + obj = OptFld(x=42) + assert obj.x == 42 + obj2 = OptFld(x=None) + assert obj2.x is None + + def test_multiple_fields(self) -> None: + @py_class(_unique_key("Multi")) + class Multi(Object): + a: int + b: float + c: str + + obj = Multi(a=1, b=2.0, c="three") + assert obj.a == 1 + assert obj.b == 2.0 + assert obj.c == "three" + + +# ########################################################################### +# 3. Defaults +# ########################################################################### +class TestDefaults: + """Default values and default_factory.""" + + def test_bare_default(self) -> None: + @py_class(_unique_key("BareDef")) + class BareDef(Object): + x: int + y: int = 10 + + obj = BareDef(x=1) + assert obj.y == 10 + + def test_field_default(self) -> None: + @py_class(_unique_key("FldDef")) + class FldDef(Object): + x: int = field(default=42) + + obj = FldDef() + assert obj.x == 42 + + def test_field_default_factory(self) -> None: + call_count = 0 + + def make_default() -> int: + nonlocal call_count + call_count += 1 + return 99 + + @py_class(_unique_key("FldFact")) + class FldFact(Object): + x: int = field(default_factory=make_default) + + obj1 = FldFact() + assert obj1.x == 99 + obj2 = FldFact() + assert obj2.x == 99 + assert call_count == 2 + + def test_default_and_factory_mutually_exclusive(self) -> None: + with pytest.raises(ValueError, match="cannot specify both"): + field(default=1, default_factory=int) + + def test_non_callable_factory_rejected(self) -> None: + with pytest.raises(TypeError, match="default_factory must be a callable"): + field(default_factory=42) # ty: ignore[invalid-argument-type] + + def test_required_before_optional(self) -> None: + @py_class(_unique_key("ReqOpt")) + class ReqOpt(Object): + a: int + b: int = 10 + + obj = ReqOpt(1) + assert obj.a == 1 + assert obj.b == 10 + + +# ########################################################################### +# 4. KW_ONLY +# ########################################################################### +class TestKwOnly: + """Keyword-only field support.""" + + def test_kw_only_sentinel(self) -> None: + @py_class(_unique_key("KWSent")) + class KWSent(Object): + a: int + _: KW_ONLY + b: int = 10 + + obj = KWSent(1, b=20) # ty: ignore[missing-argument] + assert obj.a == 1 + assert obj.b == 20 + with pytest.raises(TypeError): + KWSent(1, 2) # ty: ignore[invalid-argument-type] + + def test_decorator_level_kw_only(self) -> None: + @py_class(_unique_key("DecKW"), kw_only=True) + class DecKW(Object): + a: int + b: int = 10 + + obj = DecKW(a=1) + assert obj.a == 1 + assert obj.b == 10 + with pytest.raises(TypeError): + DecKW(1) # ty: ignore[missing-argument,too-many-positional-arguments] + + def test_field_level_kw_only_override(self) -> None: + @py_class(_unique_key("FldKW")) + class FldKW(Object): + a: int + b: int = field(default=10, kw_only=True) + + obj = FldKW(1) + assert obj.a == 1 + assert obj.b == 10 + with pytest.raises(TypeError): + FldKW(1, 2) # b is keyword-only + + +# ########################################################################### +# 5. ClassVar +# ########################################################################### +class TestClassVar: + """ClassVar annotations are skipped.""" + + def test_classvar_skipped(self) -> None: + @py_class(_unique_key("CV")) + class CV(Object): + x: int + count: ClassVar[int] = 0 + + info = _get_type_info(CV) + field_names = [f.name for f in info.fields] + assert "x" in field_names + assert "count" not in field_names + + def test_classvar_preserved_on_class(self) -> None: + @py_class(_unique_key("CVPres")) + class CVPres(Object): + x: int + tag: ClassVar[str] = "hello" + + assert CVPres.tag == "hello" + + +# ########################################################################### +# 6. Init generation +# ########################################################################### +class TestInit: + """Auto-generated __init__.""" + + def test_positional_args(self) -> None: + @py_class(_unique_key("Pos")) + class Pos(Object): + a: int + b: str + + obj = Pos(1, "hello") + assert obj.a == 1 + assert obj.b == "hello" + + def test_keyword_args(self) -> None: + @py_class(_unique_key("Kw")) + class Kw(Object): + a: int + b: str + + obj = Kw(a=1, b="hello") + assert obj.a == 1 + assert obj.b == "hello" + + def test_init_false_field(self) -> None: + @py_class(_unique_key("NoInit")) + class NoInit(Object): + a: int + b: int = field(default=99, init=False) + + obj = NoInit(a=1) + assert obj.a == 1 + assert obj.b == 99 + + def test_user_defined_init_preserved(self) -> None: + @py_class(_unique_key("UserInit"), init=False) + class UserInit(Object): + a: int + + def __init__(self, val: int) -> None: + self.__ffi_init__(val) + + obj = UserInit(42) + assert obj.a == 42 + + def test_required_after_optional_reordered(self) -> None: + """Required positional fields are reordered before optional ones in __init__.""" + + @py_class(_unique_key("ReorderOwn")) + class ReorderOwn(Object): + x: int = 0 + y: int # ty: ignore[dataclass-field-order] + + sig = inspect.signature(ReorderOwn.__init__) + param_names = [n for n in sig.parameters if n != "self"] + assert param_names[0] == "y" # required comes first + assert param_names[1] == "x" # optional comes second + + obj = ReorderOwn(y=1) # ty: ignore[missing-argument] + assert obj.x == 0 + assert obj.y == 1 + + def test_required_after_optional_in_parent(self) -> None: + """Child required fields are reordered before parent optional fields.""" + + @py_class(_unique_key("OptParent")) + class OptParent(Object): + x: int + y: int = 0 + + @py_class(_unique_key("ReqChild")) + class ReqChild(OptParent): + z: int + + sig = inspect.signature(ReqChild.__init__) + param_names = [n for n in sig.parameters if n != "self"] + # required (x, z) before optional (y) + assert param_names == ["x", "z", "y"] + + obj = ReqChild(x=1, z=3) + assert obj.x == 1 + assert obj.y == 0 + assert obj.z == 3 + + def test_kw_only_exempt_from_reorder(self) -> None: + """kw_only fields are not reordered with positional fields.""" + + @py_class(_unique_key("KwReorder")) + class KwReorder(Object): + x: int = 0 + _: KW_ONLY # ty: ignore[dataclass-field-order] + y: int # ty: ignore[dataclass-field-order] + + sig = inspect.signature(KwReorder.__init__) + params = sig.parameters + assert params["x"].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD + assert params["y"].kind == inspect.Parameter.KEYWORD_ONLY + + obj = KwReorder(y=1) # ty: ignore[missing-argument] + assert obj.x == 0 + assert obj.y == 1 + + def test_mixed_positional_and_kw_only_with_defaults(self) -> None: + """Mixed positional/kw_only fields with defaults produce correct signature.""" + + @py_class(_unique_key("MixedSig")) + class MixedSig(Object): + a: int = 0 + b: int # ty: ignore[dataclass-field-order] + _: KW_ONLY # ty: ignore[dataclass-field-order] + c: int = 10 + d: int # ty: ignore[dataclass-field-order] + + sig = inspect.signature(MixedSig.__init__) + param_names = [n for n in sig.parameters if n != "self"] + # positional: b (required) before a (optional); kw_only: d (required) before c (optional) + assert param_names == ["b", "a", "d", "c"] + + obj = MixedSig(b=2, d=4) # ty: ignore[missing-argument] + assert obj.a == 0 + assert obj.b == 2 + assert obj.c == 10 + assert obj.d == 4 + + def test_init_false_excluded_from_signature(self) -> None: + """init=False fields do not appear in __init__ signature.""" + + @py_class(_unique_key("InitFalseSig")) + class InitFalseSig(Object): + a: int + b: int = field(default=99, init=False) + c: str + + sig = inspect.signature(InitFalseSig.__init__) + param_names = [n for n in sig.parameters if n != "self"] + assert "b" not in param_names + assert "a" in param_names + assert "c" in param_names + + +# ########################################################################### +# 7. __post_init__ +# ########################################################################### +class TestPostInit: + """__post_init__ support.""" + + def test_post_init_called(self) -> None: + post_init_called = False + + @py_class(_unique_key("PostInit")) + class PostInit(Object): + x: int + + def __post_init__(self) -> None: + nonlocal post_init_called + post_init_called = True + + PostInit(x=1) + assert post_init_called + + def test_post_init_sees_field_values(self) -> None: + @py_class(_unique_key("PostInitVal")) + class PostInitVal(Object): + x: int + y: int = 10 + + def __post_init__(self) -> None: + # Fields should be set before __post_init__ is called + assert self.x is not None + assert self.y == 10 + + PostInitVal(x=5) + + +# ########################################################################### +# 8. Repr +# ########################################################################### +class TestRepr: + """__repr__ generation.""" + + def test_repr_generated(self) -> None: + @py_class(_unique_key("Repr")) + class Repr(Object): + x: int + y: str + + obj = Repr(x=1, y="hello") + r = repr(obj) + assert "1" in r + assert "hello" in r + + def test_repr_disabled(self) -> None: + @py_class(_unique_key("NoRepr"), repr=False) + class NoRepr(Object): + x: int + + obj = NoRepr(x=1) + # Should use default object repr + r = repr(obj) + assert "NoRepr" in r or "object at" in r + + +# ########################################################################### +# 9. Equality +# ########################################################################### +class TestEquality: + """__eq__ and __ne__ generation.""" + + def test_eq_enabled(self) -> None: + @py_class(_unique_key("Eq"), eq=True) + class Eq(Object): + x: int + y: str + + assert Eq(x=1, y="a") == Eq(x=1, y="a") + assert Eq(x=1, y="a") != Eq(x=2, y="a") + + def test_eq_disabled_by_default(self) -> None: + @py_class(_unique_key("NoEq")) + class NoEq(Object): + x: int + + a = NoEq(x=1) + b = NoEq(x=1) + # Without eq, identity comparison + assert a != b + assert a == a + + +# ########################################################################### +# 10. Order +# ########################################################################### +class TestOrder: + """Comparison methods.""" + + def test_order_enabled(self) -> None: + @py_class(_unique_key("Ord"), eq=True, order=True) + class Ord(Object): + x: int + + assert Ord(x=1) < Ord(x=2) + assert Ord(x=2) > Ord(x=1) + assert Ord(x=1) <= Ord(x=1) + assert Ord(x=1) >= Ord(x=1) + + +# ########################################################################### +# 11. Hash +# ########################################################################### +class TestHash: + """__hash__ generation.""" + + def test_unsafe_hash(self) -> None: + @py_class(_unique_key("Hash"), eq=True, unsafe_hash=True) + class Hash(Object): + x: int + + a = Hash(x=1) + b = Hash(x=1) + assert hash(a) == hash(b) + # Can be used in sets + s = {a, b} + assert len(s) == 1 + + +# ########################################################################### +# 12. Copy +# ########################################################################### +class TestCopy: + """__copy__, __deepcopy__, __replace__.""" + + def test_shallow_copy(self) -> None: + @py_class(_unique_key("SCopy")) + class SCopy(Object): + x: int + + obj = SCopy(x=42) + obj2 = copy.copy(obj) + assert obj2.x == 42 + + def test_deep_copy(self) -> None: + @py_class(_unique_key("DCopy")) + class DCopy(Object): + x: int + + obj = DCopy(x=42) + obj2 = copy.deepcopy(obj) + assert obj2.x == 42 + + def test_replace(self) -> None: + @py_class(_unique_key("Repl")) + class Repl(Object): + x: int + y: str + + obj = Repl(x=1, y="a") + obj2 = obj.__replace__(x=2) # ty: ignore[unresolved-attribute] + assert obj2.x == 2 + assert obj2.y == "a" + + +# ########################################################################### +# 13. Inheritance +# ########################################################################### +class TestInheritance: + """Inheritance between py_class types.""" + + def test_child_adds_fields(self) -> None: + @py_class(_unique_key("Parent")) + class Parent(Object): + x: int + + @py_class(_unique_key("Child")) + class Child(Parent): + y: str + + obj = Child(x=1, y="hello") + assert obj.x == 1 + assert obj.y == "hello" + + def test_child_isinstance(self) -> None: + @py_class(_unique_key("P2")) + class P2(Object): + x: int + + @py_class(_unique_key("C2")) + class C2(P2): + y: str + + obj = C2(x=1, y="hello") + assert isinstance(obj, C2) + assert isinstance(obj, P2) + assert isinstance(obj, Object) + + def test_three_level_inheritance(self) -> None: + @py_class(_unique_key("L1")) + class L1(Object): + a: int + + @py_class(_unique_key("L2")) + class L2(L1): + b: int + + @py_class(_unique_key("L3")) + class L3(L2): + c: int + + obj = L3(a=1, b=2, c=3) + assert obj.a == 1 + assert obj.b == 2 + assert obj.c == 3 + + +# ########################################################################### +# 14. Forward references / deferred resolution +# ########################################################################### +class TestForwardReferences: + """Deferred annotation resolution for mutual and self-references.""" + + @_needs_310 + def test_self_reference(self) -> None: + @py_class(_unique_key("SelfRef")) + class SelfRef(Object): + value: int + next_node: SelfRef | None + + leaf = SelfRef(value=2, next_node=None) + head = SelfRef(value=1, next_node=leaf) + assert head.next_node is not None + assert head.next_node.value == 2 + + @_needs_310 + def test_mutual_reference(self) -> None: + """Two classes that reference each other.""" + + @py_class(_unique_key("Foo")) + class Foo(Object): + value: int + bar: Bar | None + + @py_class(_unique_key("Bar")) + class Bar(Object): + value: int + foo: Foo | None + + bar = Bar(value=2, foo=None) + foo = Foo(value=1, bar=bar) + assert foo.bar is not None + assert foo.bar.value == 2 + + @_needs_310 + def test_deferred_resolution_on_instantiation(self) -> None: + """Forward ref resolved on first instantiation.""" + + @py_class(_unique_key("Early")) + class Early(Object): + value: int + ref: Late | None + + # At this point, Early's fields are deferred because Late doesn't exist + + @py_class(_unique_key("Late")) + class Late(Object): + value: int + + # Now Early should resolve (either via flush or on instantiation) + obj = Early(value=1, ref=Late(value=2)) + assert obj.ref is not None + assert obj.ref.value == 2 + + +# ########################################################################### +# 15. User-defined dunder preservation +# ########################################################################### +class TestDunderPreservation: + """User-defined dunders are not overwritten.""" + + def test_user_repr_preserved(self) -> None: + @py_class(_unique_key("UserRepr")) + class UserRepr(Object): + x: int + + def __repr__(self) -> str: + return f"Custom({self.x})" + + obj = UserRepr(x=42) + assert repr(obj) == "Custom(42)" + + def test_user_eq_preserved(self) -> None: + @py_class(_unique_key("UserEq"), eq=True) + class UserEq(Object): + x: int + + def __eq__(self, other: object) -> bool: + return False + + assert not (UserEq(x=1) == UserEq(x=1)) + + +# ########################################################################### +# 16. field() API +# ########################################################################### +class TestFieldAPI: + """field() function returns a Field.""" + + def test_field_returns_field(self) -> None: + f = field(default=42) + assert isinstance(f, Field) + assert f.default == 42 + + def test_field_defaults(self) -> None: + f = field() + assert f.init is True + assert f.repr is True + assert f.hash is None # None = follow compare + assert f.compare is True + + def test_field_kw_only_missing_by_default(self) -> None: + f = field() + assert f.kw_only is None + + def test_field_repr_false(self) -> None: + @py_class(_unique_key("FldRepr")) + class FldRepr(Object): + x: int + y: int = field(default=0, repr=False) + + obj = FldRepr(x=1) + r = repr(obj) + assert "1" in r + # y with repr=False should not appear in repr + # (depends on C++ ReprPrint implementation respecting the flag) + + +# ########################################################################### +# 17. Edge cases +# ########################################################################### +class TestEdgeCases: + """Edge cases and error conditions.""" + + def test_no_ffi_parent_raises(self) -> None: + with pytest.raises(TypeError, match="must inherit from"): + + @py_class(_unique_key("NoPar")) + class NoPar: # no Object parent! + x: int + + def test_only_classvar(self) -> None: + @py_class(_unique_key("OnlyCV")) + class OnlyCV(Object): + count: ClassVar[int] = 0 + + obj = OnlyCV() + assert obj is not None + + def test_mutation_after_creation(self) -> None: + @py_class(_unique_key("Mut")) + class Mut(Object): + x: int + + obj = Mut(x=1) + obj.x = 42 + assert obj.x == 42 + + +# ########################################################################### +# 18. hash=None tri-state +# ########################################################################### +class TestHashTriState: + """field(hash=None) means 'follow compare' (native dataclass semantics).""" + + def test_hash_none_follows_compare_true(self) -> None: + """hash=None + compare=True → field participates in hash.""" + + @py_class(_unique_key("HNT"), eq=True, unsafe_hash=True) + class HNT(Object): + x: int # default: compare=True, hash=None → hash=True + + a = HNT(x=1) + b = HNT(x=1) + assert hash(a) == hash(b) + + def test_hash_none_follows_compare_false(self) -> None: + """hash=None + compare=False → field excluded from hash.""" + + @py_class(_unique_key("HNF"), eq=True, unsafe_hash=True) + class HNF(Object): + x: int + y: int = field(compare=False) # hash=None → follows compare=False + + # y doesn't participate in hash, so different y values → same hash + a = HNF(x=1, y=10) + b = HNF(x=1, y=20) + assert hash(a) == hash(b) + + def test_hash_explicit_true_with_compare_true(self) -> None: + """hash=True + compare=True → field participates in hash.""" + + @py_class(_unique_key("HET"), eq=True, unsafe_hash=True) + class HET(Object): + x: int = field(hash=True) # compare=True (default) + + a = HET(x=1) + b = HET(x=2) + assert hash(a) != hash(b) + + def test_hash_explicit_false(self) -> None: + """hash=False excludes field from hashing even with compare=True.""" + + @py_class(_unique_key("HEF"), eq=True, unsafe_hash=True) + class HEF(Object): + x: int + y: int = field(hash=False) # compare=True but hash=False + + a = HEF(x=1, y=10) + b = HEF(x=1, y=20) + assert hash(a) == hash(b) + + +# ########################################################################### +# 19. Deferred resolution + user __init__ / init=False +# ########################################################################### +class TestDeferredInitPreservation: + """Deferred resolution preserves user-defined __init__ and init=False.""" + + @_needs_310 + def test_deferred_with_user_init(self) -> None: + """User-defined __init__ is preserved after deferred resolution.""" + + @py_class(_unique_key("DefUI")) + class DefUI(Object): + value: int + ref: DefUILate | None + + def __init__(self, value: int) -> None: + self.__ffi_init__(value, None) + + @py_class(_unique_key("DefUILate")) + class DefUILate(Object): + x: int + + # DefUI should use the user-defined __init__ (one positional arg) + obj = DefUI(42) + assert obj.value == 42 + assert obj.ref is None + + @_needs_310 + def test_deferred_with_init_false(self) -> None: + """init=False is respected after deferred resolution.""" + + @py_class(_unique_key("DefNoInit"), init=False) + class DefNoInit(Object): + value: int + ref: DefNoInitLate | None + + def __init__(self, v: int) -> None: + self.__ffi_init__(v, None) + + @py_class(_unique_key("DefNoInitLate")) + class DefNoInitLate(Object): + x: int + + obj = DefNoInit(10) + assert obj.value == 10 + + +# ########################################################################### +# 21. order=True requires eq=True +# ########################################################################### +class TestOrderEqValidation: + """order=True without eq=True is rejected.""" + + def test_order_without_eq_raises(self) -> None: + with pytest.raises(ValueError, match="order=True requires eq=True"): + + @py_class(_unique_key("OrdNoEq"), order=True) + class OrdNoEq(Object): + x: int + + +# ########################################################################### +# 23. Registration rollback on failure +# ########################################################################### +class TestRegistrationRollback: + """Failed decorations don't permanently poison the type registry.""" + + def test_failed_decoration_allows_retry(self) -> None: + key = _unique_key("Rollback") + + with pytest.raises(Exception): + + @py_class(key) + class Bad(Object): + x: object # unsupported annotation type + + # The type key should be available for reuse + @py_class(key) + class Good(Object): + x: int + y: int = 0 + + assert Good(x=1).y == 0 + + +# ########################################################################### +# 24. User-defined __replace__ preserved +# ########################################################################### +class TestUserReplace: + """User-defined __replace__ is not overwritten by py_class.""" + + def test_user_replace_preserved(self) -> None: + @py_class(_unique_key("UserRepl")) + class UserRepl(Object): + x: int + + def __replace__(self, **changes: object) -> str: + return "custom" + + obj = UserRepl(x=1) + assert obj.__replace__(x=2) == "custom" + + +# ########################################################################### +# 25. default_factory=None raises +# ########################################################################### +class TestDefaultFactoryNone: + """Explicit default_factory=None matches stdlib semantics (raises).""" + + def test_explicit_none_raises(self) -> None: + with pytest.raises(TypeError, match="default_factory must be a callable"): + field(default_factory=None) + + +# ########################################################################### +# 26. Adversarial edge cases for init reordering +# ########################################################################### +class TestInitReorderingAdversarial: + """Tricky scenarios that catch bugs in naive init-signature generation.""" + + def test_positional_call_maps_to_required_not_declared_order(self) -> None: + """Positional arg 1 maps to the first *required* field, not the first declared.""" + + @py_class(_unique_key("PosMap")) + class PosMap(Object): + x: int = 0 # optional, declared first + y: int # ty: ignore[dataclass-field-order] # required, declared second + + # Positional call: first arg is y (required), not x (optional) + obj = PosMap(42) # ty: ignore[missing-argument] + assert obj.y == 42 + assert obj.x == 0 + + def test_relative_order_preserved_within_groups(self) -> None: + """Within required and optional groups, declaration order is preserved.""" + + @py_class(_unique_key("RelOrder")) + class RelOrder(Object): + a: int = 0 + b: int # ty: ignore[dataclass-field-order] + c: int = 1 + d: int # ty: ignore[dataclass-field-order] + + sig = inspect.signature(RelOrder.__init__) + param_names = [n for n in sig.parameters if n != "self"] + # required: b, d (declaration order); optional: a, c (declaration order) + assert param_names == ["b", "d", "a", "c"] + + obj = RelOrder(10, 20) # ty: ignore[missing-argument] + assert obj.b == 10 + assert obj.d == 20 + assert obj.a == 0 + assert obj.c == 1 + + def test_default_factory_counts_as_optional(self) -> None: + """default_factory makes a field optional for reordering purposes.""" + + @py_class(_unique_key("DFReorder")) + class DFReorder(Object): + items: str = field(default_factory=lambda: "hello") + count: int # ty: ignore[dataclass-field-order] + + sig = inspect.signature(DFReorder.__init__) + param_names = [n for n in sig.parameters if n != "self"] + assert param_names[0] == "count" # required first + assert param_names[1] == "items" # optional (factory) second + + obj = DFReorder(count=5) + assert obj.count == 5 + assert obj.items == "hello" + + def test_three_level_hierarchy_reorder(self) -> None: + """Required fields from all levels come before optional fields from all levels.""" + + @py_class(_unique_key("G1")) + class G1(Object): + a: int # required + + @py_class(_unique_key("P1")) + class P1(G1): + b: int = 0 # optional + + @py_class(_unique_key("C1")) + class C1(P1): + c: int # required + + sig = inspect.signature(C1.__init__) + param_names = [n for n in sig.parameters if n != "self"] + # required (a, c) before optional (b) + assert param_names == ["a", "c", "b"] + + obj = C1(a=1, c=3) + assert obj.a == 1 + assert obj.b == 0 + assert obj.c == 3 + + def test_kw_only_false_overrides_sentinel(self) -> None: + """kw_only=False on a field after KW_ONLY sentinel makes it positional.""" + + @py_class(_unique_key("KwOverride")) + class KwOverride(Object): + _: KW_ONLY + a: int # kw_only (inherits sentinel) + b: int = field(kw_only=False) # positional (explicit override) + + sig = inspect.signature(KwOverride.__init__) + assert sig.parameters["a"].kind == inspect.Parameter.KEYWORD_ONLY + assert sig.parameters["b"].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD + + obj = KwOverride(42, a=1) # ty: ignore[missing-argument,invalid-argument-type] + assert obj.b == 42 + assert obj.a == 1 + + def test_init_false_field_gets_default(self) -> None: + """init=False field with default is set to default, not left uninitialized.""" + + @py_class(_unique_key("InitFalseDef")) + class InitFalseDef(Object): + visible: int + hidden: str = field(default="secret", init=False) + + obj = InitFalseDef(visible=1) + assert obj.hidden == "secret" + + def test_post_init_sees_reordered_fields(self) -> None: + """__post_init__ sees correct values even when __init__ reorders fields.""" + seen: dict[str, int] = {} + + @py_class(_unique_key("PostReorder")) + class PostReorder(Object): + x: int = 0 + y: int # ty: ignore[dataclass-field-order] + + def __post_init__(self) -> None: + seen["x"] = self.x + seen["y"] = self.y + + PostReorder(y=10, x=20) + assert seen == {"x": 20, "y": 10} + + @_needs_310 + def test_deferred_forward_ref_with_reordering(self) -> None: + """Deferred forward-reference resolution still produces correct reordering.""" + + @py_class(_unique_key("DeferReorder")) + class DeferReorder(Object): + opt: DeferLate | None = None + req: int # ty: ignore[dataclass-field-order] + + @py_class(_unique_key("DeferLate")) + class DeferLate(Object): + x: int + + sig = inspect.signature(DeferReorder.__init__) + param_names = [n for n in sig.parameters if n != "self"] + assert param_names[0] == "req" + assert param_names[1] == "opt" + + obj = DeferReorder(req=1) + assert obj.req == 1 + assert obj.opt is None + + def test_all_optional_preserves_declaration_order(self) -> None: + """When all fields are optional, declaration order is preserved.""" + + @py_class(_unique_key("AllOpt")) + class AllOpt(Object): + c: int = 3 + a: int = 1 + b: int = 2 + + sig = inspect.signature(AllOpt.__init__) + param_names = [n for n in sig.parameters if n != "self"] + assert param_names == ["c", "a", "b"] + + obj = AllOpt() + assert obj.c == 3 + assert obj.a == 1 + assert obj.b == 2 + + def test_all_required_preserves_declaration_order(self) -> None: + """When all fields are required, declaration order is preserved.""" + + @py_class(_unique_key("AllReq")) + class AllReq(Object): + c: int + a: int + b: int + + sig = inspect.signature(AllReq.__init__) + param_names = [n for n in sig.parameters if n != "self"] + assert param_names == ["c", "a", "b"] + + obj = AllReq(10, 20, 30) + assert obj.c == 10 + assert obj.a == 20 + assert obj.b == 30 + + +# ########################################################################### +# 1. Registration +# ########################################################################### +class TestRegisterPyClass: + """Low-level _register_py_class: type allocation, ancestors, field lifecycle.""" + + def test_basic_registration(self) -> None: + type_key = _unique_key_ff("RegBasic") + parent_info = core._type_cls_to_type_info(core.Object) + assert parent_info is not None + cls = type("RegBasic", (core.Object,), {"__slots__": ()}) + info = core._register_py_class(parent_info, type_key, cls) + assert info is not None + assert info.type_key == type_key + + def test_type_index_allocated(self) -> None: + type_key = _unique_key_ff("RegIndex") + parent_info = core._type_cls_to_type_info(core.Object) + assert parent_info is not None + cls = type("RegIndex", (core.Object,), {"__slots__": ()}) + info = core._register_py_class(parent_info, type_key, cls) + assert isinstance(info.type_index, int) + assert info.type_index > 0 + + def test_ancestors_include_parent(self) -> None: + parent_info = core._type_cls_to_type_info(core.Object) + assert parent_info is not None + type_key = _unique_key_ff("RegAncestors") + cls = type("RegAncestors", (core.Object,), {"__slots__": ()}) + info = core._register_py_class(parent_info, type_key, cls) + assert parent_info.type_index in info.type_ancestors + + def test_parent_type_info_set(self) -> None: + parent_info = core._type_cls_to_type_info(core.Object) + assert parent_info is not None + type_key = _unique_key_ff("RegParent") + cls = type("RegParent", (core.Object,), {"__slots__": ()}) + info = core._register_py_class(parent_info, type_key, cls) + assert info.parent_type_info is parent_info + + def test_initial_fields_none_and_methods_empty(self) -> None: + parent_info = core._type_cls_to_type_info(core.Object) + assert parent_info is not None + type_key = _unique_key_ff("RegEmpty") + cls = type("RegEmpty", (core.Object,), {"__slots__": ()}) + info = core._register_py_class(parent_info, type_key, cls) + assert info.fields is None + assert len(info.methods) == 0 + + def test_two_registrations_different_indices(self) -> None: + parent_info = core._type_cls_to_type_info(core.Object) + assert parent_info is not None + cls1 = type("RegDiff1", (core.Object,), {"__slots__": ()}) + cls2 = type("RegDiff2", (core.Object,), {"__slots__": ()}) + info1 = core._register_py_class(parent_info, _unique_key_ff("RegDiff1"), cls1) + info2 = core._register_py_class(parent_info, _unique_key_ff("RegDiff2"), cls2) + assert info1.type_index != info2.type_index + + def test_fields_none_before_registration(self) -> None: + parent_info = core._type_cls_to_type_info(core.Object) + assert parent_info is not None + cls = type("Pending", (core.Object,), {"__slots__": ()}) + info = core._register_py_class(parent_info, _unique_key_ff("Pending"), cls) + assert info.fields is None + + def test_register_fields_is_instance_method(self) -> None: + parent_info = core._type_cls_to_type_info(core.Object) + assert parent_info is not None + cls = type("PendingM", (core.Object,), {"__slots__": ()}) + info = core._register_py_class(parent_info, _unique_key_ff("PendingM"), cls) + assert hasattr(info, "_register_fields") + + def test_duplicate_type_key_raises(self) -> None: + parent_info = core._type_cls_to_type_info(core.Object) + assert parent_info is not None + type_key = _unique_key_ff("Dup") + cls1 = type("Dup1", (core.Object,), {"__slots__": ()}) + core._register_py_class(parent_info, type_key, cls1) + cls2 = type("Dup2", (core.Object,), {"__slots__": ()}) + with pytest.raises((RuntimeError, ValueError)): + core._register_py_class(parent_info, type_key, cls2) + + def test_duplicate_type_key_preserves_original(self) -> None: + """After rejected duplicate, original entry is intact.""" + parent_info = core._type_cls_to_type_info(core.Object) + assert parent_info is not None + type_key = _unique_key_ff("DupPreserve") + cls1 = type("DupPreserve1", (core.Object,), {"__slots__": ()}) + info1 = core._register_py_class(parent_info, type_key, cls1) + info1._register_fields([Field(name="x", ty=TypeSchema("int"))]) + setattr(cls1, "__tvm_ffi_type_info__", info1) + _add_class_attrs(cls1, info1) + + cls2 = type("DupPreserve2", (core.Object,), {"__slots__": ()}) + with pytest.raises((RuntimeError, ValueError)): + core._register_py_class(parent_info, type_key, cls2) + + reloaded = core._lookup_or_register_type_info_from_type_key(type_key) + assert reloaded.type_cls is cls1 + assert [f.name for f in reloaded.fields] == ["x"] + + +# ########################################################################### +# 2. Field Registration +# ########################################################################### +class TestFieldRegistration: + """Low-level _register_fields: field types, metadata, offsets.""" + + def test_int_field_registered(self) -> None: + cls = _make_type( + "FldInt", + [Field(name="x", ty=TypeSchema("int"), default=MISSING)], + ) + info = getattr(cls, "__tvm_ffi_type_info__") + assert len(info.fields) == 1 + assert info.fields[0].name == "x" + + def test_float_field_registered(self) -> None: + cls = _make_type( + "FldFloat", + [Field(name="val", ty=TypeSchema("float"), default=0.0)], + ) + info = getattr(cls, "__tvm_ffi_type_info__") + assert info.fields[0].name == "val" + + def test_str_field_registered(self) -> None: + cls = _make_type( + "FldStr", + [Field(name="s", ty=TypeSchema("str"), default="hello")], + ) + info = getattr(cls, "__tvm_ffi_type_info__") + assert info.fields[0].name == "s" + + def test_bool_field_registered(self) -> None: + cls = _make_type( + "FldBool", + [Field(name="flag", ty=TypeSchema("bool"), default=False)], + ) + info = getattr(cls, "__tvm_ffi_type_info__") + assert info.fields[0].name == "flag" + + def test_multiple_fields_count(self) -> None: + cls = _make_type( + "FldMulti", + [ + Field(name="a", ty=TypeSchema("int"), default=MISSING), + Field(name="b", ty=TypeSchema("float"), default=0.0), + Field(name="c", ty=TypeSchema("str"), default="x"), + ], + ) + info = getattr(cls, "__tvm_ffi_type_info__") + assert len(info.fields) == 3 + assert [f.name for f in info.fields] == ["a", "b", "c"] + + def test_field_offsets_increasing(self) -> None: + cls = _make_type( + "FldOff", + [ + Field(name="a", ty=TypeSchema("int"), default=MISSING), + Field(name="b", ty=TypeSchema("float"), default=MISSING), + Field(name="c", ty=TypeSchema("str"), default=MISSING), + ], + ) + info = getattr(cls, "__tvm_ffi_type_info__") + offsets = [f.offset for f in info.fields] + for i in range(1, len(offsets)): + assert offsets[i] > offsets[i - 1], f"Field offsets not increasing: {offsets}" + + def test_ffi_init_method_registered(self) -> None: + cls = _make_type( + "FldInit", + [Field(name="x", ty=TypeSchema("int"), default=MISSING)], + ) + info = getattr(cls, "__tvm_ffi_type_info__") + assert "__ffi_init__" in [m.name for m in info.methods] + + def test_field_metadata_repr_flag(self) -> None: + cls = _make_type( + "FldReprMeta", + [ + Field( + name="visible", + ty=TypeSchema("int"), + default=MISSING, + repr=True, + ), + Field( + name="hidden", + ty=TypeSchema("int"), + default=0, + repr=False, + ), + ], + ) + info = getattr(cls, "__tvm_ffi_type_info__") + assert len(info.fields) == 2 + + +# ########################################################################### +# 3. Field Descriptor +# ########################################################################### +class TestFieldDescriptor: + """Field class: validation, defaults, default_factory checks.""" + + def test_compare_default_is_false(self) -> None: + f = Field(name="x", ty=TypeSchema("int")) + assert f.compare is False + + def test_default_and_factory_mutually_exclusive(self) -> None: + with pytest.raises(ValueError): + Field(name="x", ty=TypeSchema("int"), default=0, default_factory=lambda: 0) + + def test_factory_must_be_callable(self) -> None: + with pytest.raises(TypeError, match="callable"): + Field(name="x", ty=TypeSchema("int"), default_factory=0) # ty: ignore[invalid-argument-type] + + def test_non_callable_factory_rejected(self) -> None: + with pytest.raises(TypeError, match="callable"): + Field(name="x", ty=TypeSchema("int"), default_factory="not_callable") # ty: ignore[invalid-argument-type] + + +# ########################################################################### +# 4. Construction +# ########################################################################### +class TestConstruction: + """Low-level __init__ via _make_type: positional/keyword args, defaults, errors.""" + + def test_keyword_args(self) -> None: + Cls = _make_type( + "ConKw", + [ + Field(name="x", ty=TypeSchema("int"), default=MISSING), + Field(name="y", ty=TypeSchema("float"), default=MISSING), + ], + ) + obj = Cls(x=42, y=3.14) + assert obj.x == 42 + assert obj.y == pytest.approx(3.14) + + def test_positional_args(self) -> None: + Cls = _make_type( + "ConPos", + [ + Field(name="x", ty=TypeSchema("int"), default=MISSING), + Field(name="y", ty=TypeSchema("float"), default=MISSING), + ], + ) + obj = Cls(10, 2.5) + assert obj.x == 10 + assert obj.y == pytest.approx(2.5) + + def test_mixed_positional_and_keyword(self) -> None: + Cls = _make_type( + "ConMixed", + [ + Field(name="x", ty=TypeSchema("int"), default=MISSING), + Field(name="y", ty=TypeSchema("float"), default=MISSING), + ], + ) + obj = Cls(7, y=1.5) + assert obj.x == 7 + assert obj.y == pytest.approx(1.5) + + def test_default_value_int(self) -> None: + Cls = _make_type( + "ConDefInt", + [Field(name="x", ty=TypeSchema("int"), default=99)], + ) + assert Cls().x == 99 + + def test_default_value_float(self) -> None: + Cls = _make_type( + "ConDefFloat", + [Field(name="x", ty=TypeSchema("float"), default=1.5)], + ) + assert Cls().x == pytest.approx(1.5) + + def test_default_value_str(self) -> None: + Cls = _make_type( + "ConDefStr", + [Field(name="s", ty=TypeSchema("str"), default="hello")], + ) + assert Cls().s == "hello" + + def test_override_default(self) -> None: + Cls = _make_type( + "ConOverride", + [Field(name="x", ty=TypeSchema("int"), default=0)], + ) + assert Cls(x=42).x == 42 + + def test_required_and_optional_together(self) -> None: + Cls = _make_type( + "ConReqOpt", + [ + Field(name="required", ty=TypeSchema("int"), default=MISSING), + Field(name="optional", ty=TypeSchema("float"), default=0.0), + ], + ) + obj = Cls(required=5) + assert obj.required == 5 + assert obj.optional == pytest.approx(0.0) + + def test_missing_required_raises(self) -> None: + Cls = _make_type( + "ConMissing", + [Field(name="x", ty=TypeSchema("int"), default=MISSING)], + ) + with pytest.raises(TypeError): + Cls() + + def test_extra_kwarg_raises(self) -> None: + Cls = _make_type( + "ConExtra", + [Field(name="x", ty=TypeSchema("int"), default=MISSING)], + ) + with pytest.raises(TypeError): + Cls(x=1, bogus=2) + + def test_str_field_construction(self) -> None: + Cls = _make_type( + "ConStr", + [Field(name="name", ty=TypeSchema("str"), default=MISSING)], + ) + assert Cls(name="world").name == "world" + + def test_bool_field_construction(self) -> None: + Cls = _make_type( + "ConBool", + [Field(name="flag", ty=TypeSchema("bool"), default=MISSING)], + ) + assert Cls(flag=True).flag is True + assert Cls(flag=False).flag is False + + def test_kw_only_field(self) -> None: + Cls = _make_type( + "ConKwOnly", + [ + Field(name="x", ty=TypeSchema("int"), default=MISSING), + Field( + name="y", + ty=TypeSchema("int"), + default=MISSING, + kw_only=True, + ), + ], + ) + obj = Cls(1, y=2) + assert obj.x == 1 + assert obj.y == 2 + + def test_kw_only_rejects_positional(self) -> None: + Cls = _make_type( + "ConKwOnlyReject", + [ + Field(name="x", ty=TypeSchema("int"), default=MISSING), + Field( + name="y", + ty=TypeSchema("int"), + default=MISSING, + kw_only=True, + ), + ], + ) + with pytest.raises(TypeError): + Cls(1, 2) + + def test_isinstance_check(self) -> None: + Cls = _make_type( + "ConIsInstance", + [Field(name="x", ty=TypeSchema("int"), default=MISSING)], + ) + obj = Cls(x=1) + assert isinstance(obj, Cls) + assert isinstance(obj, core.Object) + + +# ########################################################################### +# 5. Getter / Setter +# ########################################################################### +class TestGetterSetter: + """Field access: get/set POD, str, bool, mutation isolation.""" + + def test_get_int(self) -> None: + Cls = _make_type( + "GSInt", + [Field(name="x", ty=TypeSchema("int"), default=MISSING)], + ) + assert Cls(x=42).x == 42 + + def test_set_int(self) -> None: + Cls = _make_type( + "GSSetInt", + [Field(name="x", ty=TypeSchema("int"), default=MISSING)], + ) + obj = Cls(x=1) + obj.x = 100 + assert obj.x == 100 + + def test_get_float(self) -> None: + Cls = _make_type( + "GSFloat", + [Field(name="val", ty=TypeSchema("float"), default=MISSING)], + ) + assert Cls(val=3.14).val == pytest.approx(3.14) + + def test_set_float(self) -> None: + Cls = _make_type( + "GSSetFloat", + [Field(name="val", ty=TypeSchema("float"), default=MISSING)], + ) + obj = Cls(val=1.0) + obj.val = 2.718 + assert obj.val == pytest.approx(2.718) + + def test_get_str(self) -> None: + Cls = _make_type( + "GSStr", + [Field(name="s", ty=TypeSchema("str"), default=MISSING)], + ) + assert Cls(s="hello").s == "hello" + + def test_set_str(self) -> None: + Cls = _make_type( + "GSSetStr", + [Field(name="s", ty=TypeSchema("str"), default=MISSING)], + ) + obj = Cls(s="hello") + obj.s = "world" + assert obj.s == "world" + + def test_get_bool(self) -> None: + Cls = _make_type( + "GSBool", + [Field(name="flag", ty=TypeSchema("bool"), default=MISSING)], + ) + assert Cls(flag=True).flag is True + + def test_set_bool(self) -> None: + Cls = _make_type( + "GSSetBool", + [Field(name="flag", ty=TypeSchema("bool"), default=MISSING)], + ) + obj = Cls(flag=True) + obj.flag = False + assert obj.flag is False + + def test_mutation_isolated(self) -> None: + Cls = _make_type( + "GSIsolate", + [Field(name="x", ty=TypeSchema("int"), default=MISSING)], + ) + a = Cls(x=1) + b = Cls(x=1) + a.x = 99 + assert a.x == 99 + assert b.x == 1 + + def test_multiple_fields_mutation(self) -> None: + Cls = _make_type( + "GSMultiMut", + [ + Field(name="a", ty=TypeSchema("int"), default=MISSING), + Field(name="b", ty=TypeSchema("float"), default=MISSING), + Field(name="c", ty=TypeSchema("str"), default=MISSING), + ], + ) + obj = Cls(a=1, b=2.0, c="x") + obj.a = 10 + obj.b = 20.0 + obj.c = "y" + assert obj.a == 10 + assert obj.b == pytest.approx(20.0) + assert obj.c == "y" + + def test_set_array_field(self) -> None: + Cls = _make_type( + "GSSetArr", + [ + Field( + name="arr", + ty=TypeSchema("Array", (TypeSchema("int"),)), + default=MISSING, + ), + ], + ) + obj = Cls(arr=[1]) + obj.arr = tvm_ffi.Array([4, 5, 6]) + assert len(obj.arr) == 3 + assert obj.arr[0] == 4 + + +# ########################################################################### +# 6. ObjectRef Fields +# ########################################################################### +class TestObjectRefFields: + """Fields holding ObjectRef types: Array, custom objects.""" + + def test_array_field(self) -> None: + Cls = _make_type( + "ObjArr", + [ + Field( + name="arr", + ty=TypeSchema("Array", (TypeSchema("int"),)), + default=MISSING, + ), + ], + ) + obj = Cls(arr=tvm_ffi.Array([1, 2, 3])) + assert len(obj.arr) == 3 + assert obj.arr[0] == 1 + assert obj.arr[2] == 3 + + def test_array_field_from_list(self) -> None: + Cls = _make_type( + "ObjArrList", + [ + Field( + name="arr", + ty=TypeSchema("Array", (TypeSchema("int"),)), + default=MISSING, + ), + ], + ) + obj = Cls(arr=[10, 20, 30]) + assert len(obj.arr) == 3 + assert obj.arr[1] == 20 + + def test_nested_object_field(self) -> None: + Inner = _make_type( + "ObjInner", + [Field(name="val", ty=TypeSchema("int"), default=MISSING)], + ) + inner_info = getattr(Inner, "__tvm_ffi_type_info__") + inner_schema = TypeSchema(inner_info.type_key, origin_type_index=inner_info.type_index) + Outer = _make_type( + "ObjOuter", + [Field(name="child", ty=inner_schema, default=MISSING)], + ) + assert Outer(child=Inner(val=42)).child.val == 42 + + +# ########################################################################### +# 7. Optional Fields +# ########################################################################### +class TestOptionalFields: + """Optional/Union fields: None and non-None values.""" + + def test_optional_int_with_value(self) -> None: + Cls = _make_type( + "OptIntV", + [ + Field( + name="x", + ty=TypeSchema("Optional", (TypeSchema("int"),)), + default=MISSING, + ), + ], + ) + assert Cls(x=42).x == 42 + + def test_optional_int_with_none(self) -> None: + Cls = _make_type( + "OptIntN", + [ + Field( + name="x", + ty=TypeSchema("Optional", (TypeSchema("int"),)), + default=None, + ), + ], + ) + assert Cls().x is None + + def test_optional_str_with_value(self) -> None: + Cls = _make_type( + "OptStrV", + [ + Field( + name="s", + ty=TypeSchema("Optional", (TypeSchema("str"),)), + default=MISSING, + ), + ], + ) + assert Cls(s="hello").s == "hello" + + def test_optional_str_with_none(self) -> None: + Cls = _make_type( + "OptStrN", + [ + Field( + name="s", + ty=TypeSchema("Optional", (TypeSchema("str"),)), + default=None, + ), + ], + ) + assert Cls().s is None + + def test_optional_set_to_none(self) -> None: + Cls = _make_type( + "OptSet", + [ + Field( + name="x", + ty=TypeSchema("Optional", (TypeSchema("int"),)), + default=MISSING, + ), + ], + ) + obj = Cls(x=42) + obj.x = None + assert obj.x is None + + def test_optional_set_back_to_value(self) -> None: + Cls = _make_type( + "OptBack", + [ + Field( + name="x", + ty=TypeSchema("Optional", (TypeSchema("int"),)), + default=None, + ), + ], + ) + obj = Cls() + obj.x = 99 + assert obj.x == 99 + + def test_all_optional_fields_default_none(self) -> None: + Cls = _make_type( + "AllOpt", + [ + Field( + name="a", + ty=TypeSchema("Optional", (TypeSchema("int"),)), + default=None, + ), + Field( + name="b", + ty=TypeSchema("Optional", (TypeSchema("str"),)), + default=None, + ), + Field( + name="c", + ty=TypeSchema("Optional", (TypeSchema("float"),)), + default=None, + ), + ], + ) + obj = Cls() + assert obj.a is None + assert obj.b is None + assert obj.c is None + obj.a = 42 + assert obj.a == 42 + obj.b = "hello" + assert obj.b == "hello" + + def test_optional_object_none_and_back(self) -> None: + Cls = _make_type( + "OptObjRound", + [ + Field( + name="ref", + ty=TypeSchema("Optional", (TypeSchema("Object"),)), + default=None, + ), + ], + ) + obj = Cls() + assert obj.ref is None + obj.ref = tvm_ffi.Array([1]) + assert len(obj.ref) == 1 + obj.ref = None + assert obj.ref is None + + def test_union_int_str(self) -> None: + """Union[int, str] field should accept both types.""" + Cls = _make_type( + "UnionIntStr", + [ + Field( + name="val", + ty=TypeSchema("Union", (TypeSchema("int"), TypeSchema("str"))), + default=MISSING, + ), + ], + ) + obj = Cls(val=42) + assert obj.val == 42 + obj.val = "hello" + assert obj.val == "hello" + + def test_union_int_str_rejects_float(self) -> None: + """Union[int, str] should reject float (not in union).""" + Cls = _make_type( + "UnionReject", + [ + Field( + name="val", + ty=TypeSchema("Union", (TypeSchema("int"), TypeSchema("str"))), + default=MISSING, + ), + ], + ) + obj = Cls(val=1) + with pytest.raises((TypeError, RuntimeError)): + obj.val = 3.14 + + def test_optional_union(self) -> None: + """Optional[Union[int, str]] should accept None, int, and str.""" + Cls = _make_type( + "OptUnion", + [ + Field( + name="val", + ty=TypeSchema( + "Optional", + (TypeSchema("Union", (TypeSchema("int"), TypeSchema("str"))),), + ), + default=None, + ), + ], + ) + obj = Cls() + assert obj.val is None + obj.val = 42 + assert obj.val == 42 + obj.val = "hi" + assert obj.val == "hi" + obj.val = None + assert obj.val is None + + +# ########################################################################### +# 8. Any Fields +# ########################################################################### +class TestAnyField: + """Fields with TypeSchema('Any'): hold any value type.""" + + def test_any_holds_int(self) -> None: + Cls = _make_type( + "AnyI", + [Field(name="val", ty=TypeSchema("Any"), default=None)], + ) + assert Cls(val=42).val == 42 + + def test_any_holds_str(self) -> None: + Cls = _make_type( + "AnyS", + [Field(name="val", ty=TypeSchema("Any"), default=None)], + ) + assert Cls(val="hello").val == "hello" + + def test_any_holds_none(self) -> None: + Cls = _make_type( + "AnyN", + [Field(name="val", ty=TypeSchema("Any"), default=None)], + ) + assert Cls().val is None + + def test_any_holds_object(self) -> None: + Cls = _make_type( + "AnyObj", + [Field(name="val", ty=TypeSchema("Any"), default=None)], + ) + arr = tvm_ffi.Array([1, 2]) + assert len(Cls(val=arr).val) == 2 + + def test_any_type_change(self) -> None: + Cls = _make_type( + "AnyChg", + [Field(name="val", ty=TypeSchema("Any"), default=None)], + ) + obj = Cls() + obj.val = 42 + assert obj.val == 42 + obj.val = "hello" + assert obj.val == "hello" + obj.val = None + assert obj.val is None + obj.val = tvm_ffi.Array([1]) + assert len(obj.val) == 1 + + +# ########################################################################### +# 9. Default Factory +# ########################################################################### +class TestDefaultFactory: + """default_factory support: fresh instances, override, various types.""" + + def test_factory_produces_fresh_instances(self) -> None: + Cls = _make_type( + "DFFresh", + [ + Field( + name="data", + ty=TypeSchema("Array", (TypeSchema("int"),)), + default_factory=lambda: tvm_ffi.Array([]), + ), + ], + ) + a = Cls() + b = Cls() + assert not a.data.same_as(b.data) + + def test_factory_with_content(self) -> None: + Cls = _make_type( + "DFContent", + [ + Field( + name="items", + ty=TypeSchema("Array", (TypeSchema("int"),)), + default_factory=lambda: tvm_ffi.Array([1, 2, 3]), + ), + ], + ) + obj = Cls() + assert len(obj.items) == 3 + assert obj.items[0] == 1 + + def test_factory_override(self) -> None: + Cls = _make_type( + "DFOverride", + [Field(name="x", ty=TypeSchema("int"), default_factory=lambda: 42)], + ) + assert Cls(x=99).x == 99 + + def test_factory_str(self) -> None: + Cls = _make_type( + "DFStr", + [Field(name="s", ty=TypeSchema("str"), default_factory=lambda: "generated")], + ) + assert Cls().s == "generated" + + +# ########################################################################### +# 10. Repr +# ########################################################################### +class TestFieldRepr: + """Low-level repr via _make_type: field values, repr=False exclusion.""" + + def test_repr_includes_fields(self) -> None: + Cls = _make_type( + "ReprBasic", + [ + Field(name="x", ty=TypeSchema("int"), default=MISSING), + Field(name="y", ty=TypeSchema("float"), default=0.0), + ], + ) + r = ReprPrint(Cls(x=42, y=3.14)) + assert "x=42" in r + assert "y=3.14" in r + + def test_repr_str_field(self) -> None: + Cls = _make_type( + "ReprStr", + [Field(name="name", ty=TypeSchema("str"), default=MISSING)], + ) + assert '"hello"' in ReprPrint(Cls(name="hello")) + + def test_repr_bool_field(self) -> None: + Cls = _make_type( + "ReprBool", + [Field(name="flag", ty=TypeSchema("bool"), default=MISSING)], + ) + assert "flag=True" in ReprPrint(Cls(flag=True)) + + def test_repr_false_excluded(self) -> None: + Cls = _make_type( + "ReprExcl", + [ + Field(name="visible", ty=TypeSchema("int"), default=MISSING), + Field( + name="hidden", + ty=TypeSchema("int"), + default=0, + repr=False, + ), + ], + ) + r = ReprPrint(Cls(visible=42)) + assert "visible=42" in r + assert "hidden" not in r + + def test_python_repr_delegates(self) -> None: + Cls = _make_type( + "ReprDeleg", + [Field(name="x", ty=TypeSchema("int"), default=MISSING)], + ) + assert "x=7" in repr(Cls(x=7)) + + def test_repr_contains_type_key(self) -> None: + Cls = _make_type( + "ReprKey", + [Field(name="x", ty=TypeSchema("int"), default=MISSING)], + ) + info = getattr(Cls, "__tvm_ffi_type_info__") + assert info.type_key in ReprPrint(Cls(x=1)) + + def test_repr_optional_none(self) -> None: + Cls = _make_type( + "ReprOptNone", + [ + Field( + name="x", + ty=TypeSchema("Optional", (TypeSchema("int"),)), + default=None, + ), + ], + ) + r = ReprPrint(Cls()) + assert isinstance(r, str) + + def test_repr_array_field(self) -> None: + Cls = _make_type( + "ReprArr", + [ + Field( + name="items", + ty=TypeSchema("Array", (TypeSchema("int"),)), + default=MISSING, + ), + ], + ) + r = ReprPrint(Cls(items=[1, 2, 3])) + assert isinstance(r, str) + + +# ########################################################################### +# 11. Hash +# ########################################################################### +class TestFieldHash: + """Low-level hash via _make_type: equal objects same hash, hash=False exclusion.""" + + def test_equal_objects_same_hash(self) -> None: + Cls = _make_type( + "HashEq", + [ + Field( + name="x", + ty=TypeSchema("int"), + default=MISSING, + compare=True, + ), + Field( + name="y", + ty=TypeSchema("float"), + default=MISSING, + compare=True, + ), + ], + eq=True, + unsafe_hash=True, + ) + assert RecursiveHash(Cls(x=1, y=2.0)) == RecursiveHash(Cls(x=1, y=2.0)) + + def test_different_objects_different_hash(self) -> None: + Cls = _make_type( + "HashDiff", + [ + Field( + name="x", + ty=TypeSchema("int"), + default=MISSING, + compare=True, + ), + ], + eq=True, + unsafe_hash=True, + ) + assert RecursiveHash(Cls(x=1)) != RecursiveHash(Cls(x=2)) + + def test_hash_false_field_ignored(self) -> None: + Cls = _make_type( + "HashIgn", + [ + Field( + name="key", + ty=TypeSchema("int"), + default=MISSING, + compare=True, + ), + Field( + name="ignored", + ty=TypeSchema("int"), + default=0, + hash=False, + ), + ], + eq=True, + unsafe_hash=True, + ) + assert RecursiveHash(Cls(key=42, ignored=100)) == RecursiveHash(Cls(key=42, ignored=999)) + + def test_hash_dunder_installed(self) -> None: + Cls = _make_type( + "HashDunder", + [Field(name="x", ty=TypeSchema("int"), default=MISSING)], + eq=True, + unsafe_hash=True, + ) + assert isinstance(hash(Cls(x=42)), int) + + def test_usable_as_dict_key(self) -> None: + Cls = _make_type( + "HashDict", + [ + Field( + name="x", + ty=TypeSchema("int"), + default=MISSING, + compare=True, + ), + ], + eq=True, + unsafe_hash=True, + ) + assert {Cls(x=1): "value"}[Cls(x=1)] == "value" + + def test_usable_in_set(self) -> None: + Cls = _make_type( + "HashSet", + [ + Field( + name="x", + ty=TypeSchema("int"), + default=MISSING, + compare=True, + ), + ], + eq=True, + unsafe_hash=True, + ) + assert len({Cls(x=1), Cls(x=1), Cls(x=2)}) == 2 + + +# ########################################################################### +# 12. Equality +# ########################################################################### +class TestFieldEquality: + """Low-level equality via _make_type: structural compare, compare=False exclusion.""" + + def test_equal_objects(self) -> None: + Cls = _make_type( + "EqEqual", + [ + Field( + name="x", + ty=TypeSchema("int"), + default=MISSING, + compare=True, + ), + Field( + name="y", + ty=TypeSchema("float"), + default=MISSING, + compare=True, + ), + ], + eq=True, + ) + assert Cls(x=1, y=2.0) == Cls(x=1, y=2.0) + + def test_different_objects(self) -> None: + Cls = _make_type( + "EqDiff", + [ + Field( + name="x", + ty=TypeSchema("int"), + default=MISSING, + compare=True, + ), + ], + eq=True, + ) + assert Cls(x=1) != Cls(x=2) + + def test_compare_false_field_ignored(self) -> None: + Cls = _make_type( + "EqIgn", + [ + Field( + name="key", + ty=TypeSchema("int"), + default=MISSING, + compare=True, + ), + Field( + name="ignored", + ty=TypeSchema("int"), + default=0, + compare=False, + ), + ], + eq=True, + ) + assert RecursiveEq(Cls(key=42, ignored=100), Cls(key=42, ignored=999)) + + def test_compare_off_excludes_from_eq(self) -> None: + """Fields with compare=False (default) are ignored by RecursiveEq.""" + Cls = _make_type( + "CmpOff", + [Field(name="x", ty=TypeSchema("int"), default=MISSING)], + eq=True, + ) + assert RecursiveEq(Cls(x=1), Cls(x=2)) + + def test_compare_true_includes_in_eq(self) -> None: + Cls = _make_type( + "CmpOn", + [ + Field( + name="x", + ty=TypeSchema("int"), + default=MISSING, + compare=True, + ), + ], + eq=True, + ) + assert not RecursiveEq(Cls(x=1), Cls(x=2)) + + def test_eq_reflexive(self) -> None: + Cls = _make_type( + "EqRefl", + [ + Field( + name="x", + ty=TypeSchema("int"), + default=MISSING, + compare=True, + ), + ], + eq=True, + ) + a = Cls(x=42) + assert a == a + + def test_eq_symmetric(self) -> None: + Cls = _make_type( + "EqSym", + [ + Field( + name="x", + ty=TypeSchema("int"), + default=MISSING, + compare=True, + ), + ], + eq=True, + ) + a, b = Cls(x=1), Cls(x=1) + assert a == b + assert b == a + + def test_eq_with_str_field(self) -> None: + Cls = _make_type( + "EqStr", + [ + Field( + name="s", + ty=TypeSchema("str"), + default=MISSING, + compare=True, + ), + ], + eq=True, + ) + assert RecursiveEq(Cls(s="hello"), Cls(s="hello")) + assert not RecursiveEq(Cls(s="hello"), Cls(s="world")) + + def test_eq_hash_consistency(self) -> None: + Cls = _make_type( + "EqHashCon", + [ + Field( + name="x", + ty=TypeSchema("int"), + default=MISSING, + compare=True, + ), + Field( + name="y", + ty=TypeSchema("float"), + default=MISSING, + compare=True, + ), + ], + eq=True, + unsafe_hash=True, + ) + a, b = Cls(x=1, y=2.0), Cls(x=1, y=2.0) + assert RecursiveEq(a, b) + assert RecursiveHash(a) == RecursiveHash(b) + + +# ########################################################################### +# 13. Edge Cases +# ########################################################################### +class TestFieldEdgeCases: + """Low-level edge cases via _make_type: empty class, extreme values, init=False.""" + + def test_empty_class_no_fields(self) -> None: + Cls = _make_type("EdgeEmpty", []) + obj = Cls() + assert isinstance(obj, core.Object) + assert isinstance(obj, Cls) + + def test_empty_class_repr(self) -> None: + Cls = _make_type("EdgeEmptyRepr", []) + info = getattr(Cls, "__tvm_ffi_type_info__") + assert info.type_key in ReprPrint(Cls()) + + def test_bool_true_and_false(self) -> None: + Cls = _make_type( + "EdgeBool", + [Field(name="flag", ty=TypeSchema("bool"), default=MISSING)], + ) + assert Cls(flag=True).flag is True + assert Cls(flag=False).flag is False + + def test_bool_default_false(self) -> None: + Cls = _make_type( + "EdgeBoolDef", + [Field(name="flag", ty=TypeSchema("bool"), default=False)], + ) + assert Cls().flag is False + + def test_multiple_types_together(self) -> None: + Cls = _make_type( + "EdgeMulti", + [ + Field( + name="i", + ty=TypeSchema("int"), + default=MISSING, + compare=True, + ), + Field(name="f", ty=TypeSchema("float"), default=MISSING), + Field(name="s", ty=TypeSchema("str"), default=MISSING), + Field(name="b", ty=TypeSchema("bool"), default=MISSING), + ], + ) + obj = Cls(i=42, f=3.14, s="test", b=True) + assert obj.i == 42 + assert obj.f == pytest.approx(3.14) + assert obj.s == "test" + assert obj.b is True + + def test_pod_and_objectref_mixed(self) -> None: + Cls = _make_type( + "EdgeMixed", + [ + Field(name="count", ty=TypeSchema("int"), default=MISSING), + Field( + name="items", + ty=TypeSchema("Array", (TypeSchema("int"),)), + default=MISSING, + ), + Field(name="label", ty=TypeSchema("str"), default=""), + ], + ) + obj = Cls(count=3, items=[1, 2, 3]) + assert obj.count == 3 + assert len(obj.items) == 3 + assert obj.label == "" + + def test_multiple_types_with_defaults(self) -> None: + Cls = _make_type( + "EdgeMultiDef", + [ + Field(name="i", ty=TypeSchema("int"), default=0), + Field(name="f", ty=TypeSchema("float"), default=1.0), + Field(name="s", ty=TypeSchema("str"), default="default"), + Field(name="b", ty=TypeSchema("bool"), default=True), + ], + ) + obj = Cls() + assert obj.i == 0 + assert obj.f == pytest.approx(1.0) + assert obj.s == "default" + assert obj.b is True + + def test_zero_values(self) -> None: + Cls = _make_type( + "EdgeZero", + [ + Field(name="i", ty=TypeSchema("int"), default=MISSING), + Field(name="f", ty=TypeSchema("float"), default=MISSING), + ], + ) + obj = Cls(i=0, f=0.0) + assert obj.i == 0 + assert obj.f == 0.0 + + def test_negative_values(self) -> None: + Cls = _make_type( + "EdgeNeg", + [ + Field(name="i", ty=TypeSchema("int"), default=MISSING), + Field(name="f", ty=TypeSchema("float"), default=MISSING), + ], + ) + obj = Cls(i=-42, f=-3.14) + assert obj.i == -42 + assert obj.f == pytest.approx(-3.14) + + def test_large_int(self) -> None: + Cls = _make_type( + "EdgeLargeInt", + [Field(name="x", ty=TypeSchema("int"), default=MISSING)], + ) + large = 2**62 + assert Cls(x=large).x == large + + def test_empty_string_field(self) -> None: + Cls = _make_type( + "EdgeEmptyStr", + [Field(name="s", ty=TypeSchema("str"), default=MISSING)], + ) + assert Cls(s="").s == "" + + def test_long_string_field(self) -> None: + Cls = _make_type( + "EdgeLongStr", + [Field(name="s", ty=TypeSchema("str"), default=MISSING)], + ) + long_str = "a" * 1000 + assert Cls(s=long_str).s == long_str + + def test_equality_empty_class(self) -> None: + Cls = _make_type("EdgeEmptyEq", [], eq=True, unsafe_hash=True) + assert RecursiveEq(Cls(), Cls()) + assert RecursiveHash(Cls()) == RecursiveHash(Cls()) + + def test_init_false_field_excluded_from_init(self) -> None: + Cls = _make_type( + "EdgeInitFalse", + [ + Field(name="visible", ty=TypeSchema("int"), default=MISSING), + Field( + name="internal", + ty=TypeSchema("int"), + default=0, + init=False, + ), + ], + ) + obj = Cls(visible=42) + assert obj.visible == 42 + assert obj.internal == 0 + + def test_init_false_field_rejected_as_kwarg(self) -> None: + Cls = _make_type( + "EdgeInitFalseReject", + [ + Field(name="visible", ty=TypeSchema("int"), default=MISSING), + Field( + name="internal", + ty=TypeSchema("int"), + default=0, + init=False, + ), + ], + ) + with pytest.raises(TypeError): + Cls(visible=1, internal=2) + + def test_init_false_field_writable(self) -> None: + Cls = _make_type( + "EdgeInitFalseWrite", + [ + Field(name="visible", ty=TypeSchema("int"), default=MISSING), + Field( + name="internal", + ty=TypeSchema("int"), + default=0, + init=False, + ), + ], + ) + obj = Cls(visible=1) + obj.internal = 99 + assert obj.internal == 99 + + +# ########################################################################### +# 14. Inheritance (Python-defined parent) +# ########################################################################### +class TestFieldInheritance: + """Low-level inheritance via _make_type: field offsets, parent-child layout.""" + + def test_child_fields_after_parent(self) -> None: + Parent = _make_type( + "InhParent", + [Field(name="x", ty=TypeSchema("int"), default=MISSING)], + ) + Child = _make_type( + "InhChild", + [Field(name="y", ty=TypeSchema("int"), default=MISSING)], + parent=Parent, + ) + obj = Child(1, 2) + assert obj.x == 1 + assert obj.y == 2 + + def test_child_field_offsets_non_overlapping(self) -> None: + Parent = _make_type( + "InhParentOff", + [Field(name="x", ty=TypeSchema("int"), default=MISSING)], + ) + Child = _make_type( + "InhChildOff", + [Field(name="y", ty=TypeSchema("int"), default=MISSING)], + parent=Parent, + ) + p_info = getattr(Parent, "__tvm_ffi_type_info__") + c_info = getattr(Child, "__tvm_ffi_type_info__") + parent_end = p_info.fields[0].offset + p_info.fields[0].size + assert c_info.fields[0].offset >= parent_end + + def test_mutation_no_aliasing(self) -> None: + Parent = _make_type( + "InhParentAlias", + [Field(name="x", ty=TypeSchema("int"), default=MISSING)], + ) + Child = _make_type( + "InhChildAlias", + [Field(name="y", ty=TypeSchema("int"), default=MISSING)], + parent=Parent, + ) + obj = Child(1, 2) + obj.y = 9 + assert obj.x == 1 + assert obj.y == 9 + + def test_three_level_inheritance(self) -> None: + """Object → A → B → C: all fields accessible and non-overlapping.""" + A = _make_type( + "InhA", + [Field(name="a", ty=TypeSchema("int"), default=MISSING)], + ) + B = _make_type( + "InhB", + [Field(name="b", ty=TypeSchema("str"), default=MISSING)], + parent=A, + ) + C = _make_type( + "InhC", + [Field(name="c", ty=TypeSchema("float"), default=MISSING)], + parent=B, + ) + obj = C(a=1, b="two", c=3.0) + assert obj.a == 1 + assert obj.b == "two" + assert obj.c == pytest.approx(3.0) + + def test_three_level_offsets_non_overlapping(self) -> None: + A = _make_type( + "InhAOff", + [Field(name="a", ty=TypeSchema("int"), default=MISSING)], + ) + B = _make_type( + "InhBOff", + [Field(name="b", ty=TypeSchema("int"), default=MISSING)], + parent=A, + ) + C = _make_type( + "InhCOff", + [Field(name="c", ty=TypeSchema("int"), default=MISSING)], + parent=B, + ) + a_info = getattr(A, "__tvm_ffi_type_info__") + b_info = getattr(B, "__tvm_ffi_type_info__") + c_info = getattr(C, "__tvm_ffi_type_info__") + a_end = a_info.fields[0].offset + a_info.fields[0].size + b_end = b_info.fields[0].offset + b_info.fields[0].size + assert b_info.fields[0].offset >= a_end + assert c_info.fields[0].offset >= b_end + + def test_three_level_mutation_no_aliasing(self) -> None: + A = _make_type( + "InhAMut", + [Field(name="a", ty=TypeSchema("int"), default=MISSING)], + ) + B = _make_type( + "InhBMut", + [Field(name="b", ty=TypeSchema("int"), default=MISSING)], + parent=A, + ) + C = _make_type( + "InhCMut", + [Field(name="c", ty=TypeSchema("int"), default=MISSING)], + parent=B, + ) + obj = C(a=1, b=2, c=3) + obj.c = 99 + assert obj.a == 1 + assert obj.b == 2 + assert obj.c == 99 + obj.a = 77 + assert obj.a == 77 + assert obj.b == 2 + assert obj.c == 99 + + def test_three_level_isinstance(self) -> None: + A = _make_type( + "InhAIs", + [Field(name="a", ty=TypeSchema("int"), default=MISSING)], + ) + B = _make_type( + "InhBIs", + [Field(name="b", ty=TypeSchema("int"), default=MISSING)], + parent=A, + ) + C = _make_type( + "InhCIs", + [Field(name="c", ty=TypeSchema("int"), default=MISSING)], + parent=B, + ) + obj = C(a=1, b=2, c=3) + assert isinstance(obj, C) + assert isinstance(obj, B) + assert isinstance(obj, A) + assert isinstance(obj, core.Object) + + def test_three_level_deep_copy(self) -> None: + A = _make_type( + "InhACopy", + [Field(name="a", ty=TypeSchema("int"), default=MISSING)], + ) + B = _make_type( + "InhBCopy", + [Field(name="b", ty=TypeSchema("int"), default=MISSING)], + parent=A, + ) + C = _make_type( + "InhCCopy", + [Field(name="c", ty=TypeSchema("int"), default=MISSING)], + parent=B, + ) + obj = C(a=1, b=2, c=3) + obj_copy = DeepCopy(obj) + assert not obj.same_as(obj_copy) + assert obj_copy.a == 1 + assert obj_copy.b == 2 + assert obj_copy.c == 3 + obj_copy.c = 99 + assert obj.c == 3 + + +# ########################################################################### +# 15. Mutual / Self References +# ########################################################################### +class TestMutualReferences: + """Low-level mutual and self-referential type fields via two-phase registration.""" + + def _register_bare(self, name: str) -> tuple[type, core.TypeInfo]: + """Register a type with no fields (phase 1 of two-phase).""" + parent_info = core._type_cls_to_type_info(core.Object) + assert parent_info is not None + cls = type(name, (core.Object,), {"__slots__": ()}) + info = core._register_py_class(parent_info, _unique_key_ff(name), cls) + return cls, info + + def _finalize(self, cls: type, info: core.TypeInfo, fields: list[Field]) -> None: + """Register fields and install class attrs (phase 2 of two-phase).""" + info._register_fields(fields) + setattr(cls, "__tvm_ffi_type_info__", info) + _add_class_attrs(cls, info) + _install_dataclass_dunders( + cls, + init=True, + repr=True, + eq=False, + order=False, + unsafe_hash=False, + ) + + def test_mutual_references(self) -> None: + """Foo has Optional[Bar], Bar has Optional[Foo].""" + Foo, foo_info = self._register_bare("MutFoo") + Bar, bar_info = self._register_bare("MutBar") + foo_schema = TypeSchema(foo_info.type_key, origin_type_index=foo_info.type_index) + bar_schema = TypeSchema(bar_info.type_key, origin_type_index=bar_info.type_index) + self._finalize( + Foo, + foo_info, + [ + Field(name="a", ty=TypeSchema("str"), default=MISSING), + Field( + name="bar", + ty=TypeSchema("Optional", (bar_schema,)), + default=None, + ), + ], + ) + self._finalize( + Bar, + bar_info, + [ + Field( + name="foo", + ty=TypeSchema("Optional", (foo_schema,)), + default=None, + ), + ], + ) + foo = Foo(a="hello") + bar = Bar() + bar.foo = foo + foo.bar = bar + assert foo.bar.foo.a == "hello" + + def test_self_referential_field(self) -> None: + """Bar has Optional[Bar] (self-reference).""" + Bar, bar_info = self._register_bare("SelfRef") + bar_schema = TypeSchema(bar_info.type_key, origin_type_index=bar_info.type_index) + self._finalize( + Bar, + bar_info, + [ + Field(name="val", ty=TypeSchema("int"), default=MISSING), + Field( + name="next", + ty=TypeSchema("Optional", (bar_schema,)), + default=None, + ), + ], + ) + a = Bar(val=1) + b = Bar(val=2, next=a) + assert b.next.val == 1 + assert a.next is None + # Circular: a → b → a + a.next = b + assert a.next.next.val == 1 + + def test_typed_mutual_ref_rejects_wrong_type(self) -> None: + """Optional[Foo] field should reject Bar objects.""" + Foo, foo_info = self._register_bare("TypedFoo") + Bar, bar_info = self._register_bare("TypedBar") + foo_schema = TypeSchema(foo_info.type_key, origin_type_index=foo_info.type_index) + self._finalize( + Foo, + foo_info, + [ + Field(name="x", ty=TypeSchema("int"), default=MISSING), + ], + ) + self._finalize( + Bar, + bar_info, + [ + Field( + name="foo", + ty=TypeSchema("Optional", (foo_schema,)), + default=None, + ), + ], + ) + bar = Bar() + bar.foo = Foo(x=1) # OK + assert bar.foo.x == 1 + with pytest.raises((TypeError, RuntimeError)): + bar.foo = bar # Bar is not Foo + + +# ########################################################################### +# 16. Inheritance (native C++ parent) +# ########################################################################### +class TestNativeParentInheritance: + """Low-level Python child of C++ TestObjectBase: offsets, fields, methods, copy.""" + + def test_non_overlapping_offsets(self) -> None: + parent_info = core._type_cls_to_type_info(_TestObjectBase) + assert parent_info is not None + Child = _make_type( + "InhNativeChild", + [Field(name="extra", ty=TypeSchema("int"), default=MISSING)], + parent=_TestObjectBase, + ) + child_info = getattr(Child, "__tvm_ffi_type_info__") + parent_end = max(f.offset + f.size for f in parent_info.fields) + assert child_info.fields[0].offset >= parent_end + + def test_preserves_parent_fields(self) -> None: + Child = _make_type( + "InhNativePreserve", + [Field(name="extra", ty=TypeSchema("int"), default=MISSING)], + parent=_TestObjectBase, + ) + obj = Child(extra=7, v_i64=1, v_f64=2.0, v_str="x") + assert obj.extra == 7 + assert obj.v_i64 == 1 + assert obj.v_f64 == 2.0 + assert obj.v_str == "x" + + def test_mutation_no_aliasing(self) -> None: + Child = _make_type( + "InhNativeMut", + [Field(name="extra", ty=TypeSchema("int"), default=MISSING)], + parent=_TestObjectBase, + ) + obj = Child(extra=7, v_i64=1, v_f64=2.0, v_str="x") + obj.extra = 33 + assert obj.extra == 33 + assert obj.v_i64 == 1 + assert obj.v_f64 == 2.0 + assert obj.v_str == "x" + + def test_parent_method_uses_parent_state(self) -> None: + Child = _make_type( + "InhNativeMethod", + [Field(name="extra", ty=TypeSchema("int"), default=MISSING)], + parent=_TestObjectBase, + ) + obj = Child(extra=7, v_i64=1, v_f64=2.0, v_str="x") + assert obj.add_i64(5) == 6 + + def test_copy_preserves_all_fields(self) -> None: + Child = _make_type( + "InhNativeCopy", + [Field(name="extra", ty=TypeSchema("int"), default=MISSING)], + parent=_TestObjectBase, + ) + obj = Child(extra=7, v_i64=1, v_f64=2.0, v_str="x") + obj_copy = copy.copy(obj) + assert obj_copy.extra == 7 + assert obj_copy.v_i64 == 1 + assert obj_copy.v_f64 == 2.0 + assert obj_copy.v_str == "x" + + def test_deepcopy_preserves_all_fields(self) -> None: + Child = _make_type( + "InhNativeDeepCopy", + [Field(name="extra", ty=TypeSchema("int"), default=MISSING)], + parent=_TestObjectBase, + ) + obj = Child(extra=7, v_i64=1, v_f64=2.0, v_str="x") + obj_copy = copy.deepcopy(obj) + assert obj_copy.extra == 7 + assert obj_copy.v_i64 == 1 + assert obj_copy.v_f64 == 2.0 + assert obj_copy.v_str == "x" + + +# ########################################################################### +# 16. Deep Copy +# ########################################################################### +class TestDeepCopy: + """Low-level DeepCopy via _make_type: nested ObjectRef, mutation independence.""" + + def test_deep_copy_basic(self) -> None: + Cls = _make_type( + "DCBasic", + [ + Field( + name="x", + ty=TypeSchema("int"), + default=MISSING, + compare=True, + ), + Field( + name="s", + ty=TypeSchema("str"), + default=MISSING, + compare=True, + ), + ], + eq=True, + ) + obj = Cls(x=42, s="hello") + obj_copy = DeepCopy(obj) + assert not obj.same_as(obj_copy) + assert RecursiveEq(obj, obj_copy) + + def test_deep_copy_nested_objectref(self) -> None: + Cls = _make_type( + "DCNested", + [ + Field( + name="items", + ty=TypeSchema("Array", (TypeSchema("int"),)), + default=MISSING, + ), + ], + ) + obj = Cls(items=tvm_ffi.Array([1, 2, 3])) + obj_copy = DeepCopy(obj) + assert not obj.items.same_as(obj_copy.items) + assert len(obj_copy.items) == 3 + + def test_deep_copy_mutate_independent(self) -> None: + Cls = _make_type( + "DCMut", + [Field(name="x", ty=TypeSchema("int"), default=MISSING)], + ) + obj = Cls(x=1) + obj_copy = DeepCopy(obj) + obj_copy.x = 99 + assert obj.x == 1 + + def test_python_deepcopy_dunder(self) -> None: + Cls = _make_type( + "DCPython", + [Field(name="x", ty=TypeSchema("int"), default=MISSING)], + ) + obj = Cls(x=42) + obj_copy = copy.deepcopy(obj) + assert obj_copy.x == 42 + assert not obj.same_as(obj_copy) + + +# ########################################################################### +# 17. Memory / Lifetime +# ########################################################################### +class TestMemoryLifetime: + """Low-level ref-counting: ObjectRef/Any fields are properly ref-counted.""" + + def test_objectref_field_kept_alive(self) -> None: + Cls = _make_type( + "MemAlive", + [ + Field( + name="arr", + ty=TypeSchema("Array", (TypeSchema("int"),)), + default=MISSING, + ), + ], + ) + arr = tvm_ffi.Array([1, 2, 3]) + obj = Cls(arr=arr) + del arr + gc.collect() + assert len(obj.arr) == 3 + + def test_multiple_objects_independent_lifetime(self) -> None: + Cls = _make_type( + "MemIndep", + [ + Field( + name="arr", + ty=TypeSchema("Array", (TypeSchema("int"),)), + default=MISSING, + ), + ], + ) + shared = tvm_ffi.Array([10, 20]) + a = Cls(arr=shared) + b = Cls(arr=shared) + del a + gc.collect() + assert len(b.arr) == 2 + assert b.arr[0] == 10 + + def test_str_field_any_storage(self) -> None: + Cls = _make_type( + "MemStr", + [Field(name="s", ty=TypeSchema("str"), default=MISSING)], + ) + assert Cls(s="hi").s == "hi" + long_str = "a" * 500 + assert Cls(s=long_str).s == long_str + + +# ########################################################################### +# 18. Bool Alignment +# ########################################################################### +class TestBoolAlignment: + """Low-level bool field layout: 1-byte packing, padding, alternating layouts.""" + + def test_bool_then_int_alignment(self) -> None: + Cls = _make_type( + "BoolAlign", + [ + Field(name="flag", ty=TypeSchema("bool"), default=MISSING), + Field(name="val", ty=TypeSchema("int"), default=MISSING), + ], + ) + info = getattr(Cls, "__tvm_ffi_type_info__") + assert info.fields[0].offset == 24 + assert info.fields[1].offset % 8 == 0 + assert info.fields[1].offset >= info.fields[0].offset + 1 + + def test_bool_then_int_values(self) -> None: + Cls = _make_type( + "BoolAlignVal", + [ + Field(name="flag", ty=TypeSchema("bool"), default=MISSING), + Field(name="val", ty=TypeSchema("int"), default=MISSING), + ], + ) + obj = Cls(flag=True, val=42) + assert obj.flag is True + assert obj.val == 42 + obj.flag = False + assert obj.flag is False + assert obj.val == 42 + + def test_multiple_bools_packed(self) -> None: + Cls = _make_type( + "MultiBool", + [ + Field(name="a", ty=TypeSchema("bool"), default=MISSING), + Field(name="b", ty=TypeSchema("bool"), default=MISSING), + Field(name="c", ty=TypeSchema("bool"), default=MISSING), + ], + ) + info = getattr(Cls, "__tvm_ffi_type_info__") + assert [f.offset for f in info.fields] == [24, 25, 26] + obj = Cls(a=True, b=False, c=True) + assert obj.a is True + assert obj.b is False + assert obj.c is True + + def test_bool_int_bool_int_alternating(self) -> None: + Cls = _make_type( + "BoolIntBoolInt", + [ + Field(name="b1", ty=TypeSchema("bool"), default=MISSING), + Field(name="i1", ty=TypeSchema("int"), default=MISSING), + Field(name="b2", ty=TypeSchema("bool"), default=MISSING), + Field(name="i2", ty=TypeSchema("int"), default=MISSING), + ], + ) + obj = Cls(b1=True, i1=100, b2=False, i2=200) + assert obj.b1 is True + assert obj.i1 == 100 + assert obj.b2 is False + assert obj.i2 == 200 + + +# ########################################################################### +# 19. Type Conversion Errors +# ########################################################################### +class TestTypeConversionErrors: + """Low-level type conversion: wrong-type setter/construction raises.""" + + def test_set_int_field_to_str_raises(self) -> None: + Cls = _make_type( + "ErrIntStr", + [Field(name="x", ty=TypeSchema("int"), default=MISSING)], + ) + obj = Cls(x=1) + with pytest.raises((TypeError, RuntimeError)): + obj.x = "not_an_int" + + def test_set_str_field_to_int_raises(self) -> None: + Cls = _make_type( + "ErrStrInt", + [Field(name="s", ty=TypeSchema("str"), default=MISSING)], + ) + obj = Cls(s="hello") + with pytest.raises((TypeError, RuntimeError)): + obj.s = 42 + + def test_construct_with_wrong_type_raises(self) -> None: + Cls = _make_type( + "ErrInit", + [Field(name="x", ty=TypeSchema("int"), default=MISSING)], + ) + with pytest.raises((TypeError, RuntimeError)): + Cls(x="bad") + + def test_set_wrong_type_preserves_old_value(self) -> None: + """Failed type-checked mutation preserves old value.""" + Cls = _make_type( + "ErrPreserve", + [Field(name="x", ty=TypeSchema("int"), default=MISSING)], + ) + obj = Cls(x=42) + with pytest.raises((TypeError, RuntimeError)): + obj.x = "bad_value" + assert obj.x == 42 + + def test_type_schema_convert_raises_directly(self) -> None: + """TypeSchema.convert raises TypeError for incompatible values.""" + ts = TypeSchema("int") + assert ts.convert(42).to_py() == 42 + with pytest.raises(TypeError): + ts.convert("not_an_int") + + def test_set_non_optional_object_field_to_none_raises(self) -> None: + """A non-Optional Object field must reject None.""" + Cls = _make_type( + "ErrObjNone", + [ + Field( + name="child", + ty=TypeSchema("Object"), + default=MISSING, + ), + ], + ) + obj = Cls(child=tvm_ffi.Array([1])) + with pytest.raises((TypeError, RuntimeError)): + obj.child = None + + def test_construct_non_optional_object_field_with_none_raises(self) -> None: + """Constructing with None for a non-Optional Object field must fail.""" + Cls = _make_type( + "ErrObjNoneInit", + [ + Field( + name="child", + ty=TypeSchema("Object"), + default=MISSING, + ), + ], + ) + with pytest.raises((TypeError, RuntimeError)): + Cls(child=None) + + def test_optional_object_field_accepts_none(self) -> None: + """An Optional[Object] field should accept None.""" + Cls = _make_type( + "OptObjNone", + [ + Field( + name="child", + ty=TypeSchema("Optional", (TypeSchema("Object"),)), + default=None, + ), + ], + ) + obj = Cls() + assert obj.child is None + obj.child = tvm_ffi.Array([1, 2]) + assert len(obj.child) == 2 + obj.child = None + assert obj.child is None + + def test_set_bool_field_to_str_raises(self) -> None: + Cls = _make_type( + "ErrBoolStr", + [Field(name="b", ty=TypeSchema("bool"), default=MISSING)], + ) + obj = Cls(b=True) + with pytest.raises((TypeError, RuntimeError)): + obj.b = "not_a_bool" + + def test_set_array_field_to_int_raises(self) -> None: + Cls = _make_type( + "ErrArrInt", + [ + Field( + name="arr", + ty=TypeSchema("Array", (TypeSchema("int"),)), + default=MISSING, + ), + ], + ) + obj = Cls(arr=[1]) + with pytest.raises((TypeError, RuntimeError)): + obj.arr = 42 + + def test_construct_multiple_wrong_types_first_caught(self) -> None: + """When the first field has a wrong type, the error is caught.""" + Cls = _make_type( + "ErrMulti", + [ + Field(name="x", ty=TypeSchema("int"), default=MISSING), + Field(name="y", ty=TypeSchema("str"), default=MISSING), + ], + ) + with pytest.raises((TypeError, RuntimeError)): + Cls(x="bad", y="ok") + + def test_set_optional_to_wrong_inner_type_raises(self) -> None: + Cls = _make_type( + "ErrOptWrong", + [ + Field( + name="x", + ty=TypeSchema("Optional", (TypeSchema("int"),)), + default=None, + ), + ], + ) + obj = Cls() + with pytest.raises((TypeError, RuntimeError)): + obj.x = "not_an_int" + + +# ########################################################################### +# 20. Setter / Getter Corner Cases +# ########################################################################### +class TestSetterGetterCornerCases: + """Low-level setter/getter corner cases: conversions, nesting, edge values.""" + + # --- Bool / int coercion --- + + def test_bool_field_accepts_true_false(self) -> None: + Cls = _make_type( + "SGBool", + [Field(name="b", ty=TypeSchema("bool"), default=MISSING)], + ) + obj = Cls(b=True) + assert obj.b is True + obj.b = False + assert obj.b is False + + def test_int_field_accepts_bool(self) -> None: + """Python bool is a subclass of int — FFI should accept it.""" + Cls = _make_type( + "SGIntBool", + [Field(name="x", ty=TypeSchema("int"), default=MISSING)], + ) + obj = Cls(x=True) + assert obj.x == 1 + obj.x = False + assert obj.x == 0 + + # --- Float edge values --- + + def test_float_field_inf_nan(self) -> None: + Cls = _make_type( + "SGFloatEdge", + [Field(name="f", ty=TypeSchema("float"), default=MISSING)], + ) + obj = Cls(f=float("inf")) + assert math.isinf(obj.f) + obj.f = float("-inf") + assert math.isinf(obj.f) and obj.f < 0 + obj.f = float("nan") + assert math.isnan(obj.f) + + def test_float_field_accepts_int(self) -> None: + Cls = _make_type( + "SGFloatInt", + [Field(name="f", ty=TypeSchema("float"), default=MISSING)], + ) + obj = Cls(f=42) + assert obj.f == pytest.approx(42.0) + + # --- String edge values --- + + def test_str_field_unicode(self) -> None: + Cls = _make_type( + "SGStrUni", + [Field(name="s", ty=TypeSchema("str"), default=MISSING)], + ) + obj = Cls(s="日本語テスト 🎉") + assert obj.s == "日本語テスト 🎉" + + def test_str_field_null_bytes(self) -> None: + Cls = _make_type( + "SGStrNull", + [Field(name="s", ty=TypeSchema("str"), default=MISSING)], + ) + s = "hello\x00world" + obj = Cls(s=s) + assert obj.s == s + + # --- Multiple mutations --- + + def test_repeated_mutation_same_field(self) -> None: + Cls = _make_type( + "SGRepeat", + [Field(name="x", ty=TypeSchema("int"), default=MISSING)], + ) + obj = Cls(x=0) + for i in range(100): + obj.x = i + assert obj.x == 99 + + def test_repeated_str_mutation(self) -> None: + """Stress: repeated str assignment should not leak.""" + Cls = _make_type( + "SGRepeatStr", + [Field(name="s", ty=TypeSchema("str"), default=MISSING)], + ) + obj = Cls(s="init") + for i in range(100): + obj.s = f"value_{i}" + assert obj.s == "value_99" + + def test_repeated_objectref_mutation(self) -> None: + """Stress: repeated ObjectRef assignment should properly DecRef old values.""" + Cls = _make_type( + "SGRepeatArr", + [ + Field( + name="arr", + ty=TypeSchema("Array", (TypeSchema("int"),)), + default=MISSING, + ), + ], + ) + obj = Cls(arr=[0]) + for i in range(50): + obj.arr = tvm_ffi.Array([i]) + assert obj.arr[0] == 49 + + # --- Nested object fields --- + + def test_nested_two_levels(self) -> None: + Inner = _make_type( + "SGInner", + [Field(name="val", ty=TypeSchema("int"), default=MISSING)], + ) + inner_info = getattr(Inner, "__tvm_ffi_type_info__") + inner_schema = TypeSchema(inner_info.type_key, origin_type_index=inner_info.type_index) + Outer = _make_type( + "SGOuter", + [Field(name="child", ty=inner_schema, default=MISSING)], + ) + obj = Outer(child=Inner(val=42)) + assert obj.child.val == 42 + # Mutate inner through outer + new_inner = Inner(val=99) + obj.child = new_inner + assert obj.child.val == 99 + + def test_self_referential_optional_field(self) -> None: + """A type with an Optional[Self] field (stored as Any).""" + Cls = _make_type( + "SGSelfRef", + [ + Field(name="val", ty=TypeSchema("int"), default=MISSING), + Field( + name="next", + ty=TypeSchema("Optional", (TypeSchema("Object"),)), + default=None, + ), + ], + ) + a = Cls(val=1) + b = Cls(val=2, next=a) + assert b.val == 2 + assert b.next.val == 1 + assert a.next is None + + # --- Default factory edge cases --- + + def test_default_factory_called_each_time(self) -> None: + call_count = [0] + + def factory() -> int: + call_count[0] += 1 + return call_count[0] + + Cls = _make_type( + "SGFactoryCount", + [Field(name="x", ty=TypeSchema("int"), default_factory=factory)], + ) + a = Cls() + b = Cls() + c = Cls() + assert a.x == 1 + assert b.x == 2 + assert c.x == 3 + + # --- Mixed types in one class --- + + def test_all_pod_plus_objectref_plus_optional(self) -> None: + Cls = _make_type( + "SGKitchenSink", + [ + Field(name="i", ty=TypeSchema("int"), default=MISSING), + Field(name="f", ty=TypeSchema("float"), default=MISSING), + Field(name="b", ty=TypeSchema("bool"), default=MISSING), + Field(name="s", ty=TypeSchema("str"), default=MISSING), + Field( + name="arr", + ty=TypeSchema("Array", (TypeSchema("int"),)), + default=MISSING, + ), + Field( + name="opt", + ty=TypeSchema("Optional", (TypeSchema("int"),)), + default=None, + ), + ], + ) + obj = Cls(i=1, f=2.0, b=True, s="hi", arr=[10, 20]) + assert obj.i == 1 + assert obj.f == pytest.approx(2.0) + assert obj.b is True + assert obj.s == "hi" + assert len(obj.arr) == 2 + assert obj.opt is None + # Mutate all fields + obj.i = -1 + obj.f = -2.0 + obj.b = False + obj.s = "bye" + obj.arr = tvm_ffi.Array([30]) + obj.opt = 42 + assert obj.i == -1 + assert obj.f == pytest.approx(-2.0) + assert obj.b is False + assert obj.s == "bye" + assert len(obj.arr) == 1 + assert obj.opt == 42 + + +# ########################################################################### +# 21. FFI Global Function Existence +# ########################################################################### +class TestFFIGlobalFunctions: + """Low-level FFI global function registration checks.""" + + def test_make_ffi_new_exists(self) -> None: + assert tvm_ffi.get_global_func("ffi.MakeFFINew", allow_missing=True) is not None + + def test_register_auto_init_exists(self) -> None: + assert tvm_ffi.get_global_func("ffi.RegisterAutoInit", allow_missing=True) is not None + + def test_get_field_getter_exists(self) -> None: + assert tvm_ffi.get_global_func("ffi.GetFieldGetter", allow_missing=True) is not None + + def test_make_field_setter_exists(self) -> None: + assert tvm_ffi.get_global_func("ffi.MakeFieldSetter", allow_missing=True) is not None + + def test_make_new_removed(self) -> None: + assert tvm_ffi.get_global_func("ffi.MakeNew", allow_missing=True) is None diff --git a/tests/python/test_type_converter.py b/tests/python/test_type_converter.py new file mode 100644 index 00000000..45c1da4e --- /dev/null +++ b/tests/python/test_type_converter.py @@ -0,0 +1,4540 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Tests for TypeSchema type conversion to CAny.""" + +from __future__ import annotations + +import collections.abc +import ctypes +import os +import sys +import typing +from numbers import Integral +from typing import Callable, Iterator, Optional, Union + +import pytest +import tvm_ffi +from tvm_ffi.core import ( + CAny, + ObjectConvertible, + TypeSchema, + _lookup_type_attr, + _object_type_key_to_index, +) + +# Python 3.9+ supports list[int], dict[str, int], tuple[int, ...] at runtime. +# On 3.8, these raise TypeError("'type' object is not subscriptable"). +_PY39 = sys.version_info >= (3, 9) +requires_py39 = pytest.mark.skipif( + not _PY39, reason="builtin generic subscripts require Python 3.9+" +) +from tvm_ffi.testing import ( + TestIntPair, + TestObjectBase, + TestObjectDerived, + _TestCxxClassBase, + _TestCxxClassDerived, + _TestCxxClassDerivedDerived, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- +def S(origin: str, *args: TypeSchema) -> TypeSchema: + """Shorthand constructor for TypeSchema (string-based).""" + return TypeSchema(origin, tuple(args)) + + +# Annotation-based constructor — the main subject under test. +A = TypeSchema.from_annotation + + +# --------------------------------------------------------------------------- +# Category 1: POD type exact match (check_value) +# --------------------------------------------------------------------------- +class TestPODExactMatch: + def test_int(self) -> None: + """Test int.""" + A(int).check_value(42) + + def test_float(self) -> None: + """Test float.""" + A(float).check_value(3.14) + + def test_bool_true(self) -> None: + """Test bool true.""" + A(bool).check_value(True) + + def test_bool_false(self) -> None: + """Test bool false.""" + A(bool).check_value(False) + + def test_str(self) -> None: + """Test str.""" + A(str).check_value("hello") + + def test_bytes(self) -> None: + """Test bytes.""" + A(bytes).check_value(b"data") + + def test_none(self) -> None: + """Test none.""" + A(type(None)).check_value(None) + + +# --------------------------------------------------------------------------- +# Category 2: Implicit conversions (mirrors TryCastFromAnyView) +# --------------------------------------------------------------------------- +class TestImplicitConversions: + def test_bool_to_int(self) -> None: + """Bool -> int is OK (C++: int accepts bool).""" + A(int).check_value(True) + + def test_int_to_float(self) -> None: + """Int -> float is OK (C++: float accepts int).""" + A(float).check_value(42) + + def test_bool_to_float(self) -> None: + """Bool -> float is OK (C++: float accepts bool).""" + A(float).check_value(True) + + def test_int_to_bool(self) -> None: + """Int -> bool is OK (C++: bool accepts int).""" + A(bool).check_value(1) + + +# --------------------------------------------------------------------------- +# Category 3: Rejection cases +# --------------------------------------------------------------------------- +class TestRejections: + def test_str_not_int(self) -> None: + """Test str not int.""" + with pytest.raises(TypeError, match="expected int"): + A(int).check_value("hello") + + def test_float_not_int(self) -> None: + """Test float not int.""" + with pytest.raises(TypeError): + A(int).check_value(3.14) + + def test_none_not_int(self) -> None: + """Test none not int.""" + with pytest.raises(TypeError): + A(int).check_value(None) + + def test_int_not_str(self) -> None: + """Test int not str.""" + with pytest.raises(TypeError): + A(str).check_value(42) + + def test_str_not_bool(self) -> None: + """Test str not bool.""" + with pytest.raises(TypeError): + A(bool).check_value("hello") + + def test_none_not_str(self) -> None: + """Test none not str.""" + with pytest.raises(TypeError): + A(str).check_value(None) + + def test_int_not_bytes(self) -> None: + """Test int not bytes.""" + with pytest.raises(TypeError): + A(bytes).check_value(42) + + def test_int_not_none(self) -> None: + """Test int not none.""" + with pytest.raises(TypeError): + A(type(None)).check_value(42) + + +# --------------------------------------------------------------------------- +# Category 4: Special types +# --------------------------------------------------------------------------- +class TestSpecialTypes: + def test_device_pass(self) -> None: + """Test device pass.""" + dev = tvm_ffi.Device("cpu", 0) + A(tvm_ffi.Device).check_value(dev) + + def test_device_fail(self) -> None: + """Test device fail.""" + with pytest.raises(TypeError): + A(tvm_ffi.Device).check_value(42) + + def test_dtype_pass(self) -> None: + """Test dtype pass.""" + dt = tvm_ffi.core.DataType("float32") + A(tvm_ffi.core.DataType).check_value(dt) + + def test_dtype_str_pass(self) -> None: + """Str accepted as dtype (will be parsed).""" + A(tvm_ffi.core.DataType).check_value("float32") + + def test_dtype_fail(self) -> None: + """Test dtype fail.""" + with pytest.raises(TypeError): + A(tvm_ffi.core.DataType).check_value(42) + + def test_opaque_ptr_pass(self) -> None: + """Test opaque ptr pass.""" + A(ctypes.c_void_p).check_value(ctypes.c_void_p(0)) + + def test_opaque_ptr_none_pass(self) -> None: + """Test opaque ptr none pass.""" + A(ctypes.c_void_p).check_value(None) + + def test_opaque_ptr_fail(self) -> None: + """Test opaque ptr fail.""" + with pytest.raises(TypeError): + A(ctypes.c_void_p).check_value(42) + + def test_callable_pass_function(self) -> None: + """Test callable pass function.""" + A(Callable).check_value(lambda x: x) + + def test_callable_pass_builtin(self) -> None: + """Test callable pass builtin.""" + A(Callable).check_value(len) + + def test_callable_fail(self) -> None: + """Test callable fail.""" + with pytest.raises(TypeError): + A(Callable).check_value(42) + + def test_collections_abc_callable_pass_function(self) -> None: + """collections.abc.Callable accepts Python functions.""" + A(collections.abc.Callable).check_value(lambda x: x) + + def test_collections_abc_callable_pass_builtin(self) -> None: + """collections.abc.Callable accepts builtins.""" + A(collections.abc.Callable).check_value(len) + + def test_collections_abc_callable_fail(self) -> None: + """collections.abc.Callable rejects non-callables.""" + with pytest.raises(TypeError, match="expected Callable"): + A(collections.abc.Callable).check_value(42) + + def test_callable_cobject_wraps_to_function(self) -> None: + """Callable CObjects are wrapped instead of asserting.""" + + class CallableObj(TestObjectBase): + def __call__(self, x: int) -> int: + return x + 1 + + obj = CallableObj(v_i64=1, v_f64=2.0, v_str="s") + with pytest.raises(TypeError, match=r"expected Callable, got .*TestObjectBase"): + A(Callable).check_value(obj) + + +# --------------------------------------------------------------------------- +# Category 5: Object types +# --------------------------------------------------------------------------- +class TestObjectTypes: + def test_object_pass(self) -> None: + """Any CObject passes TypeSchema('Object').""" + f = tvm_ffi.get_global_func("testing.echo") + A(tvm_ffi.core.Object).check_value(f) + + def test_object_fail(self) -> None: + """Test object fail.""" + with pytest.raises(TypeError): + A(tvm_ffi.core.Object).check_value(42) + + def test_specific_object_pass(self) -> None: + """A Function object should pass its own type schema.""" + f = tvm_ffi.get_global_func("testing.echo") + A(Callable).check_value(f) + + def test_function_from_extern_c_exists(self) -> None: + """ffi.FunctionFromExternC should be registered.""" + fn = tvm_ffi.get_global_func("ffi.FunctionFromExternC", allow_missing=True) + assert fn is not None, "ffi.FunctionFromExternC not registered" + + +# --------------------------------------------------------------------------- +# Category 6: Optional +# --------------------------------------------------------------------------- +class TestOptional: + def test_none_passes(self) -> None: + """Test none passes.""" + A(Optional[int]).check_value(None) + + def test_inner_type_passes(self) -> None: + """Test inner type passes.""" + A(Optional[int]).check_value(42) + + def test_wrong_type_fails(self) -> None: + """Test wrong type fails.""" + with pytest.raises(TypeError, match="expected int"): + A(Optional[int]).check_value("hello") + + def test_nested_optional(self) -> None: + """Test nested optional.""" + schema = A(Optional[Optional[int]]) + schema.check_value(None) + schema.check_value(42) + + +# --------------------------------------------------------------------------- +# Category 7: Union / Variant +# --------------------------------------------------------------------------- +class TestUnion: + def test_first_alt_passes(self) -> None: + """Test first alt passes.""" + A(Union[int, str]).check_value(42) + + def test_second_alt_passes(self) -> None: + """Test second alt passes.""" + A(Union[int, str]).check_value("hello") + + def test_no_alt_matches(self) -> None: + """Test no alt matches.""" + with pytest.raises(TypeError, match="got float"): + A(Union[int, str]).check_value(3.14) + + def test_bool_matches_int_alt(self) -> None: + """Bool is accepted by the int alternative.""" + A(Union[int, str]).check_value(True) + + +# --------------------------------------------------------------------------- +# Category 8: Containers +# --------------------------------------------------------------------------- +class TestContainers: + @requires_py39 + def test_array_list_pass(self) -> None: + """Test array list pass.""" + A(tuple[int, ...]).check_value([1, 2, 3]) + + @requires_py39 + def test_array_tuple_pass(self) -> None: + """Test array tuple pass.""" + A(tuple[int, ...]).check_value((1, 2, 3)) + + @requires_py39 + def test_array_wrong_element(self) -> None: + """Test array wrong element.""" + with pytest.raises(TypeError, match=r"element \[1\].*expected int"): + A(tuple[int, ...]).check_value([1, "x"]) + + @requires_py39 + def test_array_empty_pass(self) -> None: + """Test array empty pass.""" + A(tuple[int, ...]).check_value([]) + + @requires_py39 + def test_array_any_pass(self) -> None: + """Test array any pass.""" + A(tuple[typing.Any, ...]).check_value([1, "x", None]) + + @requires_py39 + def test_array_wrong_container_type(self) -> None: + """Test array wrong container type.""" + with pytest.raises(TypeError, match="expected Array"): + A(tuple[int, ...]).check_value(42) + + @requires_py39 + def test_array_rejects_generator(self) -> None: + """Generators are not accepted by Array schemas.""" + + def gen() -> Iterator[int]: + yield 1 + yield 2 + + with pytest.raises(TypeError, match="expected Array"): + A(tuple[int, ...]).check_value(gen()) + + @requires_py39 + def test_array_rejects_string(self) -> None: + """Strings are not accepted by Array schemas.""" + with pytest.raises(TypeError, match="expected Array"): + A(tuple[int, ...]).check_value("hello") + + @requires_py39 + def test_list_pass(self) -> None: + """Test list pass.""" + A(list[str]).check_value(["a", "b"]) + + @requires_py39 + def test_map_pass(self) -> None: + """Test map pass.""" + A(tvm_ffi.Map[str, int]).check_value({"a": 1, "b": 2}) + + @requires_py39 + def test_map_wrong_key(self) -> None: + """Test map wrong key.""" + with pytest.raises(TypeError, match="expected str"): + A(tvm_ffi.Map[str, int]).check_value({1: 2}) + + @requires_py39 + def test_map_wrong_value(self) -> None: + """Test map wrong value.""" + with pytest.raises(TypeError, match="expected int"): + A(tvm_ffi.Map[str, int]).check_value({"a": "b"}) + + @requires_py39 + def test_map_empty_pass(self) -> None: + """Test map empty pass.""" + A(tvm_ffi.Map[str, int]).check_value({}) + + @requires_py39 + def test_dict_pass(self) -> None: + """Test dict pass.""" + A(dict[str, int]).check_value({"a": 1}) + + @requires_py39 + def test_map_wrong_container(self) -> None: + """Test map wrong container.""" + with pytest.raises(TypeError, match="expected Map"): + A(tvm_ffi.Map[str, int]).check_value([1, 2]) + + @requires_py39 + def test_map_rejects_non_mapping_pairs(self) -> None: + """Lists of pairs are not accepted by Map schemas.""" + with pytest.raises(TypeError, match="expected Map"): + A(tvm_ffi.Map[str, int]).check_value([("a", 1)]) + + +# --------------------------------------------------------------------------- +# Category 9: Nested types +# --------------------------------------------------------------------------- +class TestNestedTypes: + @requires_py39 + def test_array_optional_int(self) -> None: + """Test array optional int.""" + A(tuple[Optional[int], ...]).check_value([1, None, 2]) + + @requires_py39 + def test_map_str_array_int(self) -> None: + """Test map str array int.""" + A(tvm_ffi.Map[str, tuple[int, ...]]).check_value({"a": [1, 2]}) + + @requires_py39 + def test_map_str_array_int_nested_fail(self) -> None: + """Test map str array int nested fail.""" + with pytest.raises(TypeError, match="expected int"): + A(tvm_ffi.Map[str, tuple[int, ...]]).check_value({"a": [1, "x"]}) + + @requires_py39 + def test_union_with_containers(self) -> None: + """Test union with containers.""" + schema = A(Union[int, tuple[str, ...]]) + schema.check_value(42) + schema.check_value(["a", "b"]) + with pytest.raises(TypeError): + schema.check_value(3.14) + + +# --------------------------------------------------------------------------- +# Category 10: Any +# --------------------------------------------------------------------------- +class TestAny: + def test_int(self) -> None: + """Test int.""" + A(typing.Any).check_value(42) + + def test_none(self) -> None: + """Test none.""" + A(typing.Any).check_value(None) + + def test_str(self) -> None: + """Test str.""" + A(typing.Any).check_value("hello") + + def test_list(self) -> None: + """Test list.""" + A(typing.Any).check_value([1, 2, 3]) + + def test_object(self) -> None: + """Test object.""" + A(typing.Any).check_value(object()) + + def test_object_convertible_convert(self) -> None: + """Any eagerly unwraps ObjectConvertible via asobject().""" + inner = TestIntPair(1, 2) + + class Convertible(ObjectConvertible): + def asobject(self) -> tvm_ffi.core.Object: + return inner + + result = A(typing.Any).convert(Convertible()).to_py() + assert result.same_as(inner) + + def test_object_convertible_error(self) -> None: + """Any surfaces asobject() failures during eager normalization.""" + + class BadConvertible(ObjectConvertible): + def asobject(self) -> tvm_ffi.core.Object: + raise RuntimeError("broken") + + with pytest.raises(TypeError, match=r"asobject\(\) failed"): + A(typing.Any).check_value(BadConvertible()) + + def test_object_protocol_convert(self) -> None: + """Any eagerly unwraps __tvm_ffi_object__ before dispatch.""" + inner = TestIntPair(1, 2) + + class ObjProto: + def __tvm_ffi_object__(self) -> object: + return inner + + result = A(typing.Any).convert(ObjProto()).to_py() + assert result.same_as(inner) + + def test_object_protocol_error(self) -> None: + """Any surfaces __tvm_ffi_object__ failures during eager normalization.""" + + class BadProto: + def __tvm_ffi_object__(self) -> object: + raise RuntimeError("broken") + + with pytest.raises(TypeError, match=r"__tvm_ffi_object__\(\) failed"): + A(typing.Any).check_value(BadProto()) + + +# --------------------------------------------------------------------------- +# Category 11: Error message quality +# --------------------------------------------------------------------------- +class TestErrorMessages: + def test_basic_type_mismatch(self) -> None: + """Test basic type mismatch.""" + with pytest.raises(TypeError, match=r"expected int, got str"): + A(int).check_value("hello") + + @requires_py39 + def test_nested_array_error(self) -> None: + """Test nested array error.""" + with pytest.raises(TypeError, match=r"element \[2\].*expected int, got str"): + A(tuple[int, ...]).check_value([1, 2, "x"]) + + @requires_py39 + def test_nested_map_error(self) -> None: + """Test nested map error.""" + with pytest.raises(TypeError, match=r"value for key 'b'.*expected int, got str"): + A(tvm_ffi.Map[str, int]).check_value({"a": 1, "b": "x"}) + + def test_union_error_lists_alternatives(self) -> None: + """Test union error lists alternatives.""" + with pytest.raises(TypeError, match="got float") as exc_info: + A(Union[int, str]).check_value(3.14) + err = str(exc_info.value) + assert "int" in err + assert "str" in err + + def test_schema_in_error_message(self) -> None: + """check_value includes the schema repr in the TypeError.""" + with pytest.raises(TypeError, match=r"type check failed for"): + A(int).check_value("hello") + + def test_convert_error_message(self) -> None: + """Convert includes the schema repr in the TypeError.""" + with pytest.raises(TypeError, match=r"type conversion failed for"): + A(int).convert("hello") + + +# --------------------------------------------------------------------------- +# Category 12: from_type_index factory +# --------------------------------------------------------------------------- +class TestFromTypeIndex: + def test_int(self) -> None: + """Test int.""" + schema = TypeSchema.from_type_index(1) # kTVMFFIInt + assert schema.origin == "int" + assert schema.origin_type_index == 1 + + def test_float(self) -> None: + """Test float.""" + schema = TypeSchema.from_type_index(3) # kTVMFFIFloat + assert schema.origin == "float" + + def test_bool(self) -> None: + """Test bool.""" + schema = TypeSchema.from_type_index(2) # kTVMFFIBool + assert schema.origin == "bool" + + def test_array_with_args(self) -> None: + """Test array with args.""" + schema = TypeSchema.from_type_index(71, (A(int),)) # kTVMFFIArray + assert schema.origin == "Array" + assert len(schema.args) == 1 + assert schema.args[0].origin == "int" + + def test_roundtrip_check(self) -> None: + """from_type_index then check_value works correctly.""" + schema = TypeSchema.from_type_index(1) # int + schema.check_value(42) + with pytest.raises(TypeError): + schema.check_value("hello") + + def test_none(self) -> None: + """Test none.""" + schema = TypeSchema.from_type_index(0) # kTVMFFINone + assert schema.origin == "None" + schema.check_value(None) + + def test_any(self) -> None: + """Test any.""" + schema = TypeSchema.from_type_index(-1) # kTVMFFIAny + assert schema.origin == "Any" + schema.check_value("anything") + + def test_str(self) -> None: + """Test str.""" + schema = TypeSchema.from_type_index(65) # kTVMFFIStr + assert schema.origin == "str" + schema.check_value("hello") + + def test_map_with_args(self) -> None: + """Test map with args.""" + schema = TypeSchema.from_type_index(72, (A(str), A(int))) # kTVMFFIMap + assert schema.origin == "Map" + schema.check_value({"a": 1}) + + +# --------------------------------------------------------------------------- +# Category 13: Edge cases +# --------------------------------------------------------------------------- +class TestEdgeCases: + def test_bytearray_passes_bytes(self) -> None: + """Test bytearray passes bytes.""" + A(bytes).check_value(bytearray(b"data")) + + @requires_py39 + def test_tuple_passes_array(self) -> None: + """Tuple is accepted as a sequence type for Array.""" + A(tuple[int, ...]).check_value((1, 2, 3)) + + def test_empty_union_is_rejected(self) -> None: + """Union requires at least 2 args.""" + with pytest.raises(ValueError, match="at least two"): + TypeSchema("Union", ()) + + def test_origin_type_index_auto_computed(self) -> None: + """origin_type_index is automatically computed from origin string.""" + schema = A(int) + assert schema.origin_type_index == 1 # kTVMFFIInt + schema = A(float) + assert schema.origin_type_index == 3 # kTVMFFIFloat + schema = A(Optional[int]) + assert schema.origin_type_index == -2 # structural + + def test_check_value_succeeds_on_valid(self) -> None: + """Test check value succeeds on valid input.""" + A(int).check_value(42) + + def test_check_value_raises_on_failure(self) -> None: + """Test check value raises TypeError on failure.""" + with pytest.raises(TypeError, match="expected int"): + A(int).check_value("hello") + + @requires_py39 + def test_tuple_type_schema(self) -> None: + """Test tuple type schema.""" + schema = A(tuple[int, str]) + schema.check_value((1, "a")) + with pytest.raises(TypeError): + schema.check_value((1, 2)) + with pytest.raises(TypeError): + schema.check_value((1,)) + + def test_numpy_int_passes_int(self) -> None: + """Numpy integer types should pass int check via Integral.""" + np = pytest.importorskip("numpy") + A(int).check_value(np.int64(42)) + A(float).check_value(np.float64(3.14)) + + +# =========================================================================== +# Type Converter Tests (convert) +# =========================================================================== + + +# --------------------------------------------------------------------------- +# Category 14: POD conversion results +# --------------------------------------------------------------------------- +class TestConvertPOD: + def test_int_passthrough(self) -> None: + """Int -> int returns the same value.""" + result = A(int).convert(42).to_py() + assert result == 42 + assert type(result) is int + + def test_bool_to_int(self) -> None: + """Bool -> int actually converts to int.""" + result = A(int).convert(True).to_py() + assert result == 1 + assert type(result) is int + + def test_bool_false_to_int(self) -> None: + """Test bool false to int.""" + result = A(int).convert(False).to_py() + assert result == 0 + assert type(result) is int + + def test_float_passthrough(self) -> None: + """Test float passthrough.""" + result = A(float).convert(3.14).to_py() + assert result == 3.14 + assert type(result) is float + + def test_int_to_float(self) -> None: + """Int -> float actually converts.""" + result = A(float).convert(42).to_py() + assert result == 42.0 + assert type(result) is float + + def test_bool_to_float(self) -> None: + """Bool -> float actually converts.""" + result = A(float).convert(True).to_py() + assert result == 1.0 + assert type(result) is float + + def test_bool_passthrough(self) -> None: + """Test bool passthrough.""" + result = A(bool).convert(True).to_py() + assert result is True + assert type(result) is bool + + def test_int_to_bool(self) -> None: + """Int -> bool actually converts.""" + result = A(bool).convert(1).to_py() + assert result is True + assert type(result) is bool + + def test_int_zero_to_bool(self) -> None: + """Test int zero to bool.""" + result = A(bool).convert(0).to_py() + assert result is False + assert type(result) is bool + + def test_str_passthrough(self) -> None: + """Test str passthrough — returns tvm_ffi.String (subclass of str).""" + result = A(str).convert("hello").to_py() + assert result == "hello" + assert isinstance(result, str) + assert isinstance(result, tvm_ffi.core.String) + + def test_bytes_passthrough(self) -> None: + """Test bytes passthrough — returns tvm_ffi.Bytes (subclass of bytes).""" + result = A(bytes).convert(b"data").to_py() + assert result == b"data" + assert isinstance(result, bytes) + assert isinstance(result, tvm_ffi.core.Bytes) + + def test_bytearray_to_bytes(self) -> None: + """Bytearray -> bytes converts to tvm_ffi.Bytes.""" + result = A(bytes).convert(bytearray(b"data")).to_py() + assert result == b"data" + assert isinstance(result, bytes) + assert isinstance(result, tvm_ffi.core.Bytes) + + +# --------------------------------------------------------------------------- +# Category 15: None disambiguation (critical design point) +# --------------------------------------------------------------------------- +class TestNoneDisambiguation: + def test_none_converts_successfully_for_none_schema(self) -> None: + """TypeSchema('None').convert(None) returns None as a valid result.""" + result = A(type(None)).convert(None).to_py() + assert result is None + + def test_none_converts_successfully_for_optional(self) -> None: + """Optional[int].convert(None) returns None as a valid result.""" + result = A(Optional[int]).convert(None).to_py() + assert result is None + + def test_none_fails_for_int(self) -> None: + """TypeSchema('int').convert(None) raises TypeError.""" + with pytest.raises(TypeError, match="expected int, got None"): + A(int).convert(None) + + def test_convert_none_success(self) -> None: + """Convert returns None for Optional[int] with None input.""" + result = A(Optional[int]).convert(None).to_py() + assert result is None + + def test_convert_none_failure(self) -> None: + """Convert raises TypeError for failed conversion.""" + with pytest.raises(TypeError, match="expected int"): + A(int).convert(None) + + def test_convert_success_with_value(self) -> None: + """Convert returns converted value on success.""" + result = A(int).convert(True).to_py() + assert result == 1 + assert type(result) is int + + def test_opaque_ptr_none_converts(self) -> None: + """ctypes.c_void_p accepts None and converts it to a null opaque pointer.""" + result = A(ctypes.c_void_p).convert(None).to_py() + assert isinstance(result, ctypes.c_void_p) + assert result.value is None + + def test_convert_opaque_ptr_none(self) -> None: + """Test convert opaque ptr none.""" + result = A(ctypes.c_void_p).convert(None).to_py() + assert isinstance(result, ctypes.c_void_p) + assert result.value is None + + +# --------------------------------------------------------------------------- +# Category 16: Special type conversions +# --------------------------------------------------------------------------- +class TestConvertSpecialTypes: + def test_dtype_str_converts(self) -> None: + """Str -> dtype actually creates a DataType object.""" + result = A(tvm_ffi.core.DataType).convert("float32").to_py() + assert isinstance(result, tvm_ffi.core.DataType) + assert str(result) == "float32" + + def test_dtype_passthrough(self) -> None: + """Test dtype passthrough.""" + dt = tvm_ffi.core.DataType("int32") + result = A(tvm_ffi.core.DataType).convert(dt).to_py() + assert str(result) == str(dt) + + def test_device_passthrough(self) -> None: + """Test device passthrough.""" + dev = tvm_ffi.Device("cpu", 0) + result = A(tvm_ffi.Device).convert(dev).to_py() + assert str(result) == str(dev) + + def test_callable_passthrough(self) -> None: + """Test callable passthrough.""" + fn = lambda x: x + result = A(Callable).convert(fn).to_py() + assert callable(result) + + def test_opaque_ptr_passthrough(self) -> None: + """Test opaque ptr passthrough.""" + ptr = ctypes.c_void_p(42) + result = A(ctypes.c_void_p).convert(ptr).to_py() + assert result is not None + + +# --------------------------------------------------------------------------- +# Category 17: Container conversion results +# --------------------------------------------------------------------------- +class TestConvertContainers: + @requires_py39 + def test_array_converts_bool_elements_to_int(self) -> None: + """Array[int] with bool elements converts them to int.""" + result = A(tuple[int, ...]).convert([True, False, 1]).to_py() + assert list(result) == [1, 0, 1] + assert all(type(x) is int for x in result) + + @requires_py39 + def test_array_int_passthrough(self) -> None: + """Array[int] with int elements returns ffi.Array.""" + result = A(tuple[int, ...]).convert([1, 2, 3]).to_py() + assert list(result) == [1, 2, 3] + + @requires_py39 + def test_array_any_passthrough(self) -> None: + """Array[Any] wraps into ffi.Array.""" + original = [1, "x", None] + result = A(tuple[typing.Any, ...]).convert(original).to_py() + assert isinstance(result, tvm_ffi.Array) + + @requires_py39 + def test_map_converts_values(self) -> None: + """Map[str, float] converts int values to float.""" + result = A(tvm_ffi.Map[str, float]).convert({"a": 1, "b": 2}).to_py() + assert isinstance(result, tvm_ffi.Map) + assert type(result["a"]) is float + assert type(result["b"]) is float + assert result["a"] == 1.0 + assert result["b"] == 2.0 + + @requires_py39 + def test_map_any_float_converts_values(self) -> None: + """Map[Any, float] still converts values when keys are Any.""" + result = A(tvm_ffi.Map[typing.Any, float]).convert({"a": 1, "b": 2}).to_py() + assert isinstance(result, tvm_ffi.Map) + assert type(result["a"]) is float + + @requires_py39 + def test_map_any_any_passthrough(self) -> None: + """Map[Any, Any] wraps into ffi.Map.""" + original = {"a": 1} + result = A(tvm_ffi.Map[typing.Any, typing.Any]).convert(original).to_py() + assert isinstance(result, tvm_ffi.Map) + + @requires_py39 + def test_map_empty_dict_convert(self) -> None: + """Empty dict converts to Map[str, int].""" + result = A(tvm_ffi.Map[str, int]).convert({}).to_py() + assert len(result) == 0 + + @requires_py39 + def test_dict_empty_dict_convert(self) -> None: + """Empty dict converts to Dict[str, int].""" + result = A(dict[str, int]).convert({}).to_py() + assert len(result) == 0 + + @requires_py39 + def test_tuple_converts_elements(self) -> None: + """tuple[int, float] converts elements positionally.""" + result = A(tuple[int, float]).convert((True, 42)).to_py() + assert list(result) == [1, 42.0] + assert type(result[0]) is int + assert type(result[1]) is float + + @requires_py39 + def test_nested_array_in_map(self) -> None: + """Map[str, Array[int]] recursively converts elements.""" + result = A(tvm_ffi.Map[str, tuple[int, ...]]).convert({"a": [True, False]}).to_py() + assert isinstance(result, tvm_ffi.Map) + assert list(result["a"]) == [1, 0] + assert all(type(x) is int for x in result["a"]) + + @requires_py39 + def test_array_optional_int_all_none(self) -> None: + """Array[Optional[int]] accepts an all-None payload.""" + result = A(tuple[Optional[int], ...]).convert([None, None, None]).to_py() + assert list(result) == [None, None, None] + + +# --------------------------------------------------------------------------- +# Category 18: Optional/Union conversion results +# --------------------------------------------------------------------------- +class TestConvertComposite: + def test_optional_converts_inner(self) -> None: + """Optional[float].convert(42) converts int -> float.""" + result = A(Optional[float]).convert(42).to_py() + assert result == 42.0 + assert type(result) is float + + def test_optional_none(self) -> None: + """Test optional none.""" + result = A(Optional[float]).convert(None).to_py() + assert result is None + + def test_union_picks_first_match(self) -> None: + """Union[int, str] converts bool via int alternative.""" + result = A(Union[int, str]).convert(True).to_py() + assert result == 1 + assert type(result) is int + + def test_union_second_match(self) -> None: + """Test union second match.""" + result = A(Union[int, str]).convert("hello").to_py() + assert result == "hello" + + def test_any_passthrough(self) -> None: + """Any returns value as-is.""" + result = A(typing.Any).convert(42).to_py() + assert result == 42 + result = A(typing.Any).convert(None).to_py() + assert result is None + + +# --------------------------------------------------------------------------- +# Category 19: Convert rejection cases +# --------------------------------------------------------------------------- +class TestConvertRejections: + def test_int_rejects_str(self) -> None: + """Test int rejects str.""" + with pytest.raises(TypeError, match="expected int, got str"): + A(int).convert("hello") + + def test_int_rejects_float(self) -> None: + """Test int rejects float.""" + with pytest.raises(TypeError, match="expected int, got float"): + A(int).convert(3.14) + + def test_str_rejects_int(self) -> None: + """Test str rejects int.""" + with pytest.raises(TypeError, match="expected str, got int"): + A(str).convert(42) + + @requires_py39 + def test_array_rejects_wrong_element(self) -> None: + """Test array rejects wrong element.""" + with pytest.raises(TypeError, match=r"element \[1\].*expected int, got str"): + A(tuple[int, ...]).convert([1, "x"]) + + @requires_py39 + def test_map_rejects_wrong_value(self) -> None: + """Test map rejects wrong value.""" + with pytest.raises(TypeError, match=r"value for key 'a'.*expected int, got str"): + A(tvm_ffi.Map[str, int]).convert({"a": "x"}) + + @requires_py39 + def test_tuple_rejects_wrong_length(self) -> None: + """Test tuple rejects wrong length.""" + with pytest.raises(TypeError, match=r"expected tuple of length 2"): + A(tuple[int, str]).convert((1,)) + + def test_convert_failure_raises(self) -> None: + """Test convert failure raises TypeError.""" + with pytest.raises(TypeError, match="expected int"): + A(int).convert("hello") + + +# --------------------------------------------------------------------------- +# Category 20: Numpy conversion +# --------------------------------------------------------------------------- +class TestConvertNumpy: + def test_numpy_int_to_int(self) -> None: + """Test numpy int to int.""" + np = pytest.importorskip("numpy") + result = A(int).convert(np.int64(42)).to_py() + assert result == 42 + assert type(result) is int + + def test_numpy_float_to_float(self) -> None: + """Test numpy float to float.""" + np = pytest.importorskip("numpy") + result = A(float).convert(np.float64(3.14)).to_py() + assert result == pytest.approx(3.14) + # np.float64 is a subclass of float, so isinstance check passes + # and the value is returned as-is (no forced conversion to plain float) + assert isinstance(result, float) + + +# =========================================================================== +# Nested Conversion Tests (with inner-level conversions) +# =========================================================================== + + +# --------------------------------------------------------------------------- +# Category 21: Array nested with Optional/Union (inner conversion) +# --------------------------------------------------------------------------- +class TestNestedArrayComposite: + @requires_py39 + def test_array_optional_float_with_bool(self) -> None: + """Array[Optional[float]] converts bool elements to float.""" + result = A(tuple[Optional[float], ...]).convert([True, None, 3]).to_py() + assert list(result) == [1.0, None, 3.0] + assert type(result[0]) is float + assert result[1] is None + assert type(result[2]) is float + + @requires_py39 + def test_array_optional_int_with_bool(self) -> None: + """Array[Optional[int]] converts bool elements to int.""" + result = A(tuple[Optional[int], ...]).convert([True, None, 2]).to_py() + assert list(result) == [1, None, 2] + assert type(result[0]) is int + assert result[1] is None + + @requires_py39 + def test_array_union_int_str_with_bool(self) -> None: + """Array[Union[int, str]] converts bool via int alternative.""" + result = A(tuple[Union[int, str], ...]).convert([True, "hello", False]).to_py() + assert list(result) == [1, "hello", 0] + assert type(result[0]) is int + assert type(result[1]) is str + assert type(result[2]) is int + + @requires_py39 + def test_array_union_float_str_with_int(self) -> None: + """Array[Union[float, str]] converts int via float alternative.""" + result = A(tuple[Union[float, str], ...]).convert([42, "hi", True]).to_py() + assert list(result) == [42.0, "hi", 1.0] + assert type(result[0]) is float + assert type(result[2]) is float + + @requires_py39 + def test_array_optional_float_all_none(self) -> None: + """Array[Optional[float]] with all None elements.""" + result = A(tuple[Optional[float], ...]).convert([None, None]).to_py() + assert list(result) == [None, None] + + @requires_py39 + def test_array_optional_float_empty(self) -> None: + """Array[Optional[float]] with empty list.""" + result = A(tuple[Optional[float], ...]).convert([]).to_py() + assert list(result) == [] + + @requires_py39 + def test_array_union_failure_in_element(self) -> None: + """Array[Union[int, str]] fails when element matches no alternative.""" + with pytest.raises(TypeError, match=r"element \[1\].*got float"): + A(tuple[Union[int, str], ...]).check_value([1, 3.14]) + + +# --------------------------------------------------------------------------- +# Category 22: Map/Dict nested with Optional/Union (inner conversion) +# --------------------------------------------------------------------------- +class TestNestedMapComposite: + @requires_py39 + def test_map_str_optional_float_with_int(self) -> None: + """Map[str, Optional[float]] converts int values to float.""" + result = A(tvm_ffi.Map[str, Optional[float]]).convert({"a": 1, "b": None}).to_py() + assert type(result["a"]) is float + assert result["a"] == 1.0 + assert result["b"] is None + + @requires_py39 + def test_map_str_union_int_str(self) -> None: + """Map[str, Union[int, str]] converts bool values via int.""" + result = A(tvm_ffi.Map[str, Union[int, str]]).convert({"x": True, "y": "hello"}).to_py() + assert result["x"] == 1 + assert result["y"] == "hello" + assert type(result["x"]) is int + + @requires_py39 + def test_dict_str_optional_int(self) -> None: + """Dict[str, Optional[int]] with bool conversion.""" + result = A(dict[str, Optional[int]]).convert({"a": True, "b": None, "c": 42}).to_py() + assert result["a"] == 1 + assert result["b"] is None + assert result["c"] == 42 + assert type(result["a"]) is int + + @requires_py39 + def test_map_str_optional_float_failure(self) -> None: + """Map[str, Optional[float]] fails for non-float non-None value.""" + with pytest.raises(TypeError, match="expected float"): + A(tvm_ffi.Map[str, Optional[float]]).check_value({"a": "bad"}) + + +# --------------------------------------------------------------------------- +# Category 23: Nested containers (container inside container) +# --------------------------------------------------------------------------- +class TestNestedContainerInContainer: + @requires_py39 + def test_array_of_array_int(self) -> None: + """Array[Array[int]] with inner bool->int conversion.""" + result = A(tuple[tuple[int, ...], ...]).convert([[True, False], [1, 2]]).to_py() + assert [list(row) for row in result] == [[1, 0], [1, 2]] + assert all(type(x) is int for row in result for x in row) + + @requires_py39 + def test_array_of_array_float(self) -> None: + """Array[Array[float]] with inner int->float conversion.""" + result = A(tuple[tuple[float, ...], ...]).convert([[1, 2], [True, 3]]).to_py() + assert [list(row) for row in result] == [[1.0, 2.0], [1.0, 3.0]] + assert all(type(x) is float for row in result for x in row) + + @requires_py39 + def test_map_str_array_float(self) -> None: + """Map[str, Array[float]] with int->float conversion in arrays.""" + result = ( + A(tvm_ffi.Map[str, tuple[float, ...]]).convert({"a": [1, 2], "b": [True, 3]}).to_py() + ) + assert list(result["a"]) == [1.0, 2.0] + assert list(result["b"]) == [1.0, 3.0] + assert all(type(x) is float for x in result["a"]) + assert all(type(x) is float for x in result["b"]) + + @requires_py39 + def test_dict_str_array_int(self) -> None: + """Dict[str, Array[int]] with bool->int conversion.""" + result = A(dict[str, tuple[int, ...]]).convert({"a": [True, False]}).to_py() + assert list(result["a"]) == [1, 0] + assert all(type(x) is int for x in result["a"]) + + @requires_py39 + def test_array_of_map_str_int(self) -> None: + """Array[Map[str, int]] with bool->int value conversion.""" + result = A(tuple[tvm_ffi.Map[str, int], ...]).convert([{"x": True}, {"y": 2}]).to_py() + assert result[0]["x"] == 1 + assert result[1]["y"] == 2 + assert type(result[0]["x"]) is int + + @requires_py39 + def test_map_str_map_str_float(self) -> None: + """Map[str, Map[str, float]] double nested with int->float.""" + result = ( + A(tvm_ffi.Map[str, tvm_ffi.Map[str, float]]).convert({"outer": {"inner": 42}}).to_py() + ) + assert result["outer"]["inner"] == 42.0 + assert type(result["outer"]["inner"]) is float + + @requires_py39 + def test_list_of_list_int(self) -> None: + """List[List[int]] with bool->int conversion.""" + result = A(list[list[int]]).convert([[True, 1], [False, 2]]).to_py() + assert [list(row) for row in result] == [[1, 1], [0, 2]] + assert all(type(x) is int for row in result for x in row) + + @requires_py39 + def test_nested_failure_array_of_array(self) -> None: + """Array[Array[int]] error propagation through nested arrays.""" + with pytest.raises(TypeError, match="expected int"): + A(tuple[tuple[int, ...], ...]).check_value([[1, 2], [3, "bad"]]) + + @requires_py39 + def test_empty_inner_containers(self) -> None: + """Map[str, Array[int]] with empty inner arrays.""" + result = A(tvm_ffi.Map[str, tuple[int, ...]]).convert({"a": [], "b": []}).to_py() + assert list(result["a"]) == [] + assert list(result["b"]) == [] + + @requires_py39 + def test_array_of_array_of_array_int(self) -> None: + """Three-level nested Array[int] conversion still works.""" + schema = A(tuple[tuple[tuple[int, ...], ...], ...]) + data = [[[1, 2], [True, False]], [[3], [4, 5, 6]]] + result = schema.convert(data).to_py() + assert list(result[0][0]) == [1, 2] + assert list(result[0][1]) == [1, 0] + assert type(result[0][1][0]) is int + + @requires_py39 + def test_map_of_map_of_array_float(self) -> None: + """Nested map-to-array conversion still coerces inner values.""" + schema = A(tvm_ffi.Map[str, tvm_ffi.Map[str, tuple[float, ...]]]) + data = {"outer": {"inner": [1, 2, True]}} + result = schema.convert(data).to_py() + assert list(result["outer"]["inner"]) == [1.0, 2.0, 1.0] + assert type(result["outer"]["inner"][0]) is float + + +# --------------------------------------------------------------------------- +# Category 24: Optional/Union wrapping containers +# --------------------------------------------------------------------------- +class TestOptionalUnionWrappingContainers: + @requires_py39 + def test_optional_array_int_with_conversion(self) -> None: + """Optional[Array[int]] converts inner bool elements.""" + schema = A(Optional[tuple[int, ...]]) + result = schema.convert([True, 2]).to_py() + assert list(result) == [1, 2] + assert type(result[0]) is int + + @requires_py39 + def test_optional_array_int_none(self) -> None: + """Optional[Array[int]] accepts None.""" + result = A(Optional[tuple[int, ...]]).convert(None).to_py() + assert result is None + + @requires_py39 + def test_optional_map_str_float(self) -> None: + """Optional[Map[str, float]] converts inner int values.""" + result = A(Optional[tvm_ffi.Map[str, float]]).convert({"a": 1}).to_py() + assert result["a"] == 1.0 + assert type(result["a"]) is float + + @requires_py39 + def test_optional_map_str_float_none(self) -> None: + """Optional[Map[str, float]] accepts None.""" + result = A(Optional[tvm_ffi.Map[str, float]]).convert(None).to_py() + assert result is None + + @requires_py39 + def test_union_array_int_or_map_str_int(self) -> None: + """Union[Array[int], Map[str, int]] matches first with conversion.""" + schema = A(Union[tuple[int, ...], tvm_ffi.Map[str, int]]) + # list matches Array alternative + result = schema.convert([True, 2]).to_py() + assert list(result) == [1, 2] + assert type(result[0]) is int + + @requires_py39 + def test_union_array_int_or_map_str_int_dict(self) -> None: + """Union[Array[int], Map[str, int]] matches Map for dict input.""" + schema = A(Union[tuple[int, ...], tvm_ffi.Map[str, int]]) + result = schema.convert({"a": True}).to_py() + assert result["a"] == 1 + assert type(result["a"]) is int + + @requires_py39 + def test_union_int_or_array_optional_float(self) -> None: + """Union[int, Array[Optional[float]]] matches array with nested conversions.""" + schema = A(Union[int, tuple[Optional[float], ...]]) + result = schema.convert([True, None, 1]).to_py() + assert list(result) == [1.0, None, 1.0] + assert type(result[0]) is float + assert result[1] is None + + @requires_py39 + def test_optional_optional_array_int(self) -> None: + """Optional[Optional[Array[int]]] with inner conversion.""" + schema = A(Optional[Optional[tuple[int, ...]]]) + assert schema.convert(None).to_py() is None + result = schema.convert([True, 2]).to_py() + assert list(result) == [1, 2] + assert type(result[0]) is int + + +# --------------------------------------------------------------------------- +# Category 25: Tuple nested with other types +# --------------------------------------------------------------------------- +class TestNestedTuple: + @requires_py39 + def test_array_of_tuple_int_float(self) -> None: + """Array[tuple[int, float]] with element-wise conversion.""" + result = A(tuple[tuple[int, float], ...]).convert([(True, 1), (2, True)]).to_py() + # Check element values; FFI storage may normalize float 1.0 to int 1 + # when stored inside an ffi.Array, so we only check values not types. + assert result[0][0] == 1 + assert result[0][1] == 1.0 + assert result[1][0] == 2 + assert result[1][1] == 1.0 + + @requires_py39 + def test_map_str_tuple_int_str(self) -> None: + """Map[str, tuple[int, str]] with inner bool->int conversion.""" + result = A(tvm_ffi.Map[str, tuple[int, str]]).convert({"a": (True, "hello")}).to_py() + assert result["a"][0] == 1 + assert str(result["a"][1]) == "hello" + assert type(result["a"][0]) is int + + @requires_py39 + def test_tuple_of_array_int_and_map(self) -> None: + """tuple[Array[int], Map[str, float]] nested conversion.""" + schema = A(tuple[tuple[int, ...], tvm_ffi.Map[str, float]]) + result = schema.convert(([True, 2], {"k": 3})).to_py() + assert list(result[0]) == [1, 2] + assert result[1]["k"] == 3.0 + assert type(result[0][0]) is int + assert type(result[1]["k"]) is float + + @requires_py39 + def test_tuple_of_optional_int_and_optional_float(self) -> None: + """tuple[Optional[int], Optional[float]] with conversions.""" + schema = A(tuple[Optional[int], Optional[float]]) + result = schema.convert((True, None)).to_py() + assert list(result) == [1, None] + assert type(result[0]) is int + assert result[1] is None + + @requires_py39 + def test_tuple_nested_failure(self) -> None: + """tuple[Array[int], str] error propagation from inner array.""" + with pytest.raises(TypeError, match=r"element .0..*element .1..*expected int"): + A(tuple[tuple[int, ...], str]).check_value(([1, "bad"], "ok")) + + +# --------------------------------------------------------------------------- +# Category 26: Deep nesting (3+ levels) +# --------------------------------------------------------------------------- +class TestDeepNesting: + @requires_py39 + def test_map_str_array_optional_int(self) -> None: + """Map[str, Array[Optional[int]]] with 3-level nesting and conversion.""" + result = ( + A(tvm_ffi.Map[str, tuple[Optional[int], ...]]).convert({"a": [1, None, True]}).to_py() + ) + assert list(result["a"]) == [1, None, 1] + assert type(result["a"][0]) is int + assert result["a"][1] is None + assert type(result["a"][2]) is int + + @requires_py39 + def test_array_map_str_optional_float(self) -> None: + """Array[Map[str, Optional[float]]] with 3-level nesting.""" + result = ( + A(tuple[tvm_ffi.Map[str, Optional[float]], ...]) + .convert([{"x": 1, "y": None}, {"z": True}]) + .to_py() + ) + assert result[0]["x"] == 1.0 + assert result[0]["y"] is None + assert result[1]["z"] == 1.0 + assert type(result[0]["x"]) is float + assert type(result[1]["z"]) is float + + @requires_py39 + def test_optional_array_map_str_int(self) -> None: + """Optional[Array[Map[str, int]]] 3 levels deep.""" + schema = A(Optional[tuple[tvm_ffi.Map[str, int], ...]]) + result = schema.convert([{"a": True}, {"b": 2}]).to_py() + assert result[0]["a"] == 1 + assert result[1]["b"] == 2 + assert type(result[0]["a"]) is int + + assert schema.convert(None).to_py() is None + + @requires_py39 + def test_map_str_array_array_int(self) -> None: + """Map[str, Array[Array[int]]] 3-level container nesting.""" + result = ( + A(tvm_ffi.Map[str, tuple[tuple[int, ...], ...]]) + .convert({"m": [[True, 1], [False, 2]]}) + .to_py() + ) + assert [list(row) for row in result["m"]] == [[1, 1], [0, 2]] + assert all(type(x) is int for row in result["m"] for x in row) + + @requires_py39 + def test_array_array_optional_float(self) -> None: + """Array[Array[Optional[float]]] deep nesting with None and conversion.""" + result = ( + A(tuple[tuple[Optional[float], ...], ...]).convert([[1, None], [True, 3.14]]).to_py() + ) + assert list(result[0]) == [1.0, None] + assert list(result[1]) == [1.0, 3.14] + assert type(result[0][0]) is float + assert result[0][1] is None + assert type(result[1][0]) is float + + @requires_py39 + def test_deep_nesting_failure_propagation(self) -> None: + """Error from deepest level propagates with full path info.""" + with pytest.raises(TypeError, match=r"value for key 'key'.*element .1..*expected int"): + A(tvm_ffi.Map[str, tuple[Optional[int], ...]]).check_value({"key": [1, "bad"]}) + + +# --------------------------------------------------------------------------- +# Category 27: FFI container inputs (tvm_ffi.Array/List/Map/Dict) +# --------------------------------------------------------------------------- +class TestFFIContainerInputs: + @requires_py39 + def test_ffi_array_with_element_conversion(self) -> None: + """tvm_ffi.Array([True, 2]) passes Array[int] with bool->int conversion.""" + arr = tvm_ffi.Array([True, 2, 3]) + result = A(tuple[int, ...]).convert(arr).to_py() + assert list(result) == [1, 2, 3] + assert type(result[0]) is int + + @requires_py39 + def test_ffi_array_any_passthrough(self) -> None: + """tvm_ffi.Array passes Array[Any] as-is.""" + arr = tvm_ffi.Array([1, "x", None]) + result = A(tuple[typing.Any, ...]).convert(arr).to_py() + assert result.same_as(arr) + + @requires_py39 + def test_ffi_list_with_list_schema(self) -> None: + """tvm_ffi.List passes List[int] with conversion.""" + lst = tvm_ffi.List([True, 2]) + result = A(list[int]).convert(lst).to_py() + assert list(result) == [1, 2] + assert type(result[0]) is int + + @requires_py39 + def test_ffi_list_accepted_by_array_schema(self) -> None: + """tvm_ffi.List passes Array schema (C++ allows cross-type via kOtherTypeIndex).""" + lst = tvm_ffi.List([1, 2]) + A(tuple[int, ...]).check_value(lst) + + @requires_py39 + def test_ffi_array_accepted_by_list_schema(self) -> None: + """tvm_ffi.Array passes List schema (C++ allows cross-type via kOtherTypeIndex).""" + arr = tvm_ffi.Array([1, 2]) + A(list[int]).check_value(arr) + + @requires_py39 + def test_ffi_map_with_value_conversion(self) -> None: + """tvm_ffi.Map passes Map[str, int] with bool->int conversion.""" + m = tvm_ffi.Map({"a": True, "b": 2}) + result = A(tvm_ffi.Map[str, int]).convert(m).to_py() + assert result["a"] == 1 + assert result["b"] == 2 + assert type(result["a"]) is int + + @requires_py39 + def test_ffi_map_any_any_passthrough(self) -> None: + """tvm_ffi.Map passes Map[Any, Any] as-is.""" + m = tvm_ffi.Map({"a": 1}) + result = A(tvm_ffi.Map[typing.Any, typing.Any]).convert(m).to_py() + assert result.same_as(m) + + @requires_py39 + def test_ffi_dict_with_dict_schema(self) -> None: + """tvm_ffi.Dict passes Dict[str, float] with int->float conversion.""" + d = tvm_ffi.Dict({"x": 1, "y": 2}) + result = A(dict[str, float]).convert(d).to_py() + assert result["x"] == 1.0 + assert result["y"] == 2.0 + assert type(result["x"]) is float + + @requires_py39 + def test_ffi_dict_accepted_by_map_schema(self) -> None: + """tvm_ffi.Dict passes Map schema (C++ allows cross-type via kOtherTypeIndex).""" + d = tvm_ffi.Dict({"a": 1}) + A(tvm_ffi.Map[str, int]).check_value(d) + + @requires_py39 + def test_ffi_map_accepted_by_dict_schema(self) -> None: + """tvm_ffi.Map passes Dict schema (C++ allows cross-type via kOtherTypeIndex).""" + m = tvm_ffi.Map({"a": 1}) + A(dict[str, int]).check_value(m) + + @requires_py39 + def test_ffi_array_nested_optional_float(self) -> None: + """tvm_ffi.Array with nested Optional[float] conversion.""" + arr = tvm_ffi.Array([1, None, True]) + result = A(tuple[Optional[float], ...]).convert(arr).to_py() + assert list(result) == [1.0, None, 1.0] + assert type(result[0]) is float + assert result[1] is None + + @requires_py39 + def test_ffi_map_nested_array_int(self) -> None: + """tvm_ffi.Map with value being a Python list, converted as Array[int].""" + # Map values are already stored; create a map with array values + m = tvm_ffi.Map({"k": tvm_ffi.Array([True, 2])}) + result = A(tvm_ffi.Map[str, tuple[int, ...]]).convert(m).to_py() + assert list(result["k"]) == [1, 2] + assert type(result["k"][0]) is int + + @requires_py39 + def test_ffi_array_wrong_element_type(self) -> None: + """tvm_ffi.Array with wrong element type gives clear error.""" + arr = tvm_ffi.Array([1, "bad", 3]) + with pytest.raises(TypeError, match=r"element \[1\].*expected int"): + A(tuple[int, ...]).check_value(arr) + + @requires_py39 + def test_ffi_map_wrong_value_type(self) -> None: + """tvm_ffi.Map with wrong value type gives clear error.""" + m = tvm_ffi.Map({"a": 1, "b": "bad"}) + with pytest.raises(TypeError, match=r"value for key.*expected int"): + A(tvm_ffi.Map[str, int]).check_value(m) + + def test_ffi_array_object_schema(self) -> None: + """tvm_ffi.Array passes Object schema (it is a CObject).""" + arr = tvm_ffi.Array([1, 2]) + A(tvm_ffi.core.Object).check_value(arr) + + def test_ffi_map_object_schema(self) -> None: + """tvm_ffi.Map passes Object schema (it is a CObject).""" + m = tvm_ffi.Map({"a": 1}) + A(tvm_ffi.core.Object).check_value(m) + + +# --------------------------------------------------------------------------- +# Category 28: Mixed Python and FFI containers in nesting +# --------------------------------------------------------------------------- +class TestMixedPythonFFIContainers: + @requires_py39 + def test_python_list_of_ffi_arrays(self) -> None: + """Python list containing tvm_ffi.Array elements, Array[Array[int]].""" + inner1 = tvm_ffi.Array([True, 2]) + inner2 = tvm_ffi.Array([3, False]) + result = A(tuple[tuple[int, ...], ...]).convert([inner1, inner2]).to_py() + assert [list(row) for row in result] == [[1, 2], [3, 0]] + + @requires_py39 + def test_python_dict_with_ffi_array_values(self) -> None: + """Python dict with tvm_ffi.Array values, Map[str, Array[float]].""" + val = tvm_ffi.Array([1, True]) + result = A(tvm_ffi.Map[str, tuple[float, ...]]).convert({"k": val}).to_py() + assert list(result["k"]) == [1.0, 1.0] + assert all(type(x) is float for x in result["k"]) + + @requires_py39 + def test_ffi_map_with_python_list_in_union(self) -> None: + """Union[Map[str, int], Array[int]] with tvm_ffi.Map input.""" + schema = A(Union[tvm_ffi.Map[str, int], tuple[int, ...]]) + m = tvm_ffi.Map({"a": True}) + result = schema.convert(m).to_py() + assert result["a"] == 1 + assert type(result["a"]) is int + + @requires_py39 + def test_ffi_array_in_optional(self) -> None: + """Optional[Array[int]] with tvm_ffi.Array input.""" + arr = tvm_ffi.Array([True, 2]) + result = A(Optional[tuple[int, ...]]).convert(arr).to_py() + assert list(result) == [1, 2] + assert type(result[0]) is int + + +# --------------------------------------------------------------------------- +# Category 29: Error propagation through deeply nested FFI containers +# --------------------------------------------------------------------------- +class TestNestedErrorPropagation: + @requires_py39 + def test_array_array_int_inner_failure(self) -> None: + """Error path: Array[Array[int]] -> element [1] -> element [0].""" + with pytest.raises(TypeError, match=r"element \[1\].*element \[0\].*expected int, got str"): + A(tuple[tuple[int, ...], ...]).convert([[1], ["bad"]]) + + @requires_py39 + def test_map_array_int_inner_failure(self) -> None: + """Error path: Map -> value for key 'k' -> element [2].""" + with pytest.raises( + TypeError, + match=r"value for key 'k'.*element \[2\].*expected int, got str", + ): + A(tvm_ffi.Map[str, tuple[int, ...]]).convert({"k": [1, 2, "bad"]}) + + @requires_py39 + def test_array_map_int_inner_failure(self) -> None: + """Error path: Array -> element [0] -> value for key 'x'.""" + with pytest.raises( + TypeError, + match=r"element \[0\].*value for key 'x'.*expected int, got str", + ): + A(tuple[tvm_ffi.Map[str, int], ...]).convert([{"x": "bad"}]) + + @requires_py39 + def test_optional_array_int_inner_failure(self) -> None: + """Error path through Optional -> Array -> element.""" + with pytest.raises(TypeError, match=r"element \[1\].*expected int, got str"): + A(Optional[tuple[int, ...]]).convert([1, "bad"]) + + @requires_py39 + def test_tuple_array_int_inner_failure(self) -> None: + """Error path: tuple -> element [0] -> element [1].""" + with pytest.raises(TypeError, match=r"element \[0\].*element \[1\].*expected int, got str"): + A(tuple[tuple[int, ...], str]).convert(([1, "bad"], "ok")) + + @requires_py39 + def test_deep_3_level_error(self) -> None: + """Error at 3 levels deep: Map -> Array -> Optional -> type mismatch.""" + with pytest.raises(TypeError, match=r"value for key 'key'.*element .1..*expected int"): + A(tvm_ffi.Map[str, tuple[Optional[int], ...]]).check_value({"key": [1, "bad"]}) + + @requires_py39 + def test_ffi_array_nested_error(self) -> None: + """Error from tvm_ffi.Array in nested context.""" + arr = tvm_ffi.Array([1, "bad", 3]) + with pytest.raises(TypeError, match=r"element \[1\].*expected int"): + A(tuple[int, ...]).convert(arr) + + +# --------------------------------------------------------------------------- +# Category 30: Custom object type exact match +# --------------------------------------------------------------------------- +class TestCustomObjectExactMatch: + def test_test_int_pair_pass(self) -> None: + """TestIntPair passes TypeSchema('testing.TestIntPair').""" + obj = TestIntPair(1, 2) + A(TestIntPair).check_value(obj) + + def test_test_object_base_pass(self) -> None: + """TestObjectBase passes its own schema.""" + obj = TestObjectBase(v_i64=10, v_f64=1.5, v_str="hi") + A(TestObjectBase).check_value(obj) + + def test_test_object_derived_pass(self) -> None: + """TestObjectDerived passes its own schema.""" + obj = TestObjectDerived(v_map={"a": 1}, v_array=[1], v_i64=0, v_f64=0.0, v_str="") + A(TestObjectDerived).check_value(obj) + + def test_cxx_class_base_pass(self) -> None: + """_TestCxxClassBase passes its own schema.""" + obj = _TestCxxClassBase(v_i64=1, v_i32=2) + A(_TestCxxClassBase).check_value(obj) + + def test_cxx_class_derived_pass(self) -> None: + """_TestCxxClassDerived passes its own schema.""" + obj = _TestCxxClassDerived(v_i64=1, v_i32=2, v_f64=3.0) + A(_TestCxxClassDerived).check_value(obj) + + def test_cxx_class_derived_derived_pass(self) -> None: + """_TestCxxClassDerivedDerived passes its own schema.""" + obj = _TestCxxClassDerivedDerived(v_i64=1, v_i32=2, v_f64=3.0, v_bool=True) + A(_TestCxxClassDerivedDerived).check_value(obj) + + +# --------------------------------------------------------------------------- +# Category 31: Custom object type hierarchy (subclass passes parent schema) +# --------------------------------------------------------------------------- +class TestCustomObjectHierarchy: + def test_derived_passes_base_schema(self) -> None: + """TestObjectDerived passes TypeSchema('testing.TestObjectBase').""" + obj = TestObjectDerived(v_map={"a": 1}, v_array=[1], v_i64=0, v_f64=0.0, v_str="") + A(TestObjectBase).check_value(obj) + + def test_derived_passes_object_schema(self) -> None: + """TestObjectDerived passes TypeSchema('Object').""" + obj = TestObjectDerived(v_map={"a": 1}, v_array=[1], v_i64=0, v_f64=0.0, v_str="") + A(tvm_ffi.core.Object).check_value(obj) + + def test_cxx_derived_passes_base(self) -> None: + """_TestCxxClassDerived passes TestCxxClassBase schema.""" + obj = _TestCxxClassDerived(v_i64=1, v_i32=2, v_f64=3.0) + A(_TestCxxClassBase).check_value(obj) + + def test_cxx_derived_derived_passes_base(self) -> None: + """_TestCxxClassDerivedDerived passes TestCxxClassBase schema (2-level up).""" + obj = _TestCxxClassDerivedDerived(v_i64=1, v_i32=2, v_f64=3.0, v_bool=True) + A(_TestCxxClassBase).check_value(obj) + + def test_cxx_derived_derived_passes_derived(self) -> None: + """_TestCxxClassDerivedDerived passes TestCxxClassDerived schema (1-level up).""" + obj = _TestCxxClassDerivedDerived(v_i64=1, v_i32=2, v_f64=3.0, v_bool=True) + A(_TestCxxClassDerived).check_value(obj) + + def test_all_custom_objects_pass_object_schema(self) -> None: + """Every custom object passes the generic Object schema.""" + objs = [ + TestIntPair(1, 2), + TestObjectBase(v_i64=10, v_f64=1.5, v_str="hi"), + _TestCxxClassBase(v_i64=1, v_i32=2), + _TestCxxClassDerived(v_i64=1, v_i32=2, v_f64=3.0), + _TestCxxClassDerivedDerived(v_i64=1, v_i32=2, v_f64=3.0, v_bool=True), + ] + schema = A(tvm_ffi.core.Object) + for obj in objs: + schema.check_value(obj) + + +# --------------------------------------------------------------------------- +# Category 32: Custom object type rejection +# --------------------------------------------------------------------------- +class TestCustomObjectRejection: + def test_wrong_object_type(self) -> None: + """TestIntPair fails TypeSchema('testing.TestObjectBase').""" + obj = TestIntPair(1, 2) + with pytest.raises(TypeError, match=r"testing.TestIntPair"): + A(TestObjectBase).check_value(obj) + + def test_base_fails_derived_schema(self) -> None: + """Parent object fails child schema (TestObjectBase fails TestObjectDerived).""" + obj = TestObjectBase(v_i64=10, v_f64=1.5, v_str="hi") + with pytest.raises(TypeError, match=r"testing.TestObjectBase"): + A(TestObjectDerived).check_value(obj) + + def test_non_object_fails_custom_schema(self) -> None: + """Plain int fails custom object schema.""" + with pytest.raises(TypeError, match=r"expected testing\.TestIntPair.*got int"): + A(TestIntPair).check_value(42) + + def test_none_fails_custom_schema(self) -> None: + """None fails custom object schema.""" + with pytest.raises(TypeError, match="got None"): + A(TestIntPair).check_value(None) + + def test_string_fails_custom_schema(self) -> None: + """String fails custom object schema.""" + with pytest.raises(TypeError, match="got str"): + A(TestIntPair).check_value("hello") + + def test_cxx_base_fails_derived_schema(self) -> None: + """_TestCxxClassBase fails _TestCxxClassDerived schema.""" + obj = _TestCxxClassBase(v_i64=1, v_i32=2) + with pytest.raises(TypeError): + A(_TestCxxClassDerived).check_value(obj) + + def test_sibling_types_reject_each_other(self) -> None: + """TestIntPair and TestCxxClassBase are unrelated -- reject each other.""" + pair = TestIntPair(1, 2) + base = _TestCxxClassBase(v_i64=1, v_i32=2) + with pytest.raises(TypeError): + A(_TestCxxClassBase).check_value(pair) + with pytest.raises(TypeError): + A(TestIntPair).check_value(base) + + +# --------------------------------------------------------------------------- +# Category 33: Custom objects in containers +# --------------------------------------------------------------------------- +class TestCustomObjectInContainers: + @requires_py39 + def test_array_of_custom_objects(self) -> None: + """Array[testing.TestIntPair] with matching elements.""" + objs = [TestIntPair(1, 2), TestIntPair(3, 4)] + A(tuple[TestIntPair, ...]).check_value(objs) + + @requires_py39 + def test_array_of_custom_objects_wrong_type(self) -> None: + """Array[testing.TestIntPair] with wrong element type fails.""" + objs = [TestIntPair(1, 2), _TestCxxClassBase(v_i64=1, v_i32=2)] + with pytest.raises(TypeError, match=r"element \[1\]"): + A(tuple[TestIntPair, ...]).check_value(objs) + + @requires_py39 + def test_array_of_base_with_derived_elements(self) -> None: + """Array[testing.TestObjectBase] accepts derived elements via hierarchy.""" + base = TestObjectBase(v_i64=1, v_f64=1.0, v_str="a") + derived = TestObjectDerived(v_map={"a": 1}, v_array=[1], v_i64=0, v_f64=0.0, v_str="") + A(tuple[TestObjectBase, ...]).check_value([base, derived]) + + @requires_py39 + def test_map_str_to_custom_object(self) -> None: + """Map[str, testing.TestIntPair] pass.""" + objs = {"a": TestIntPair(1, 2), "b": TestIntPair(3, 4)} + A(tvm_ffi.Map[str, TestIntPair]).check_value(objs) + + @requires_py39 + def test_map_str_to_custom_object_wrong_value(self) -> None: + """Map[str, testing.TestIntPair] with int value fails.""" + data = {"a": TestIntPair(1, 2), "b": 42} + with pytest.raises(TypeError, match="value for key 'b'"): + A(tvm_ffi.Map[str, TestIntPair]).check_value(data) + + @requires_py39 + def test_ffi_array_of_custom_objects(self) -> None: + """tvm_ffi.Array of custom objects passes Array[Object].""" + arr = tvm_ffi.Array([TestIntPair(1, 2), TestObjectBase(v_i64=1, v_f64=2.0, v_str="s")]) + A(tuple[tvm_ffi.core.Object, ...]).check_value(arr) + + @requires_py39 + def test_ffi_array_of_custom_objects_specific_type(self) -> None: + """tvm_ffi.Array of TestIntPair passes Array[testing.TestIntPair].""" + arr = tvm_ffi.Array([TestIntPair(1, 2), TestIntPair(3, 4)]) + A(tuple[TestIntPair, ...]).check_value(arr) + + @requires_py39 + def test_ffi_map_with_custom_object_values(self) -> None: + """tvm_ffi.Map with custom object values passes.""" + m = tvm_ffi.Map({"x": TestIntPair(1, 2), "y": TestIntPair(3, 4)}) + A(tvm_ffi.Map[str, TestIntPair]).check_value(m) + + +# --------------------------------------------------------------------------- +# Category 34: Optional/Union with custom objects +# --------------------------------------------------------------------------- +class TestCustomObjectOptionalUnion: + def test_optional_custom_object_with_value(self) -> None: + """Optional[testing.TestIntPair] with actual object.""" + obj = TestIntPair(1, 2) + A(Optional[TestIntPair]).check_value(obj) + + def test_optional_custom_object_with_none(self) -> None: + """Optional[testing.TestIntPair] with None.""" + A(Optional[TestIntPair]).check_value(None) + + def test_optional_custom_object_wrong_type(self) -> None: + """Optional[testing.TestIntPair] with wrong object type.""" + obj = _TestCxxClassBase(v_i64=1, v_i32=2) + with pytest.raises(TypeError): + A(Optional[TestIntPair]).check_value(obj) + + def test_union_custom_object_and_int(self) -> None: + """Union[testing.TestIntPair, int] with object.""" + obj = TestIntPair(1, 2) + A(Union[TestIntPair, int]).check_value(obj) + + def test_union_custom_object_and_int_with_int(self) -> None: + """Union[testing.TestIntPair, int] with int.""" + A(Union[TestIntPair, int]).check_value(42) + + def test_union_custom_object_and_int_with_wrong(self) -> None: + """Union[testing.TestIntPair, int] with str fails.""" + with pytest.raises(TypeError): + A(Union[TestIntPair, int]).check_value("bad") + + def test_union_two_custom_objects(self) -> None: + """Union of two custom types accepts both.""" + pair = TestIntPair(1, 2) + base = _TestCxxClassBase(v_i64=1, v_i32=2) + schema = A(Union[TestIntPair, _TestCxxClassBase]) + schema.check_value(pair) + schema.check_value(base) + + def test_union_two_custom_objects_rejects_third(self) -> None: + """Union of two custom types rejects a third.""" + obj = TestObjectBase(v_i64=1, v_f64=2.0, v_str="s") + with pytest.raises(TypeError): + A(Union[TestIntPair, _TestCxxClassBase]).check_value(obj) + + +# --------------------------------------------------------------------------- +# Category 35: Custom objects with from_type_index +# --------------------------------------------------------------------------- +class TestCustomObjectFromTypeIndex: + def test_from_type_index_custom_object(self) -> None: + """from_type_index resolves a custom object type and validates.""" + obj = TestIntPair(1, 2) + tindex = tvm_ffi.core._object_type_key_to_index("testing.TestIntPair") + assert tindex is not None + schema = TypeSchema.from_type_index(tindex) + assert schema.origin == "testing.TestIntPair" + schema.check_value(obj) + + def test_from_type_index_rejects_wrong_object(self) -> None: + """from_type_index schema rejects wrong object type.""" + tindex = tvm_ffi.core._object_type_key_to_index("testing.TestIntPair") + assert tindex is not None + schema = TypeSchema.from_type_index(tindex) + with pytest.raises(TypeError): + schema.check_value(_TestCxxClassBase(v_i64=1, v_i32=2)) + + def test_from_type_index_hierarchy(self) -> None: + """from_type_index for base type accepts derived objects.""" + tindex = tvm_ffi.core._object_type_key_to_index("testing.TestObjectBase") + assert tindex is not None + schema = TypeSchema.from_type_index(tindex) + derived = TestObjectDerived(v_map={"a": 1}, v_array=[1], v_i64=0, v_f64=0.0, v_str="") + schema.check_value(derived) + + +# --------------------------------------------------------------------------- +# Category 36: Custom objects in nested containers +# --------------------------------------------------------------------------- +class TestCustomObjectNestedContainers: + @requires_py39 + def test_array_of_optional_custom_object(self) -> None: + """Array[Optional[testing.TestIntPair]] with mix of objects and None.""" + data = [TestIntPair(1, 2), None, TestIntPair(3, 4)] + A(tuple[Optional[TestIntPair], ...]).check_value(data) + + @requires_py39 + def test_map_str_to_array_of_custom_objects(self) -> None: + """Map[str, Array[testing.TestIntPair]] with nested objects.""" + data = { + "group1": [TestIntPair(1, 2), TestIntPair(3, 4)], + "group2": [TestIntPair(5, 6)], + } + A(tvm_ffi.Map[str, tuple[TestIntPair, ...]]).check_value(data) + + @requires_py39 + def test_array_of_union_custom_objects(self) -> None: + """Array[Union[testing.TestIntPair, testing.TestCxxClassBase]].""" + data = [TestIntPair(1, 2), _TestCxxClassBase(v_i64=1, v_i32=2), TestIntPair(5, 6)] + A(tuple[Union[TestIntPair, _TestCxxClassBase], ...]).check_value(data) + + @requires_py39 + def test_optional_array_of_custom_objects(self) -> None: + """Optional[Array[testing.TestIntPair]] with array.""" + data = [TestIntPair(1, 2)] + A(Optional[tuple[TestIntPair, ...]]).check_value(data) + + @requires_py39 + def test_optional_array_of_custom_objects_none(self) -> None: + """Optional[Array[testing.TestIntPair]] with None.""" + A(Optional[tuple[TestIntPair, ...]]).check_value(None) + + @requires_py39 + def test_nested_error_with_custom_object(self) -> None: + """Array[testing.TestIntPair] error message includes type keys.""" + data = [TestIntPair(1, 2), _TestCxxClassBase(v_i64=1, v_i32=2)] + with pytest.raises( + TypeError, match=r"element \[1\].*testing.TestIntPair.*testing.TestCxxClassBase" + ): + A(tuple[TestIntPair, ...]).check_value(data) + + @requires_py39 + def test_map_nested_error_with_custom_object(self) -> None: + """Map value error for custom object includes key and type info.""" + data = {"ok": TestIntPair(1, 2), "bad": 42} + with pytest.raises( + TypeError, match=r"value for key 'bad'.*expected testing\.TestIntPair.*got int" + ): + A(tvm_ffi.Map[str, TestIntPair]).check_value(data) + + @requires_py39 + def test_deep_nested_custom_objects(self) -> None: + """Map[str, Array[Optional[testing.TestIntPair]]] deep nesting.""" + data = { + "a": [TestIntPair(1, 2), None], + "b": [None, TestIntPair(3, 4), TestIntPair(5, 6)], + } + A(tvm_ffi.Map[str, tuple[Optional[TestIntPair], ...]]).check_value(data) + + @requires_py39 + def test_deep_nested_custom_objects_error(self) -> None: + """Map[str, Array[testing.TestIntPair]] error at 3 levels.""" + data = {"k": [TestIntPair(1, 2), "bad"]} + with pytest.raises(TypeError, match=r"value for key 'k'.*element .1."): + A(tvm_ffi.Map[str, tuple[TestIntPair, ...]]).check_value(data) + + @requires_py39 + def test_tuple_with_custom_object(self) -> None: + """tuple[testing.TestIntPair, int, str] with custom object.""" + data = (TestIntPair(1, 2), 42, "hello") + A(tuple[TestIntPair, int, str]).check_value(data) + + @requires_py39 + def test_tuple_with_custom_object_wrong(self) -> None: + """tuple[testing.TestIntPair, int] with wrong object in first position.""" + data = (_TestCxxClassBase(v_i64=1, v_i32=2), 42) + with pytest.raises(TypeError, match=r"element \[0\]"): + A(tuple[TestIntPair, int]).check_value(data) + + +# --------------------------------------------------------------------------- +# Category 37: Lowercase Python-native origins ("list", "dict") +# --------------------------------------------------------------------------- +class TestLowercaseOrigins: + def test_list_origin_accepts_python_list(self) -> None: + """TypeSchema("list", ...) should validate elements, not passthrough.""" + S("list", S("int")).check_value([1, 2, 3]) # S: lowercase "list" is an internal origin + + def test_list_origin_rejects_bad_elements(self) -> None: + """TypeSchema("list", (int,)).check_value(["x"]) should fail.""" + with pytest.raises(TypeError, match=r"element \[0\]"): + S("list", S("int")).check_value(["x"]) # S: lowercase "list" is an internal origin + + def test_list_origin_converts_elements(self) -> None: + """TypeSchema("list", (float,)).convert([1, True]) does int->float.""" + # S: lowercase "list" is an internal origin + result = S("list", S("float")).convert([1, True]).to_py() + assert isinstance(result, tvm_ffi.List) + assert list(result) == [1.0, 1.0] + assert all(type(x) is float for x in result) + + def test_dict_origin_accepts_python_dict(self) -> None: + """TypeSchema("dict", ...) should validate key/value types.""" + S("dict", S("str"), S("int")).check_value( + {"a": 1} + ) # S: lowercase "dict" is an internal origin + + def test_dict_origin_rejects_bad_values(self) -> None: + """TypeSchema("dict", (str, int)).check_value({"a": "x"}) should fail.""" + with pytest.raises(TypeError, match="value for key 'a'"): + S("dict", S("str"), S("int")).check_value( + {"a": "x"} + ) # S: lowercase "dict" is an internal origin + + def test_dict_origin_converts_values(self) -> None: + """TypeSchema("dict", (str, float)).convert({"a": 1}) does int->float.""" + # S: lowercase "dict" is an internal origin + result = S("dict", S("str"), S("float")).convert({"a": 1, "b": True}).to_py() + assert isinstance(result, tvm_ffi.Dict) + assert result["a"] == 1.0 + assert result["b"] == 1.0 + assert all(type(v) is float for v in result.values()) + + def test_list_origin_no_args_accepts_anything(self) -> None: + """TypeSchema("list") with no args accepts any list (element type is Any).""" + S("list").check_value([1, "a", None]) # S: lowercase "list" is an internal origin + + def test_dict_origin_no_args_accepts_anything(self) -> None: + """TypeSchema("dict") with no args accepts any dict.""" + S("dict").check_value({"a": 1, 2: "b"}) # S: lowercase "dict" is an internal origin + + def test_list_origin_rejects_non_list(self) -> None: + """TypeSchema("list") rejects non-sequence types.""" + with pytest.raises(TypeError, match="got int"): + S("list").check_value(42) # S: lowercase "list" is an internal origin + + def test_dict_origin_rejects_non_dict(self) -> None: + """TypeSchema("dict") rejects non-dict types.""" + with pytest.raises(TypeError): + S("dict").check_value([1, 2]) # S: lowercase "dict" is an internal origin + + +# --------------------------------------------------------------------------- +# Category 38: Cross-type container conversions (Array<->List, Map<->Dict) +# --------------------------------------------------------------------------- +class TestCrossTypeContainers: + @requires_py39 + def test_array_schema_accepts_ffi_list(self) -> None: + """Array[int] schema accepts tvm_ffi.List (C++ kOtherTypeIndex).""" + lst = tvm_ffi.List([1, 2, 3]) + A(tuple[int, ...]).check_value(lst) + + @requires_py39 + def test_list_schema_accepts_ffi_array(self) -> None: + """List[int] schema accepts tvm_ffi.Array (C++ kOtherTypeIndex).""" + arr = tvm_ffi.Array([1, 2, 3]) + A(list[int]).check_value(arr) + + @requires_py39 + def test_map_schema_accepts_ffi_dict(self) -> None: + """Map[str, int] schema accepts tvm_ffi.Dict (C++ kOtherTypeIndex).""" + d = tvm_ffi.Dict({"a": 1, "b": 2}) + A(tvm_ffi.Map[str, int]).check_value(d) + + @requires_py39 + def test_dict_schema_accepts_ffi_map(self) -> None: + """Dict[str, int] schema accepts tvm_ffi.Map (C++ kOtherTypeIndex).""" + m = tvm_ffi.Map({"a": 1, "b": 2}) + A(dict[str, int]).check_value(m) + + @requires_py39 + def test_array_schema_converts_list_elements(self) -> None: + """Array[float] converts elements from tvm_ffi.List[int].""" + lst = tvm_ffi.List([1, 2, True]) + result = A(tuple[float, ...]).convert(lst).to_py() + assert list(result) == [1.0, 2.0, 1.0] + assert all(type(x) is float for x in result) + + @requires_py39 + def test_list_schema_converts_array_elements(self) -> None: + """List[float] converts elements from tvm_ffi.Array[int].""" + arr = tvm_ffi.Array([1, 2, True]) + result = A(list[float]).convert(arr).to_py() + assert list(result) == [1.0, 2.0, 1.0] + assert all(type(x) is float for x in result) + + @requires_py39 + def test_map_schema_converts_dict_values(self) -> None: + """Map[str, float] converts values from tvm_ffi.Dict.""" + d = tvm_ffi.Dict({"a": 1, "b": True}) + result = A(tvm_ffi.Map[str, float]).convert(d).to_py() + assert result["a"] == 1.0 + assert result["b"] == 1.0 + + @requires_py39 + def test_dict_schema_converts_map_values(self) -> None: + """Dict[str, float] converts values from tvm_ffi.Map.""" + m = tvm_ffi.Map({"a": 1, "b": True}) + result = A(dict[str, float]).convert(m).to_py() + assert result["a"] == 1.0 + assert result["b"] == 1.0 + + @requires_py39 + def test_cross_type_still_rejects_wrong_container(self) -> None: + """Array schema still rejects non-sequence CObjects (e.g. Map).""" + m = tvm_ffi.Map({"a": 1}) + with pytest.raises(TypeError, match="expected Array"): + A(tuple[int, ...]).check_value(m) + + @requires_py39 + def test_cross_type_map_rejects_array(self) -> None: + """Map schema still rejects sequence CObjects (e.g. Array).""" + arr = tvm_ffi.Array([1, 2]) + with pytest.raises(TypeError, match="expected Map"): + A(tvm_ffi.Map[str, int]).check_value(arr) + + +# --------------------------------------------------------------------------- +# Category 39: tuple accepts list and CObject Array +# --------------------------------------------------------------------------- +class TestTupleAcceptsListAndArray: + @requires_py39 + def test_tuple_accepts_python_list(self) -> None: + """tuple[int, str] accepts Python list input.""" + result = A(tuple[int, str]).convert([42, "hello"]).to_py() + assert list(result) == [42, "hello"] + + @requires_py39 + def test_tuple_list_with_conversion(self) -> None: + """tuple[float, int] converts list elements (bool->float, bool->int).""" + result = A(tuple[float, int]).convert([True, False]).to_py() + assert list(result) == [1.0, 0] + assert type(result[0]) is float + assert type(result[1]) is int + + @requires_py39 + def test_tuple_rejects_wrong_length_list(self) -> None: + """tuple[int, str] rejects list of wrong length.""" + with pytest.raises(TypeError, match="length"): + A(tuple[int, str]).check_value([1, "a", "b"]) + + @requires_py39 + def test_tuple_accepts_ffi_array(self) -> None: + """tuple[int, int] accepts tvm_ffi.Array (C++ Tuple accepts kTVMFFIArray).""" + arr = tvm_ffi.Array([1, 2]) + A(tuple[int, int]).check_value(arr) + + @requires_py39 + def test_tuple_ffi_array_with_conversion(self) -> None: + """tuple[float, float] converts tvm_ffi.Array elements.""" + arr = tvm_ffi.Array([1, True]) + result = A(tuple[float, float]).convert(arr).to_py() + assert list(result) == [1.0, 1.0] + assert all(type(x) is float for x in result) + + @requires_py39 + def test_tuple_ffi_array_wrong_length(self) -> None: + """tuple[int, int] rejects tvm_ffi.Array of wrong length.""" + arr = tvm_ffi.Array([1, 2, 3]) + with pytest.raises(TypeError, match="length"): + A(tuple[int, int]).check_value(arr) + + @requires_py39 + def test_tuple_rejects_ffi_map(self) -> None: + """Tuple schema rejects Map CObject.""" + m = tvm_ffi.Map({"a": 1}) + with pytest.raises(TypeError, match="expected tuple"): + A(tuple[int]).check_value(m) + + def test_untyped_tuple_accepts_list(self) -> None: + """Tuple (no args) accepts any list as-is.""" + # Untyped tuple has tuple_len=0, so it just checks the container type + # but doesn't validate elements + A(tuple).check_value([1, "a", None]) + + def test_untyped_tuple_accepts_ffi_array(self) -> None: + """Tuple (no args) accepts tvm_ffi.Array as-is.""" + arr = tvm_ffi.Array([1, 2, 3]) + A(tuple).check_value(arr) + + def test_typed_empty_tuple_rejects_non_empty_list(self) -> None: + """Explicit empty tuple schema enforces length 0.""" + schema = TypeSchema("tuple", ()) + with pytest.raises(TypeError, match="length 0"): + schema.check_value([1]) + + def test_untyped_tuple_converts_ffi_list_to_array(self) -> None: + """Tuple (no args) normalizes tvm_ffi.List input to tvm_ffi.Array.""" + lst = tvm_ffi.List([1, 2, 3]) + result = A(tuple).convert(lst).to_py() + assert isinstance(result, tvm_ffi.Array) + assert list(result) == [1, 2, 3] + assert not result.same_as(lst) + + +# --------------------------------------------------------------------------- +# Category 40: dtype string parse errors +# --------------------------------------------------------------------------- +class TestDtypeParseErrors: + def test_check_value_bad_dtype_raises_error(self) -> None: + """check_value should raise TypeError for invalid dtype.""" + with pytest.raises(TypeError, match="dtype"): + A(tvm_ffi.core.DataType).check_value("not_a_valid_dtype_xyz") + + def test_convert_bad_dtype_raises_type_error_2(self) -> None: + """Convert should raise TypeError for invalid dtype string.""" + with pytest.raises(TypeError, match="dtype"): + A(tvm_ffi.core.DataType).convert("not_a_valid_dtype_xyz") + + def test_convert_bad_dtype_raises_type_error(self) -> None: + """Convert should raise TypeError for invalid dtype string.""" + with pytest.raises(TypeError, match="dtype"): + A(tvm_ffi.core.DataType).convert("not_a_valid_dtype_xyz") + + def test_valid_dtype_string_still_works(self) -> None: + """Valid dtype strings should still convert successfully.""" + result = A(tvm_ffi.core.DataType).convert("float32").to_py() + assert str(result) == "float32" + + def test_convert_valid_dtype(self) -> None: + """Convert with valid dtype returns DataType.""" + result = A(tvm_ffi.core.DataType).convert("int8").to_py() + assert str(result) == "int8" + + +# --------------------------------------------------------------------------- +# Category 41: int64 boundary checking +# --------------------------------------------------------------------------- +class TestInt64Boundaries: + """Verify int converter rejects values outside int64 range. + + The FFI marshals Python int to C++ int64_t. Values outside + [-2^63, 2^63-1] would silently overflow at marshal time, so + the converter rejects them early. + """ + + def test_int64_max_accepted(self) -> None: + """2^63-1 (INT64_MAX) is the largest valid int.""" + A(int).check_value(2**63 - 1) + + def test_int64_min_accepted(self) -> None: + """-2^63 (INT64_MIN) is the smallest valid int.""" + A(int).check_value(-(2**63)) + + def test_int64_max_plus_one_rejected(self) -> None: + """2^63 exceeds int64 range.""" + with pytest.raises(TypeError, match="int64 range"): + A(int).check_value(2**63) + + def test_int64_min_minus_one_rejected(self) -> None: + """-2^63-1 exceeds int64 range.""" + with pytest.raises(TypeError, match="int64 range"): + A(int).check_value(-(2**63) - 1) + + def test_very_large_positive_rejected(self) -> None: + """Very large positive integer rejected.""" + with pytest.raises(TypeError, match="int64 range"): + A(int).check_value(10**100) + + def test_very_large_negative_rejected(self) -> None: + """Very large negative integer rejected.""" + with pytest.raises(TypeError, match="int64 range"): + A(int).check_value(-(10**100)) + + def test_convert_raises_type_error_for_overflow(self) -> None: + """Convert raises TypeError for overflow.""" + with pytest.raises(TypeError, match="int64 range"): + A(int).convert(2**63) + + def test_bool_to_int_no_range_issue(self) -> None: + """Bool -> int conversion (0 or 1) always fits.""" + assert A(int).convert(True).to_py() == 1 + assert A(int).convert(False).to_py() == 0 + + def test_int64_boundaries_in_float_conversion(self) -> None: + """Float schema accepts large ints (float64 has wider range).""" + # float64 can represent integers up to 2^53 exactly, + # and larger values with precision loss (but no range error) + A(float).check_value(2**63) + A(float).check_value(-(2**63)) + + def test_int64_overflow_in_optional_int(self) -> None: + """Optional[int] propagates int64 range check.""" + with pytest.raises(TypeError, match="int64 range"): + A(Optional[int]).check_value(2**63) + + @requires_py39 + def test_int64_overflow_in_array_element(self) -> None: + """Array[int] element overflow is caught with path.""" + with pytest.raises(TypeError, match="int64 range"): + A(tuple[int, ...]).check_value([1, 2**63, 3]) + + +# --------------------------------------------------------------------------- +# Category 42: Unknown origin errors (lazy converter construction) +# --------------------------------------------------------------------------- +class TestUnknownOriginErrors: + """Converter is built lazily via cached_property. Unknown origins + construct fine but raise TypeError on first convert/check_value. + """ + + def test_unknown_origin_constructs_ok(self) -> None: + """TypeSchema with unknown origin can be constructed.""" + schema = S("not_a_real_type") # S: intentionally invalid origin + assert schema.origin == "not_a_real_type" + + def test_unknown_origin_errors_on_check_value(self) -> None: + """Unknown origin raises TypeError on check_value.""" + schema = S("not_a_real_type") # S: intentionally invalid origin + with pytest.raises(TypeError, match="unknown TypeSchema origin"): + schema.check_value(42) + + def test_unknown_origin_errors_on_convert(self) -> None: + """Unknown origin raises TypeError on convert.""" + schema = S("not_a_real_type") # S: intentionally invalid origin + with pytest.raises(TypeError, match="unknown TypeSchema origin"): + schema.convert(42) + + def test_unknown_origin_errors_on_convert_2(self) -> None: + """Unknown origin raises TypeError on convert (duplicate check).""" + schema = S("not_a_real_type") # S: intentionally invalid origin + with pytest.raises(TypeError, match="unknown TypeSchema origin"): + schema.convert(42) + + def test_unknown_origin_errors_on_check_value_2(self) -> None: + """Unknown origin raises TypeError on check_value (duplicate check).""" + schema = S("not_a_real_type") # S: intentionally invalid origin + with pytest.raises(TypeError, match="unknown TypeSchema origin"): + schema.check_value(42) + + def test_typo_origin_errors(self) -> None: + """Common typos are caught, not silently passed through.""" + for typo in ("innt", "floot", "strr", "Int", "Float"): + schema = S(typo) # S: intentionally invalid origin + with pytest.raises(TypeError, match="unknown TypeSchema origin"): + schema.check_value(42) + + def test_unknown_nested_in_optional_errors(self) -> None: + """Unknown origin nested inside Optional errors on use.""" + schema = S("Optional", S("not_a_real_type")) # S: intentionally invalid origin + with pytest.raises(TypeError, match="unknown TypeSchema origin"): + schema.check_value(42) + + +# --------------------------------------------------------------------------- +# Category 43: convert/check_value raise TypeError on errors +# --------------------------------------------------------------------------- +class TestConvertCheckValueErrors: + """Verify convert and check_value raise TypeError on errors.""" + + def test_convert_catches_custom_integral_error(self) -> None: + """Custom Integral whose __int__ raises is caught by convert.""" + + class BadInt: + """Registered as Integral via ABC but __int__ raises.""" + + def __int__(self) -> int: + raise RuntimeError("broken __int__") + + Integral.register(BadInt) + with pytest.raises(TypeError, match="broken __int__"): + A(int).convert(BadInt()) + + def test_check_value_catches_custom_integral_error(self) -> None: + """Custom Integral whose __int__ raises is caught by check_value.""" + + class BadInt2: + def __int__(self) -> int: + raise ValueError("bad int conversion") + + Integral.register(BadInt2) + with pytest.raises(TypeError, match="bad int conversion"): + A(int).check_value(BadInt2()) + + def test_convert_unknown_origin_raises(self) -> None: + """Convert with unknown origin raises TypeError.""" + with pytest.raises(TypeError, match="unknown TypeSchema origin"): + S("bogus_type").convert("anything") # S: intentionally invalid origin + + def test_check_value_unknown_origin_raises(self) -> None: + """check_value with unknown origin raises TypeError.""" + with pytest.raises(TypeError, match="unknown TypeSchema origin"): + S("bogus_type").check_value("anything") # S: intentionally invalid origin + + +# --------------------------------------------------------------------------- +# Category 44: Schema arity validation (ValueError, not assert) +# --------------------------------------------------------------------------- +class TestSchemaArityValidation: + """Verify arity checks use ValueError (not assert) so they work under -O.""" + + def test_union_too_few_args(self) -> None: + """Union with < 2 args raises ValueError.""" + with pytest.raises(ValueError, match="at least two"): + S("Union", A(int)) + + def test_optional_wrong_arity(self) -> None: + """Optional with != 1 arg raises ValueError.""" + with pytest.raises(ValueError, match="exactly one"): + S("Optional") + with pytest.raises(ValueError, match="exactly one"): + S("Optional", A(int), A(str)) + + def test_array_too_many_args(self) -> None: + """Array with > 1 arg raises ValueError.""" + with pytest.raises(ValueError, match="0 or 1"): + S("Array", A(int), A(str)) + + def test_list_too_many_args(self) -> None: + """List with > 1 arg raises ValueError.""" + with pytest.raises(ValueError, match="0 or 1"): + S("List", A(int), A(str)) + + def test_map_wrong_arity(self) -> None: + """Map with 1 or 3 args raises ValueError.""" + with pytest.raises(ValueError, match="0 or 2"): + S("Map", A(str)) + with pytest.raises(ValueError, match="0 or 2"): + S("Map", A(str), A(int), A(float)) + + def test_dict_wrong_arity(self) -> None: + """Dict with 1 or 3 args raises ValueError.""" + with pytest.raises(ValueError, match="0 or 2"): + S("Dict", A(str)) + + def test_lowercase_list_too_many_args(self) -> None: + """Lowercase 'list' with > 1 arg raises ValueError.""" + with pytest.raises(ValueError, match="0 or 1"): + S("list", A(int), A(str)) # S: lowercase "list" is an internal origin + + def test_lowercase_dict_wrong_arity(self) -> None: + """Lowercase 'dict' with 1 arg raises ValueError.""" + with pytest.raises(ValueError, match="0 or 2"): + S("dict", A(str)) # S: lowercase "dict" is an internal origin + + +# --------------------------------------------------------------------------- +# Category 45: from_type_index edge cases +# --------------------------------------------------------------------------- +class TestFromTypeIndexEdgeCases: + """Verify from_type_index behavior for valid indices. + + Note: Unregistered type indices trigger a fatal C++ assertion + (TVMFFIGetTypeInfo CHECK failure) that cannot be caught from Python. + Only valid indices obtained from the type registry should be passed. + """ + + def test_valid_pod_index_roundtrip(self) -> None: + """POD type_index from TypeSchema.origin_type_index round-trips.""" + int_schema = A(int) + schema = TypeSchema.from_type_index(int_schema.origin_type_index) + assert schema.origin == "int" + schema.check_value(42) + + def test_valid_object_index_works(self) -> None: + """Valid registered object type_index constructs fine.""" + tindex = tvm_ffi.core._object_type_key_to_index("testing.TestIntPair") + assert tindex is not None + schema = TypeSchema.from_type_index(tindex) + assert schema.origin == "testing.TestIntPair" + + def test_from_type_index_with_args(self) -> None: + """from_type_index with type arguments creates parameterized schema.""" + arr_schema = A(tvm_ffi.Array) + schema = TypeSchema.from_type_index(arr_schema.origin_type_index, (A(int),)) + assert schema.origin == "Array" + schema.check_value([1, 2, 3]) + + +# =========================================================================== +# Protocol-based conversion tests (matching Python FFI marshal path) +# =========================================================================== + + +# --------------------------------------------------------------------------- +# Category 46: __tvm_ffi_int__ protocol +# --------------------------------------------------------------------------- +class TestIntProtocol: + """int schema accepts values with __tvm_ffi_int__ protocol.""" + + def test_int_protocol_accepted(self) -> None: + """Object with __tvm_ffi_int__ passes int schema.""" + + class IntProto: + def __tvm_ffi_int__(self) -> int: + return 42 + + A(int).check_value(IntProto()) + + def test_int_protocol_check_value(self) -> None: + """check_value succeeds for __tvm_ffi_int__ value.""" + + class IntProto: + def __tvm_ffi_int__(self) -> int: + return 10 + + A(int).check_value(IntProto()) + + def test_int_protocol_convert_returns_value(self) -> None: + """Convert returns the protocol value as-is (marshal handles conversion).""" + + class IntProto: + def __tvm_ffi_int__(self) -> int: + return 99 + + obj = IntProto() + result = A(int).convert(obj).to_py() + assert result is not None + + def test_without_protocol_still_rejected(self) -> None: + """Object without __tvm_ffi_int__ is still rejected by int schema.""" + + class NoProto: + pass + + with pytest.raises(TypeError, match="expected int"): + A(int).check_value(NoProto()) + + +# --------------------------------------------------------------------------- +# Category 47: __tvm_ffi_float__ protocol +# --------------------------------------------------------------------------- +class TestFloatProtocol: + """float schema accepts values with __tvm_ffi_float__ protocol.""" + + def test_float_protocol_accepted(self) -> None: + """Object with __tvm_ffi_float__ passes float schema.""" + + class FloatProto: + def __tvm_ffi_float__(self) -> float: + return 3.14 + + A(float).check_value(FloatProto()) + + def test_float_protocol_convert(self) -> None: + """Convert returns protocol value as-is.""" + + class FloatProto: + def __tvm_ffi_float__(self) -> float: + return 2.0 + + obj = FloatProto() + result = A(float).convert(obj).to_py() + assert result is not None + + def test_without_protocol_still_rejected(self) -> None: + """Object without __tvm_ffi_float__ is still rejected.""" + + class NoProto: + pass + + with pytest.raises(TypeError, match="expected float"): + A(float).check_value(NoProto()) + + +# --------------------------------------------------------------------------- +# Category 48: __tvm_ffi_opaque_ptr__ protocol +# --------------------------------------------------------------------------- +class TestOpaquePtrProtocol: + """ctypes.c_void_p schema accepts __tvm_ffi_opaque_ptr__ protocol.""" + + def test_opaque_ptr_protocol_accepted(self) -> None: + """Object with __tvm_ffi_opaque_ptr__ passes ctypes.c_void_p schema.""" + + class PtrProto: + def __tvm_ffi_opaque_ptr__(self) -> int: + return 0xDEAD + + A(ctypes.c_void_p).check_value(PtrProto()) + + def test_opaque_ptr_protocol_convert(self) -> None: + """Convert returns protocol value as-is.""" + + class PtrProto: + def __tvm_ffi_opaque_ptr__(self) -> int: + return 0 + + obj = PtrProto() + result = A(ctypes.c_void_p).convert(obj).to_py() + assert result is not None + + +# --------------------------------------------------------------------------- +# Category 49: __dlpack_device__ protocol +# --------------------------------------------------------------------------- +class TestDeviceProtocol: + """Device schema accepts __dlpack_device__ protocol.""" + + def test_dlpack_device_protocol_accepted(self) -> None: + """Object with __dlpack_device__ passes Device schema.""" + + class DevProto: + def __dlpack_device__(self) -> tuple[int, int]: + return (1, 0) + + A(tvm_ffi.Device).check_value(DevProto()) + + def test_dlpack_device_protocol_convert(self) -> None: + """Convert returns protocol value as-is.""" + + class DevProto: + def __dlpack_device__(self) -> tuple[int, int]: + return (2, 1) + + obj = DevProto() + result = A(tvm_ffi.Device).convert(obj).to_py() + assert result is not None + + def test_without_protocol_still_rejected(self) -> None: + """Object without __dlpack_device__ is still rejected.""" + + class NoProto: + pass + + with pytest.raises(TypeError, match="expected Device"): + A(tvm_ffi.Device).check_value(NoProto()) + + +# --------------------------------------------------------------------------- +# Category 50: dtype protocols (torch.dtype, numpy.dtype, __dlpack_data_type__) +# --------------------------------------------------------------------------- +class TestDtypeProtocols: + """dtype schema accepts torch.dtype, numpy.dtype, __dlpack_data_type__.""" + + def test_dlpack_data_type_protocol_accepted(self) -> None: + """Object with __dlpack_data_type__ passes dtype schema.""" + + class DTypeProto: + def __dlpack_data_type__(self) -> tuple[int, int, int]: + return (2, 32, 1) # float32 + + A(tvm_ffi.core.DataType).check_value(DTypeProto()) + + def test_dlpack_data_type_protocol_convert(self) -> None: + """Convert returns protocol value as-is.""" + + class DTypeProto: + def __dlpack_data_type__(self) -> tuple[int, int, int]: + return (0, 32, 1) + + obj = DTypeProto() + result = A(tvm_ffi.core.DataType).convert(obj).to_py() + assert result is not None + + def test_numpy_dtype_accepted(self) -> None: + """numpy.dtype passes dtype schema (if numpy installed).""" + numpy = pytest.importorskip("numpy") + A(tvm_ffi.core.DataType).check_value(numpy.dtype("float32")) + + def test_numpy_dtype_convert(self) -> None: + """Convert returns numpy.dtype as-is.""" + numpy = pytest.importorskip("numpy") + dt = numpy.dtype("int32") + result = A(tvm_ffi.core.DataType).convert(dt).to_py() + assert result is not None + + def test_torch_dtype_accepted(self) -> None: + """torch.dtype passes dtype schema (if torch installed).""" + torch = pytest.importorskip("torch") + A(tvm_ffi.core.DataType).check_value(torch.float32) + + def test_torch_dtype_convert(self) -> None: + """Convert returns torch.dtype as-is.""" + torch = pytest.importorskip("torch") + dt = torch.int64 + result = A(tvm_ffi.core.DataType).convert(dt).to_py() + assert result is not None + + +# --------------------------------------------------------------------------- +# Category 51: __dlpack_c_exchange_api__ protocol (Tensor) +# --------------------------------------------------------------------------- +class TestTensorProtocol: + """Tensor schema accepts __dlpack_c_exchange_api__ protocol.""" + + def test_dlpack_c_exchange_api_accepted(self) -> None: + """Object with a valid __dlpack_c_exchange_api__ passes Tensor schema.""" + np = pytest.importorskip("numpy") + tensor = tvm_ffi.from_dlpack(np.arange(4, dtype="int32")) + wrapper = tvm_ffi.core.DLTensorTestWrapper(tensor) + A(tvm_ffi.Tensor).check_value(wrapper) + + def test_dlpack_c_exchange_api_convert(self) -> None: + """Valid exchange-api wrappers can be converted to Tensor.""" + np = pytest.importorskip("numpy") + tensor = tvm_ffi.from_dlpack(np.arange(4, dtype="int32")) + wrapper = tvm_ffi.core.DLTensorTestWrapper(tensor) + result = A(tvm_ffi.Tensor).convert(wrapper).to_py() + assert isinstance(result, tvm_ffi.Tensor) + + def test_dlpack_still_accepted(self) -> None: + """Object with __dlpack__ still accepted (existing behavior).""" + np = pytest.importorskip("numpy") + A(tvm_ffi.Tensor).check_value(np.arange(4, dtype="int32")) + + +# --------------------------------------------------------------------------- +# Category 52: __tvm_ffi_object__ protocol +# --------------------------------------------------------------------------- +class TestObjectProtocol: + """Object schemas accept __tvm_ffi_object__ protocol.""" + + def test_object_protocol_generic_object(self) -> None: + """__tvm_ffi_object__ returning a CObject passes generic Object schema.""" + inner = TestIntPair(1, 2) + + class ObjProto: + def __tvm_ffi_object__(self) -> object: + return inner + + A(tvm_ffi.core.Object).check_value(ObjProto()) + + def test_object_protocol_specific_type(self) -> None: + """__tvm_ffi_object__ returning TestIntPair passes TestIntPair schema.""" + inner = TestIntPair(3, 4) + + class ObjProto: + def __tvm_ffi_object__(self) -> object: + return inner + + A(TestIntPair).check_value(ObjProto()) + + def test_object_protocol_convert_returns_cobject(self) -> None: + """Convert returns the CObject from __tvm_ffi_object__().""" + inner = TestIntPair(5, 6) + + class ObjProto: + def __tvm_ffi_object__(self) -> object: + return inner + + result = A(TestIntPair).convert(ObjProto()).to_py() + assert result.same_as(inner) + + def test_object_protocol_wrong_type_rejected(self) -> None: + """__tvm_ffi_object__ returning wrong type is rejected.""" + inner = TestIntPair(1, 2) + + class ObjProto: + def __tvm_ffi_object__(self) -> object: + return inner + + with pytest.raises( + TypeError, match=r"expected testing\.TestCxxClassBase, got testing\.TestIntPair" + ): + A(_TestCxxClassBase).check_value(ObjProto()) + + def test_object_protocol_raises_caught(self) -> None: + """__tvm_ffi_object__ that raises produces _ConvertError.""" + + class BadProto: + def __tvm_ffi_object__(self) -> object: + raise RuntimeError("broken") + + with pytest.raises(TypeError, match=r"__tvm_ffi_object__\(\) failed"): + A(tvm_ffi.core.Object).check_value(BadProto()) + + def test_object_protocol_hierarchy(self) -> None: + """__tvm_ffi_object__ returning derived passes base schema.""" + derived = _TestCxxClassDerived(v_i64=1, v_i32=2, v_f64=3.0) + + class ObjProto: + def __tvm_ffi_object__(self) -> object: + return derived + + A(_TestCxxClassBase).check_value(ObjProto()) + + +# --------------------------------------------------------------------------- +# Category 53: ObjectConvertible protocol +# --------------------------------------------------------------------------- +class TestObjectConvertibleProtocol: + """Object schemas accept ObjectConvertible subclass.""" + + def test_object_convertible_accepted(self) -> None: + """ObjectConvertible with asobject() passes Object schema.""" + inner = TestIntPair(10, 20) + + class MyConvertible(ObjectConvertible): + def asobject(self) -> tvm_ffi.core.Object: + return inner + + A(tvm_ffi.core.Object).check_value(MyConvertible()) + + def test_object_convertible_specific_type(self) -> None: + """ObjectConvertible passes specific type schema.""" + inner = TestIntPair(1, 2) + + class MyConvertible(ObjectConvertible): + def asobject(self) -> tvm_ffi.core.Object: + return inner + + A(TestIntPair).check_value(MyConvertible()) + + def test_object_convertible_convert_returns_cobject(self) -> None: + """Convert returns the CObject from asobject().""" + inner = TestIntPair(7, 8) + + class MyConvertible(ObjectConvertible): + def asobject(self) -> tvm_ffi.core.Object: + return inner + + result = A(TestIntPair).convert(MyConvertible()).to_py() + assert result.same_as(inner) + + def test_object_convertible_in_union(self) -> None: + """Union dispatch unwraps ObjectConvertible before trying alternatives.""" + inner = TestIntPair(9, 10) + + class MyConvertible(ObjectConvertible): + def asobject(self) -> tvm_ffi.core.Object: + return inner + + result = A(Union[TestIntPair, int]).convert(MyConvertible()).to_py() + assert result.same_as(inner) + + def test_object_convertible_wrong_type(self) -> None: + """ObjectConvertible returning wrong type is rejected.""" + inner = TestIntPair(1, 2) + + class MyConvertible(ObjectConvertible): + def asobject(self) -> tvm_ffi.core.Object: + return inner + + with pytest.raises( + TypeError, + match=r"type check failed for testing\.TestCxxClassBase: expected testing\.TestCxxClassBase, got testing\.TestIntPair", + ): + A(_TestCxxClassBase).check_value(MyConvertible()) + + def test_object_convertible_raises_caught(self) -> None: + """asobject() that raises produces error, not exception.""" + + class BadConvertible(ObjectConvertible): + def asobject(self) -> tvm_ffi.core.Object: + raise RuntimeError("broken asobject") + + with pytest.raises(TypeError, match=r"asobject\(\) failed"): + A(tvm_ffi.core.Object).check_value(BadConvertible()) + + +# --------------------------------------------------------------------------- +# Category 54: __tvm_ffi_value__ protocol (recursive fallback) +# --------------------------------------------------------------------------- +class TestValueProtocol: + """__tvm_ffi_value__ provides recursive conversion fallback.""" + + def test_value_protocol_int(self) -> None: + """__tvm_ffi_value__ returning int passes int schema.""" + + class ValProto: + def __tvm_ffi_value__(self) -> object: + return 42 + + A(int).check_value(ValProto()) + + def test_value_protocol_float(self) -> None: + """__tvm_ffi_value__ returning float passes float schema.""" + + class ValProto: + def __tvm_ffi_value__(self) -> object: + return 3.14 + + A(float).check_value(ValProto()) + + def test_value_protocol_convert(self) -> None: + """Convert returns the unwrapped value from __tvm_ffi_value__.""" + + class ValProto: + def __tvm_ffi_value__(self) -> object: + return 42 + + result = A(int).convert(ValProto()).to_py() + assert result == 42 + + def test_value_protocol_nested(self) -> None: + """Nested __tvm_ffi_value__ is recursively unwrapped.""" + + class ValProto: + def __init__(self, v: object) -> None: + self.v = v + + def __tvm_ffi_value__(self) -> object: + return self.v + + # ValProto(ValProto(ValProto(10))) should unwrap to 10 + wrapped = ValProto(ValProto(ValProto(10))) + assert A(int).convert(wrapped).to_py() == 10 + + def test_value_protocol_object(self) -> None: + """__tvm_ffi_value__ returning a CObject passes object schema.""" + inner = TestIntPair(1, 2) + + class ValProto: + def __tvm_ffi_value__(self) -> object: + return inner + + A(TestIntPair).check_value(ValProto()) + + def test_value_protocol_still_fails_on_mismatch(self) -> None: + """__tvm_ffi_value__ returning wrong type still fails.""" + + class ValProto: + def __tvm_ffi_value__(self) -> object: + return "not_an_int" + + with pytest.raises(TypeError, match="expected int"): + A(int).check_value(ValProto()) + + def test_value_protocol_raises_uses_original_error(self) -> None: + """If __tvm_ffi_value__ raises, the original error is returned.""" + + class BadValProto: + def __tvm_ffi_value__(self) -> object: + raise RuntimeError("broken") + + with pytest.raises(TypeError, match="expected int"): + A(int).check_value(BadValProto()) + + def test_nested_optional_value_protocol_stall(self) -> None: + """Optional[Optional[float]] reports the unwrapped target type.""" + + class SelfRef: + def __tvm_ffi_value__(self) -> object: + return self + + with pytest.raises(TypeError, match="expected float"): + A(Optional[Optional[float]]).check_value(SelfRef()) + + def test_value_protocol_eventually_resolves(self) -> None: + """Short __tvm_ffi_value__ chains still resolve successfully.""" + + class ChainStep: + def __init__(self, remaining: int) -> None: + self.remaining = remaining + + def __tvm_ffi_value__(self) -> object: + if self.remaining > 0: + return ChainStep(self.remaining - 1) + return 42 + + assert A(int).convert(ChainStep(5)).to_py() == 42 + + +# --------------------------------------------------------------------------- +# Category 55: Protocol values in containers +# --------------------------------------------------------------------------- +class TestProtocolsInContainers: + """Protocol-accepting values work inside containers and composites.""" + + @requires_py39 + def test_int_protocol_in_array(self) -> None: + """Array[int] accepts elements with __tvm_ffi_int__.""" + + class IntProto: + def __tvm_ffi_int__(self) -> int: + return 1 + + A(tuple[int, ...]).check_value([1, IntProto(), 3]) + + def test_float_protocol_in_optional(self) -> None: + """Optional[float] accepts __tvm_ffi_float__ value.""" + + class FloatProto: + def __tvm_ffi_float__(self) -> float: + return 1.0 + + A(Optional[float]).check_value(FloatProto()) + A(Optional[float]).check_value(None) + + def test_object_protocol_in_union(self) -> None: + """Union[testing.TestIntPair, int] accepts __tvm_ffi_object__ value.""" + inner = TestIntPair(1, 2) + + class ObjProto: + def __tvm_ffi_object__(self) -> object: + return inner + + A(Union[TestIntPair, int]).check_value(ObjProto()) + + @requires_py39 + def test_value_protocol_in_array(self) -> None: + """Array[int] elements use __tvm_ffi_value__ fallback (recursive).""" + + class ValProto: + def __tvm_ffi_value__(self) -> object: + return 42 + + # __tvm_ffi_value__ fallback is applied recursively at every level, + # matching the marshal path where TVMFFIPyArgSetterFactory_ is + # called per-element. + A(tuple[int, ...]).check_value([ValProto()]) + + @requires_py39 + def test_object_convertible_in_array(self) -> None: + """Array[Object] elements unwrap ObjectConvertible before element dispatch.""" + inner = TestIntPair(3, 4) + + class Convertible(ObjectConvertible): + def asobject(self) -> tvm_ffi.core.Object: + return inner + + result = A(tuple[tvm_ffi.core.Object, ...]).convert([Convertible()]).to_py() + assert result[0].same_as(inner) + + @requires_py39 + def test_device_protocol_in_map_value(self) -> None: + """Map[str, Device] accepts __dlpack_device__ values.""" + + class DevProto: + def __dlpack_device__(self) -> tuple[int, int]: + return (1, 0) + + A(tvm_ffi.Map[str, tvm_ffi.Device]).check_value({"gpu": DevProto()}) + + +# --------------------------------------------------------------------------- +# Category 56: Nested __tvm_ffi_value__ in containers (recursive fallback) +# --------------------------------------------------------------------------- +class TestNestedValueProtocol: + """__tvm_ffi_value__ fallback works recursively inside containers.""" + + @requires_py39 + def test_value_in_array_elements(self) -> None: + """Array[int] elements with __tvm_ffi_value__ are accepted.""" + + class VP: + def __tvm_ffi_value__(self) -> object: + return 42 + + A(tuple[int, ...]).check_value([1, VP(), 3]) + + @requires_py39 + def test_value_in_map_values(self) -> None: + """Map[str, int] values with __tvm_ffi_value__ are accepted.""" + + class VP: + def __tvm_ffi_value__(self) -> object: + return 99 + + A(tvm_ffi.Map[str, int]).check_value({"a": VP()}) + + @requires_py39 + def test_value_in_map_keys(self) -> None: + """Map[str, int] keys with __tvm_ffi_value__ are accepted.""" + + class VP: + def __tvm_ffi_value__(self) -> object: + return "key" + + A(tvm_ffi.Map[str, int]).check_value({VP(): 1}) + + @requires_py39 + def test_value_in_tuple_positions(self) -> None: + """tuple[int, str] positions with __tvm_ffi_value__ are accepted.""" + + class IntVP: + def __tvm_ffi_value__(self) -> object: + return 42 + + class StrVP: + def __tvm_ffi_value__(self) -> object: + return "hello" + + A(tuple[int, str]).check_value((IntVP(), StrVP())) + + def test_value_in_optional_inner(self) -> None: + """Optional[int] inner with __tvm_ffi_value__ is accepted.""" + + class VP: + def __tvm_ffi_value__(self) -> object: + return 42 + + A(Optional[int]).check_value(VP()) + + def test_value_in_union_alternatives(self) -> None: + """Union[int, str] with __tvm_ffi_value__ is accepted.""" + + class VP: + def __tvm_ffi_value__(self) -> object: + return "hello" + + A(Union[int, str]).check_value(VP()) + + @requires_py39 + def test_multi_hop_value_in_container(self) -> None: + """Nested __tvm_ffi_value__ unwrapping inside containers.""" + + class VP: + def __init__(self, v: object) -> None: + self.v = v + + def __tvm_ffi_value__(self) -> object: + return self.v + + A(tuple[int, ...]).check_value([VP(VP(10))]) + + @requires_py39 + def test_value_convert_in_array(self) -> None: + """Convert returns unwrapped values in container.""" + + class VP: + def __tvm_ffi_value__(self) -> object: + return 42 + + result = A(tuple[int, ...]).convert([VP()]).to_py() + assert list(result) == [42] + + +# --------------------------------------------------------------------------- +# Category 57: __tvm_ffi_value__ cycle protection +# --------------------------------------------------------------------------- +class TestValueProtocolCycles: + """Cycle protection in __tvm_ffi_value__ fallback.""" + + def test_self_cycle_returns_error(self) -> None: + """__tvm_ffi_value__() returning self doesn't infinite-loop.""" + + class SelfCycle: + def __tvm_ffi_value__(self) -> object: + return self + + with pytest.raises(TypeError, match="expected int"): + A(int).check_value(SelfCycle()) + + def test_any_self_cycle_returns_original_error(self) -> None: + """Any also routes __tvm_ffi_value__ through the bounded fallback loop.""" + call_count = 0 + + class SelfCycle: + def __tvm_ffi_value__(self) -> object: + nonlocal call_count + call_count += 1 + return self + + with pytest.raises(TypeError, match=r"failed to convert Any from .*SelfCycle"): + A(typing.Any).convert(SelfCycle()) + assert call_count == 1 + + def test_mutual_cycle_bounded(self) -> None: + """Mutual cycle is bounded by explicit depth limit.""" + + class A: + def __init__(self) -> None: + self.other: object = None + + def __tvm_ffi_value__(self) -> object: + return self.other + + class B: + def __init__(self) -> None: + self.other: object = None + + def __tvm_ffi_value__(self) -> object: + return self.other + + a, b = A(), B() + a.other = b + b.other = a + + # Should not hang — bounded by depth limit in the fallback loop + with pytest.raises(TypeError, match="cycle"): + S("int").check_value(a) + + def test_any_mutual_cycle_bounded(self) -> None: + """Any reports a bounded cycle instead of recursing in raw CAny packing.""" + + class Left: + def __init__(self) -> None: + self.other: object = None + + def __tvm_ffi_value__(self) -> object: + return self.other + + class Right: + def __init__(self) -> None: + self.other: object = None + + def __tvm_ffi_value__(self) -> object: + return self.other + + a, b = Left(), Right() + a.other = b + b.other = a + + with pytest.raises(TypeError, match="cycle"): + A(typing.Any).convert(a) + + def test_value_protocol_deep_chain_hits_cycle_limit(self) -> None: + """Long __tvm_ffi_value__ chains trip the explicit depth guard.""" + + class DeepChain: + def __init__(self, depth: int) -> None: + self.depth = depth + + def __tvm_ffi_value__(self) -> object: + return DeepChain(self.depth + 1) if self.depth < 100 else 42 + + with pytest.raises(TypeError, match="cycle"): + A(int).check_value(DeepChain(0)) + + +# --------------------------------------------------------------------------- +# Category 58: Object marshal fallback +# --------------------------------------------------------------------------- +class TestObjectConvertAttrRegistration: + """Object targets register __ffi_convert__ consistently.""" + + def test_core_object_types_register_convert_attr(self) -> None: + """Core object types register ``__ffi_convert__``.""" + for type_key in ( + "ffi.Object", + "ffi.Function", + "ffi.Error", + "ffi.String", + "ffi.Bytes", + "ffi.Array", + "ffi.List", + "ffi.Map", + "ffi.Dict", + "ffi.Shape", + "ffi.Tensor", + ): + type_index = _object_type_key_to_index(type_key) + assert type_index is not None + assert _lookup_type_attr(type_index, "__ffi_convert__") is not None + + def test_explicit_ref_registration_registers_convert_attr(self) -> None: + """Explicit ``.ref()`` registration adds ``__ffi_convert__``.""" + type_index = _object_type_key_to_index("testing.TestIntPair") + assert type_index is not None + assert _lookup_type_attr(type_index, "__ffi_convert__") is not None + + def test_reflected_object_without_ref_does_not_register_convert_attr(self) -> None: + """Reflected object classes without ``.ref()`` do not auto-register.""" + type_index = _object_type_key_to_index("testing.TestObjectBase") + assert type_index is not None + assert _lookup_type_attr(type_index, "__ffi_convert__") is None + + +class TestObjectMarshalFallback: + """Object schema accepts values that the marshal path converts to Objects.""" + + def test_exception_accepted_by_object_schema(self) -> None: + """TypeSchema('Object') accepts Exception (-> ffi.Error).""" + A(tvm_ffi.core.Object).check_value(RuntimeError("test")) + + def test_exception_accepted_by_error_schema(self) -> None: + """TypeSchema('ffi.Error') accepts Exception.""" + A(tvm_ffi.core.Error).check_value(ValueError("oops")) + + def test_shape_accepted_from_python_list_via_convert(self) -> None: + """TypeSchema('ffi.Shape') converts Python lists via __ffi_convert__.""" + result = S("ffi.Shape").convert([1, 2, 3]).to_py() + assert isinstance(result, tvm_ffi.Shape) + assert tuple(result) == (1, 2, 3) + + def test_shape_accepted_from_ffi_list_via_convert(self) -> None: + """TypeSchema('ffi.Shape') converts ffi.List via __ffi_convert__.""" + result = S("ffi.Shape").convert(tvm_ffi.List([1, 2, 3])).to_py() + assert isinstance(result, tvm_ffi.Shape) + assert tuple(result) == (1, 2, 3) + + def test_shape_accepted_from_ffi_array_via_convert(self) -> None: + """TypeSchema('ffi.Shape') converts ffi.Array via __ffi_convert__.""" + result = S("ffi.Shape").convert(tvm_ffi.Array([1, 2, 3])).to_py() + assert isinstance(result, tvm_ffi.Shape) + assert tuple(result) == (1, 2, 3) + + def test_shape_convert_repeated_conversions(self) -> None: + """Repeated __ffi_convert__ conversions keep returning live owning objects.""" + result1 = S("ffi.Shape").convert([1, 2, 3]).to_py() + result2 = S("ffi.Shape").convert(tvm_ffi.List([1, 2, 3])).to_py() + assert isinstance(result1, tvm_ffi.Shape) + assert isinstance(result2, tvm_ffi.Shape) + assert tuple(result1) == (1, 2, 3) + assert tuple(result2) == (1, 2, 3) + + def test_exception_rejected_by_array_schema(self) -> None: + """Exception is NOT accepted by Array schema (Error !IS-A Array).""" + with pytest.raises(TypeError, match="expected Array"): + A(tvm_ffi.Array).check_value(RuntimeError("x")) + + def test_opaque_object_accepted_by_object_schema(self) -> None: + """TypeSchema('Object') accepts arbitrary Python objects (-> OpaquePyObject).""" + + class Custom: + pass + + A(tvm_ffi.core.Object).check_value(Custom()) + + def test_plain_object_accepted_by_object_schema(self) -> None: + """TypeSchema('Object') accepts object().""" + A(tvm_ffi.core.Object).check_value(object()) + + def test_opaque_rejected_by_specific_schema(self) -> None: + """Specific schema rejects arbitrary Python object.""" + + class Custom: + pass + + with pytest.raises(TypeError, match=r"got .*Custom"): + A(TestIntPair).check_value(Custom()) + + @pytest.mark.xfail(reason="SmallStr -> ObjectRef conversion is not supported yet") + def test_str_accepted_by_object_schema(self) -> None: + """TypeSchema('Object') accepts str (-> ffi.String IS-A Object).""" + A(tvm_ffi.core.Object).check_value("hello") + + @pytest.mark.xfail(reason="SmallBytes -> ObjectRef conversion is not supported yet") + def test_bytes_accepted_by_object_schema(self) -> None: + """TypeSchema('Object') accepts bytes (-> ffi.Bytes IS-A Object).""" + A(tvm_ffi.core.Object).check_value(b"hello") + + def test_list_accepted_by_object_schema(self) -> None: + """TypeSchema('Object') accepts list (-> ffi.Array IS-A Object).""" + A(tvm_ffi.core.Object).check_value([1, 2, 3]) + + def test_dict_accepted_by_object_schema(self) -> None: + """TypeSchema('Object') accepts dict (-> ffi.Map IS-A Object).""" + A(tvm_ffi.core.Object).check_value({"a": 1}) + + def test_callable_accepted_by_object_schema(self) -> None: + """TypeSchema('Object') accepts callable (-> ffi.Function IS-A Object).""" + A(tvm_ffi.core.Object).check_value(lambda: None) + + def test_int_rejected_by_object_schema(self) -> None: + """TypeSchema('Object') rejects int (int is a POD type, not Object).""" + with pytest.raises(TypeError): + A(tvm_ffi.core.Object).check_value(42) + + def test_float_rejected_by_object_schema(self) -> None: + """TypeSchema('Object') rejects float (float is a POD, not Object).""" + with pytest.raises(TypeError): + A(tvm_ffi.core.Object).check_value(3.14) + + def test_none_rejected_by_object_schema(self) -> None: + """TypeSchema('Object') rejects None (None is a POD, not Object).""" + with pytest.raises(TypeError): + A(tvm_ffi.core.Object).check_value(None) + + +# --------------------------------------------------------------------------- +# Category 59: __cuda_stream__ for ctypes.c_void_p +# --------------------------------------------------------------------------- +class TestCudaStreamProtocol: + """ctypes.c_void_p schema accepts __cuda_stream__ protocol.""" + + def test_cuda_stream_accepted(self) -> None: + """Object with __cuda_stream__ passes ctypes.c_void_p schema.""" + + class CUStream: + def __cuda_stream__(self) -> tuple[int, int]: + return (0, 0) + + A(ctypes.c_void_p).check_value(CUStream()) + + def test_cuda_stream_convert(self) -> None: + """Convert returns __cuda_stream__ value as-is.""" + + class CUStream: + def __cuda_stream__(self) -> tuple[int, int]: + return (0, 123) + + obj = CUStream() + result = A(ctypes.c_void_p).convert(obj).to_py() + assert result is not None + + def test_cuda_stream_and_opaque_ptr(self) -> None: + """Object with both __cuda_stream__ and __tvm_ffi_opaque_ptr__ accepted.""" + + class DualProto: + def __cuda_stream__(self) -> tuple[int, int]: + return (0, 0) + + def __tvm_ffi_opaque_ptr__(self) -> int: + return 0 + + A(ctypes.c_void_p).check_value(DualProto()) + + +# --------------------------------------------------------------------------- +# Category 60: Device __dlpack__ guard +# --------------------------------------------------------------------------- +class TestDeviceDlpackGuard: + """Device schema respects __dlpack__ precedence.""" + + def test_both_dlpack_and_device_rejected_by_device(self) -> None: + """Object with both __dlpack__ and __dlpack_device__ rejected by Device.""" + + class TensorLike: + def __dlpack__(self) -> object: + return None + + def __dlpack_device__(self) -> tuple[int, int]: + return (1, 0) + + with pytest.raises(TypeError): + A(tvm_ffi.Device).check_value(TensorLike()) + + def test_both_dlpack_and_device_accepted_by_tensor(self) -> None: + """Object with both __dlpack__ and __dlpack_device__ accepted by Tensor.""" + np = pytest.importorskip("numpy") + + class TensorLike: + def __init__(self) -> None: + self.array = np.arange(4, dtype="int32") + + def __dlpack__(self) -> object: + return self.array.__dlpack__() + + def __dlpack_device__(self) -> tuple[int, int]: + return self.array.__dlpack_device__() + + A(tvm_ffi.Tensor).check_value(TensorLike()) + + def test_device_only_accepted_by_device(self) -> None: + """Object with only __dlpack_device__ still accepted by Device.""" + + class DevOnly: + def __dlpack_device__(self) -> tuple[int, int]: + return (1, 0) + + A(tvm_ffi.Device).check_value(DevOnly()) + + def test_dlpack_only_rejected_by_device(self) -> None: + """Object with only __dlpack__ rejected by Device schema.""" + + class DLPackOnly: + def __dlpack__(self) -> object: + return None + + with pytest.raises(TypeError): + A(tvm_ffi.Device).check_value(DLPackOnly()) + + +# --------------------------------------------------------------------------- +# Category 61: SKIP_DLPACK_C_EXCHANGE_API env gate +# --------------------------------------------------------------------------- +class TestSkipDlpackEnvGate: + """Tensor schema respects TVM_FFI_SKIP_DLPACK_C_EXCHANGE_API.""" + + def test_exchange_api_accepted_by_default(self) -> None: + """__dlpack_c_exchange_api__ accepted when env not set.""" + os.environ.pop("TVM_FFI_SKIP_DLPACK_C_EXCHANGE_API", None) + np = pytest.importorskip("numpy") + tensor = tvm_ffi.from_dlpack(np.arange(4, dtype="int32")) + wrapper = tvm_ffi.core.DLTensorTestWrapper(tensor) + A(tvm_ffi.Tensor).check_value(wrapper) + + def test_exchange_api_rejected_when_skipped(self) -> None: + """__dlpack_c_exchange_api__ rejected when env=1.""" + os.environ["TVM_FFI_SKIP_DLPACK_C_EXCHANGE_API"] = "1" + try: + + class ExchangeAPI: + def __dlpack_c_exchange_api__(self) -> int: + return 0 + + with pytest.raises(TypeError): + A(tvm_ffi.Tensor).check_value(ExchangeAPI()) + finally: + del os.environ["TVM_FFI_SKIP_DLPACK_C_EXCHANGE_API"] + + +# --------------------------------------------------------------------------- +# Category 62: from_type_index low-level indices +# --------------------------------------------------------------------------- +class TestFromTypeIndexLowLevel: + """from_type_index handles all built-in type indices.""" + + def test_dl_tensor_ptr(self) -> None: + """KTVMFFIDLTensorPtr maps to Tensor.""" + s = TypeSchema.from_type_index(7) # kTVMFFIDLTensorPtr + assert s.origin == "Tensor" + + def test_raw_str(self) -> None: + """KTVMFFIRawStr maps to str.""" + s = TypeSchema.from_type_index(8) # kTVMFFIRawStr + assert s.origin == "str" + + def test_byte_array_ptr(self) -> None: + """KTVMFFIByteArrayPtr maps to bytes.""" + s = TypeSchema.from_type_index(9) # kTVMFFIByteArrayPtr + assert s.origin == "bytes" + + def test_object_rvalue_ref(self) -> None: + """KTVMFFIObjectRValueRef maps to Object.""" + s = TypeSchema.from_type_index(10) # kTVMFFIObjectRValueRef + assert s.origin == "Object" + + def test_small_str(self) -> None: + """KTVMFFISmallStr maps to str.""" + s = TypeSchema.from_type_index(11) # kTVMFFISmallStr + assert s.origin == "str" + + def test_small_bytes(self) -> None: + """KTVMFFISmallBytes maps to bytes.""" + s = TypeSchema.from_type_index(12) # kTVMFFISmallBytes + assert s.origin == "bytes" + + def test_all_low_level_schemas_usable(self) -> None: + """Schemas from low-level indices can be used for conversion.""" + for idx in (7, 8, 9, 11, 12): + s = TypeSchema.from_type_index(idx) + # Trigger converter build; some schemas raise TypeError for None + try: + s.convert(None) + except TypeError: + pass + + +# --------------------------------------------------------------------------- +# Category 63: STL origin parsing +# --------------------------------------------------------------------------- +class TestSTLOriginParsing: + """C++ STL schema origins are correctly parsed.""" + + def test_std_vector(self) -> None: + """std::vector maps to Array.""" + s = TypeSchema.from_json_str('{"type":"std::vector","args":[{"type":"int"}]}') + assert s.origin == "Array" + + def test_std_optional(self) -> None: + """std::optional maps to Optional.""" + s = TypeSchema.from_json_str('{"type":"std::optional","args":[{"type":"int"}]}') + assert s.origin == "Optional" + assert repr(s) == "int | None" + + def test_std_variant(self) -> None: + """std::variant maps to Union.""" + s = TypeSchema.from_json_str( + '{"type":"std::variant","args":[{"type":"int"},{"type":"str"}]}' + ) + assert s.origin == "Union" + assert repr(s) == "int | str" + + def test_std_tuple(self) -> None: + """std::tuple maps to tuple.""" + s = TypeSchema.from_json_str('{"type":"std::tuple","args":[{"type":"int"},{"type":"str"}]}') + assert s.origin == "tuple" + + def test_std_map(self) -> None: + """std::map maps to Map.""" + s = TypeSchema.from_json_str('{"type":"std::map","args":[{"type":"str"},{"type":"int"}]}') + assert s.origin == "Map" + + def test_std_unordered_map(self) -> None: + """std::unordered_map maps to Map.""" + s = TypeSchema.from_json_str( + '{"type":"std::unordered_map","args":[{"type":"str"},{"type":"int"}]}' + ) + assert s.origin == "Map" + + def test_std_function(self) -> None: + """std::function maps to Callable.""" + s = TypeSchema.from_json_str( + '{"type":"std::function","args":[{"type":"int"},{"type":"str"}]}' + ) + assert s.origin == "Callable" + + def test_object_rvalue_ref_origin(self) -> None: + """ObjectRValueRef maps to Object.""" + s = TypeSchema.from_json_str('{"type":"ObjectRValueRef","args":[]}') + assert s.origin == "Object" + + +# --------------------------------------------------------------------------- +# Category 64: Zero-copy container conversion +# --------------------------------------------------------------------------- +class TestZeroCopyConversion: + """Typed container conversion preserves identity when no elements change.""" + + @requires_py39 + def test_array_int_exact_list(self) -> None: + """Array[int] on exact Python list converts successfully.""" + original = [1, 2, 3] + result = A(tuple[int, ...]).convert(original).to_py() + assert list(result) == original + + @requires_py39 + def test_array_int_needs_conversion(self) -> None: + """Array[int] on list needing bool->int returns converted list.""" + original = [1, True, 3] + result = A(tuple[int, ...]).convert(original).to_py() + assert list(result) == [1, 1, 3] + + @requires_py39 + def test_map_str_int_exact_dict(self) -> None: + """Map[str, int] on exact dict converts successfully.""" + original = {"a": 1, "b": 2} + result = A(tvm_ffi.Map[str, int]).convert(original).to_py() + assert dict(result) == original + + @requires_py39 + def test_map_str_int_needs_conversion(self) -> None: + """Map[str, int] on dict needing conversion returns converted dict.""" + original = {"a": True, "b": 2} + result = A(tvm_ffi.Map[str, int]).convert(original).to_py() + assert result is not None + + @requires_py39 + def test_tuple_exact_match(self) -> None: + """tuple[int, str] on exact tuple converts successfully.""" + original = (42, "hello") + result = A(tuple[int, str]).convert(original).to_py() + assert tuple(result) == original + + @requires_py39 + def test_tuple_needs_conversion(self) -> None: + """tuple[int, str] on tuple needing conversion returns converted tuple.""" + original = (True, "hello") + result = A(tuple[int, str]).convert(original).to_py() + assert tuple(result) == (1, "hello") + + @requires_py39 + def test_list_int_exact(self) -> None: + """List[int] on exact list converts successfully.""" + original = [10, 20] + result = A(list[int]).convert(original).to_py() + assert list(result) == original + + +# --------------------------------------------------------------------------- +# Category 65: Exception normalization in check_value/convert +# --------------------------------------------------------------------------- +class TestExceptionNormalization: + """check_value/convert normalize custom __int__/__float__ failures.""" + + def test_broken_integral_convert(self) -> None: + """Integral with broken __int__ caught by convert.""" + + class BadIntegral: + def __int__(self) -> int: + raise OverflowError("too big") + + Integral.register(BadIntegral) + + with pytest.raises(TypeError, match="too big"): + A(int).convert(BadIntegral()) + + def test_broken_integral_check_value(self) -> None: + """Integral with broken __int__ handled by check_value.""" + + class BrokenInt: + def __int__(self) -> int: + raise ValueError("broken") + + Integral.register(BrokenInt) + + # check_value should raise TypeError (wrapping the ValueError) + with pytest.raises(TypeError, match="broken"): + A(int).check_value(BrokenInt()) + + def test_broken_integral_bool_check_value(self) -> None: + """Integral with broken __bool__ is normalized to TypeError.""" + + class BrokenBoolInt: + def __int__(self) -> int: + return 1 + + def __bool__(self) -> bool: + raise RuntimeError("broken bool") + + Integral.register(BrokenBoolInt) + + with pytest.raises(TypeError, match="broken bool"): + A(bool).check_value(BrokenBoolInt()) + + def test_union_falls_back_after_broken_bool(self) -> None: + """Union keeps trying alternatives when bool conversion fails.""" + + class BrokenBoolStr(str): + def __bool__(self) -> bool: + raise RuntimeError("broken bool") + + Integral.register(BrokenBoolStr) + + result = A(Union[bool, str]).convert(BrokenBoolStr("hello")).to_py() + assert result == "hello" + + +# --------------------------------------------------------------------------- +# Category 66: __tvm_ffi_value__ eager normalization +# --------------------------------------------------------------------------- +class TestValueProtocolPrecedence: + """__tvm_ffi_value__ runs before schema-specific dispatch.""" + + def test_value_protocol_runs_before_int_protocol(self) -> None: + """__tvm_ffi_value__ is applied before __tvm_ffi_int__.""" + + class Dual: + def __tvm_ffi_int__(self) -> int: + return 42 + + def __tvm_ffi_value__(self) -> object: + return TestIntPair(1, 2) + + with pytest.raises(TypeError): + A(int).check_value(Dual()) + A(tvm_ffi.core.Object).check_value(Dual()) + + def test_value_protocol_runs_before_float_protocol(self) -> None: + """__tvm_ffi_value__ is applied before __tvm_ffi_float__.""" + + class Dual: + def __tvm_ffi_float__(self) -> float: + return 1.0 + + def __tvm_ffi_value__(self) -> object: + return TestIntPair(1, 2) + + with pytest.raises(TypeError): + A(float).check_value(Dual()) + A(tvm_ffi.core.Object).check_value(Dual()) + + def test_pure_value_protocol_still_works(self) -> None: + """Class with ONLY __tvm_ffi_value__ still converts eagerly.""" + + class PureVP: + def __tvm_ffi_value__(self) -> object: + return 42 + + A(int).check_value(PureVP()) + + def test_value_protocol_runs_before_callable_dispatch(self) -> None: + """Callable classes are normalized through __tvm_ffi_value__ first.""" + + class CallableVP: + def __call__(self) -> None: + pass + + def __tvm_ffi_value__(self) -> object: + return 42 + + with pytest.raises(TypeError): + A(Callable).check_value(CallableVP()) + A(int).check_value(CallableVP()) + + def test_object_protocol_precedes_value_and_convertible(self) -> None: + """__tvm_ffi_object__ wins over value and ObjectConvertible hooks.""" + inner = TestIntPair(10, 20) + + class Both(ObjectConvertible): + def __tvm_ffi_object__(self) -> object: + return inner + + def __tvm_ffi_value__(self) -> object: + return 999 + + def asobject(self) -> tvm_ffi.core.Object: + return TestIntPair(99, 99) + + result = A(tvm_ffi.core.Object).convert(Both()).to_py() + assert result.same_as(inner) + + +# --------------------------------------------------------------------------- +# Category 67: Union single-call __tvm_ffi_value__ +# --------------------------------------------------------------------------- +class TestUnionValueProtocol: + """Union dispatches __tvm_ffi_value__ once, not per-alternative.""" + + def test_union_value_protocol_once(self) -> None: + """__tvm_ffi_value__ called once for Union.""" + call_count = 0 + + class CountingVP: + def __tvm_ffi_value__(self) -> object: + nonlocal call_count + call_count += 1 + return 42 + + A(Union[str, int]).check_value(CountingVP()) + assert call_count == 1 + + def test_union_value_protocol_mismatch(self) -> None: + """__tvm_ffi_value__ returning wrong type fails Union.""" + call_count = 0 + + class WrongVP: + def __tvm_ffi_value__(self) -> object: + nonlocal call_count + call_count += 1 + return object() + + with pytest.raises(TypeError): + A(Union[int, str]).check_value(WrongVP()) + assert call_count == 1 + + +# --------------------------------------------------------------------------- +# Category 68: from_json_obj robustness +# --------------------------------------------------------------------------- +class TestFromJsonObjRobustness: + """from_json_obj handles non-dict args and malformed input.""" + + def test_non_dict_args_skipped(self) -> None: + """Non-dict elements in args list are silently skipped.""" + obj = {"type": "std::vector", "args": [{"type": "int"}, 42]} + s = TypeSchema.from_json_obj(obj) + assert s.origin == "Array" + assert len(s.args) == 1 + assert s.args[0].origin == "int" + + def test_malformed_input_raises_type_error(self) -> None: + """Non-dict top-level raises TypeError, not AssertionError.""" + with pytest.raises(TypeError, match="expected schema dict"): + TypeSchema.from_json_obj("not_a_dict") # type: ignore[arg-type] + + def test_missing_type_key_raises_type_error(self) -> None: + """Dict without 'type' key raises TypeError.""" + with pytest.raises(TypeError, match="expected schema dict"): + TypeSchema.from_json_obj({"args": []}) + + +# --------------------------------------------------------------------------- +# Category 69: Mutual-cycle RecursionError normalized +# --------------------------------------------------------------------------- +class TestMutualCycleNormalization: + """Mutual __tvm_ffi_value__ cycles produce TypeError, not RecursionError.""" + + def test_mutual_cycle_check_value(self) -> None: + """check_value normalizes mutual-cycle RecursionError to TypeError.""" + + class A: + def __init__(self) -> None: + self.other: object = None + + def __tvm_ffi_value__(self) -> object: + return self.other + + class B: + def __init__(self) -> None: + self.other: object = None + + def __tvm_ffi_value__(self) -> object: + return self.other + + a, b = A(), B() + a.other = b + b.other = a + + with pytest.raises(TypeError, match="cycle"): + S("int").check_value(a) + + def test_mutual_cycle_convert(self) -> None: + """Convert normalizes mutual-cycle RecursionError to TypeError.""" + + class A: + def __init__(self) -> None: + self.other: object = None + + def __tvm_ffi_value__(self) -> object: + return self.other + + class B: + def __init__(self) -> None: + self.other: object = None + + def __tvm_ffi_value__(self) -> object: + return self.other + + a, b = A(), B() + a.other = b + b.other = a + + with pytest.raises(TypeError, match="cycle"): + S("int").convert(a) + + +# --------------------------------------------------------------------------- +# Category 70: ObjectConvertible vs __tvm_ffi_value__ precedence +# --------------------------------------------------------------------------- +class TestObjectConvertiblePrecedence: + """__tvm_ffi_value__ takes precedence over ObjectConvertible.""" + + def test_value_protocol_wins_over_convertible(self) -> None: + """Class with both __tvm_ffi_value__ and ObjectConvertible uses fallback.""" + pair = TestIntPair(10, 20) + + class DualProtocol(ObjectConvertible): + def asobject(self) -> tvm_ffi.core.Object: + return pair + + def __tvm_ffi_value__(self) -> object: + return 42 + + # int schema: __tvm_ffi_value__ returns 42, accepted + A(int).check_value(DualProtocol()) + # Object schema: __tvm_ffi_value__ returns 42 (POD int, not Object), + # should REJECT (not accept via ObjectConvertible) + with pytest.raises(TypeError): + A(tvm_ffi.core.Object).check_value(DualProtocol()) + + def test_pure_convertible_still_works(self) -> None: + """ObjectConvertible without __tvm_ffi_value__ still accepted.""" + pair = TestIntPair(1, 2) + + class PureConvertible(ObjectConvertible): + def asobject(self) -> tvm_ffi.core.Object: + return pair + + A(tvm_ffi.core.Object).check_value(PureConvertible()) + A(TestIntPair).check_value(PureConvertible()) + + +# --------------------------------------------------------------------------- +# Category 71: from_json_obj non-iterable args +# --------------------------------------------------------------------------- +class TestFromJsonObjNonIterableArgs: + """from_json_obj handles non-iterable args values gracefully.""" + + def test_non_iterable_args_treated_as_empty(self) -> None: + """Non-list/tuple args value (e.g., int) treated as empty args.""" + s = TypeSchema.from_json_obj({"type": "int", "args": 42}) + assert s.origin == "int" + assert s.args == () + + def test_string_args_treated_as_empty(self) -> None: + """String args value treated as empty (not iterated char-by-char).""" + s = TypeSchema.from_json_obj({"type": "int", "args": "bad"}) + assert s.origin == "int" + assert s.args == () + + +# --------------------------------------------------------------------------- +# CAny class tests +# --------------------------------------------------------------------------- +class TestCAny: + """Tests for the CAny owned-value container.""" + + def test_cany_from_int(self) -> None: + """convert(int) returns CAny with correct type_index.""" + cany = A(int).convert(42) + assert isinstance(cany, CAny) + assert cany.type_index == 1 # kTVMFFIInt + + def test_cany_from_float(self) -> None: + """convert(float) returns CAny with correct type_index.""" + cany = A(float).convert(3.14) + assert isinstance(cany, CAny) + assert cany.type_index == 3 # kTVMFFIFloat + + def test_cany_from_bool(self) -> None: + """convert(bool) returns CAny with correct type_index.""" + cany = A(bool).convert(True) + assert isinstance(cany, CAny) + assert cany.type_index == 2 # kTVMFFIBool + + def test_cany_from_none(self) -> None: + """convert(None) returns CAny with type_index 0.""" + cany = A(None).convert(None) + assert isinstance(cany, CAny) + assert cany.type_index == 0 # kTVMFFINone + + def test_cany_from_str(self) -> None: + """convert(str) returns CAny.""" + cany = A(str).convert("hello") + assert isinstance(cany, CAny) + # Short strings have type_index=11 (SmallStr), longer ones have 65 (Str) + assert cany.type_index in (11, 65) + + @requires_py39 + def test_cany_from_array(self) -> None: + """convert(Array) returns CAny with array type_index.""" + cany = A(tuple[int, ...]).convert([1, 2, 3]) + assert isinstance(cany, CAny) + assert cany.type_index >= 64 # object type + + def test_to_py_int(self) -> None: + """to_py() round-trips int correctly.""" + result = A(int).convert(42).to_py() + assert result == 42 + assert type(result) is int + + def test_to_py_float(self) -> None: + """to_py() round-trips float correctly.""" + result = A(float).convert(3.14).to_py() + assert result == 3.14 + assert type(result) is float + + def test_to_py_bool(self) -> None: + """to_py() round-trips bool correctly.""" + assert A(bool).convert(True).to_py() is True + assert A(bool).convert(False).to_py() is False + + def test_to_py_none(self) -> None: + """to_py() round-trips None correctly.""" + assert A(None).convert(None).to_py() is None + + def test_to_py_str(self) -> None: + """to_py() round-trips str correctly.""" + assert A(str).convert("hello").to_py() == "hello" + + @requires_py39 + def test_to_py_array(self) -> None: + """to_py() returns ffi.Array for Array convert.""" + result = A(tuple[int, ...]).convert([1, 2, 3]).to_py() + assert isinstance(result, tvm_ffi.Array) + assert list(result) == [1, 2, 3] + + @requires_py39 + def test_to_py_list(self) -> None: + """to_py() returns ffi.List for List convert.""" + result = A(list[int]).convert([1, 2, 3]).to_py() + assert isinstance(result, tvm_ffi.List) + assert list(result) == [1, 2, 3] + + @requires_py39 + def test_to_py_map(self) -> None: + """to_py() returns ffi.Map for Map convert.""" + result = A(tvm_ffi.Map[str, int]).convert({"a": 1}).to_py() + assert isinstance(result, tvm_ffi.Map) + + @requires_py39 + def test_to_py_dict(self) -> None: + """to_py() returns ffi.Dict for Dict convert.""" + result = A(dict[str, int]).convert({"a": 1}).to_py() + assert isinstance(result, tvm_ffi.Dict) + + def test_multiple_to_py_calls(self) -> None: + """to_py() can be called multiple times safely.""" + cany = A(int).convert(42) + assert cany.to_py() == 42 + assert cany.to_py() == 42 + assert cany.to_py() == 42 + + @requires_py39 + def test_object_refcount_safety(self) -> None: + """to_py() for objects properly IncRefs — no double-free.""" + cany = A(tuple[int, ...]).convert([1, 2, 3]) + py1 = cany.to_py() + py2 = cany.to_py() + del cany # CAny.__dealloc__ runs + assert list(py1) == [1, 2, 3] + assert list(py2) == [1, 2, 3] + + def test_repr_int(self) -> None: + """Repr shows type and value for int.""" + cany = A(int).convert(42) + assert "int" in repr(cany) + assert "42" in repr(cany) + + def test_repr_none(self) -> None: + """Repr shows None.""" + cany = A(None).convert(None) + assert "None" in repr(cany) + + def test_repr_float(self) -> None: + """Repr shows float value.""" + cany = A(float).convert(3.14) + assert "float" in repr(cany) + + @requires_py39 + def test_repr_object(self) -> None: + """Repr shows type_index for objects.""" + cany = A(tuple[int, ...]).convert([1, 2, 3]) + assert "type_index" in repr(cany) + + def test_convert_raises_type_error(self) -> None: + """Convert still raises TypeError for incompatible values.""" + with pytest.raises(TypeError): + A(int).convert("hello") + + def test_check_value_does_not_return_cany(self) -> None: + """check_value returns None (not CAny).""" + result = A(int).check_value(42) + assert result is None + + +# --------------------------------------------------------------------------- +# from_annotation structural equality tests +# --------------------------------------------------------------------------- +class TestFromAnnotationScalars: + """Scalar types — from_annotation produces correct TypeSchema.""" + + def test_int(self) -> None: + """Int annotation.""" + assert A(int) == S("int") + + def test_float(self) -> None: + """Float annotation.""" + assert A(float) == S("float") + + def test_bool(self) -> None: + """Bool annotation.""" + assert A(bool) == S("bool") + + def test_str(self) -> None: + """Str annotation.""" + assert A(str) == S("str") + + def test_bytes(self) -> None: + """Bytes annotation.""" + assert A(bytes) == S("bytes") + + def test_none_type(self) -> None: + """type(None) annotation.""" + assert A(type(None)) == S("None") + + def test_none_literal(self) -> None: + """None annotation.""" + assert A(None) == S("None") + + def test_any(self) -> None: + """typing.Any annotation.""" + assert A(typing.Any) == S("Any") + + def test_tvm_ffi_string(self) -> None: + """tvm_ffi.String maps to str schema.""" + assert A(tvm_ffi.core.String) == S("str") + + def test_tvm_ffi_bytes(self) -> None: + """tvm_ffi.Bytes maps to bytes schema.""" + assert A(tvm_ffi.core.Bytes) == S("bytes") + + +class TestFromAnnotationFFITypes: + """FFI container and object types.""" + + def test_array(self) -> None: + """tvm_ffi.Array → canonical origin 'Array'.""" + assert A(tvm_ffi.Array) == S("Array") + + def test_list(self) -> None: + """tvm_ffi.List → same as A(list).""" + assert A(tvm_ffi.List) == A(list) + + def test_map(self) -> None: + """tvm_ffi.Map → canonical origin 'Map'.""" + assert A(tvm_ffi.Map) == S("Map") + + def test_dict(self) -> None: + """tvm_ffi.Dict → same as A(dict).""" + assert A(tvm_ffi.Dict) == A(dict) + + def test_function(self) -> None: + """tvm_ffi.Function → same as A(Callable).""" + assert A(tvm_ffi.core.Function) == A(Callable) + + def test_object(self) -> None: + """tvm_ffi.Object → canonical origin 'Object'.""" + assert A(tvm_ffi.core.Object) == S("Object") + + def test_tensor(self) -> None: + """tvm_ffi.Tensor → canonical origin 'Tensor'.""" + assert A(tvm_ffi.Tensor) == S("Tensor") + + def test_dtype(self) -> None: + """tvm_ffi.core.DataType → canonical origin 'dtype'.""" + assert A(tvm_ffi.core.DataType) == S("dtype") + + def test_device(self) -> None: + """tvm_ffi.Device → canonical origin 'Device'.""" + assert A(tvm_ffi.Device) == S("Device") + + def test_ctypes_c_void_p(self) -> None: + """ctypes.c_void_p → canonical origin 'ctypes.c_void_p'.""" + assert A(ctypes.c_void_p) == S("ctypes.c_void_p") + + @requires_py39 + def test_array_parameterized(self) -> None: + """tvm_ffi.Array[int] cross-equivalent to tuple[int, ...].""" + assert A(tvm_ffi.Array[int]) == A(tuple[int, ...]) + + @requires_py39 + def test_list_parameterized(self) -> None: + """tvm_ffi.List[str] cross-equivalent to list[str].""" + assert A(tvm_ffi.List[str]) == A(list[str]) + + @requires_py39 + def test_map_parameterized(self) -> None: + """tvm_ffi.Map[str, float].""" + assert A(tvm_ffi.Map[str, float]) == S("Map", S("str"), S("float")) + + @requires_py39 + def test_dict_parameterized(self) -> None: + """tvm_ffi.Dict[str, int] cross-equivalent to dict[str, int].""" + assert A(tvm_ffi.Dict[str, int]) == A(dict[str, int]) + + @requires_py39 + def test_array_too_many_args(self) -> None: + """tvm_ffi.Array[int, str] raises TypeError.""" + with pytest.raises(TypeError, match="requires 1"): + A(tvm_ffi.Array[int, str]) # type: ignore[type-arg] + + @requires_py39 + def test_list_too_many_args(self) -> None: + """tvm_ffi.List[int, str] raises TypeError.""" + with pytest.raises(TypeError, match="requires 1"): + A(tvm_ffi.List[int, str]) # type: ignore[type-arg] + + @requires_py39 + def test_dict_one_arg(self) -> None: + """tvm_ffi.Dict[str] raises TypeError.""" + with pytest.raises(TypeError, match="requires 2"): + A(tvm_ffi.Dict[str]) # type: ignore[type-arg] + + @requires_py39 + def test_dict_three_args(self) -> None: + """tvm_ffi.Dict[str, int, float] raises TypeError.""" + with pytest.raises(TypeError, match="requires 2"): + A(tvm_ffi.Dict[str, int, float]) # type: ignore[type-arg] + + @requires_py39 + def test_map_one_arg(self) -> None: + """tvm_ffi.Map[str] raises TypeError.""" + with pytest.raises(TypeError, match="requires 2"): + A(tvm_ffi.Map[str]) # type: ignore[type-arg] + + def test_unregistered_cobject_errors(self) -> None: + """Unregistered CObject subclass raises TypeError.""" + with pytest.raises(TypeError, match="not registered"): + A(tvm_ffi.core.CObject) + + +class TestFromAnnotationCallable: + """Callable annotation tests.""" + + def test_bare(self) -> None: + """Bare Callable.""" + assert A(Callable) == S("Callable") + + def test_bare_collections_abc(self) -> None: + """Bare collections.abc.Callable.""" + assert A(collections.abc.Callable) == S("Callable") + + def test_params(self) -> None: + """Callable[[int, str], bool].""" + assert A(Callable[[int, str], bool]) == S("Callable", S("bool"), S("int"), S("str")) + + def test_ellipsis(self) -> None: + """Callable[..., int].""" + assert A(Callable[..., int]) == S("Callable", S("int")) + + def test_no_params(self) -> None: + """Callable[[], int].""" + assert A(Callable[[], int]) == S("Callable", S("int")) + + +class TestFromAnnotationList: + """list[T] → List tests.""" + + def test_bare(self) -> None: + """Bare list.""" + assert A(list).origin == "List" + + @requires_py39 + def test_int(self) -> None: + """list[int].""" + assert A(list[int]) == S("List", S("int")) + + @requires_py39 + def test_nested(self) -> None: + """list[list[int]].""" + assert A(list[list[int]]) == S("List", S("List", S("int"))) + + +class TestFromAnnotationDict: + """dict[K, V] → Dict tests.""" + + def test_bare(self) -> None: + """Bare dict.""" + assert A(dict).origin == "Dict" + + @requires_py39 + def test_str_int(self) -> None: + """dict[str, int].""" + assert A(dict[str, int]) == S("Dict", S("str"), S("int")) + + +class TestFromAnnotationArray: + """tuple[T, ...] → Array tests.""" + + @requires_py39 + def test_int(self) -> None: + """tuple[int, ...].""" + assert A(tuple[int, ...]) == S("Array", S("int")) + + @requires_py39 + def test_float(self) -> None: + """tuple[float, ...].""" + assert A(tuple[float, ...]) == S("Array", S("float")) + + +class TestFromAnnotationTuple: + """tuple[T1, T2] (fixed) tests.""" + + def test_bare(self) -> None: + """Bare tuple.""" + assert A(tuple).origin == "tuple" + + @requires_py39 + def test_int_str(self) -> None: + """tuple[int, str].""" + assert A(tuple[int, str]) == S("tuple", S("int"), S("str")) + + @requires_py39 + def test_empty(self) -> None: + """tuple[()] stays distinct from bare tuple.""" + assert A(tuple[()]) == TypeSchema("tuple", ()) + + +class TestFromAnnotationOptional: + """Optional[T] tests.""" + + def test_int(self) -> None: + """Optional[int].""" + assert A(Optional[int]) == S("Optional", S("int")) + + def test_union_with_none_becomes_optional(self) -> None: + """Union[int, None] normalizes to Optional[int].""" + assert A(Union[int, None]) == S("Optional", S("int")) + + @pytest.mark.skipif(sys.version_info < (3, 10), reason="X | Y requires 3.10+") + def test_pipe_syntax(self) -> None: + """Int | None.""" + assert A(eval("int | None")) == S("Optional", S("int")) + + +class TestFromAnnotationUnion: + """Union[T1, T2] tests.""" + + def test_int_str(self) -> None: + """Union[int, str].""" + assert A(Union[int, str]) == S("Union", S("int"), S("str")) + + def test_nested_union_flattening(self) -> None: + """Nested unions flatten to a single Union schema.""" + assert A(Union[int, Union[str, float]]) == S("Union", S("int"), S("str"), S("float")) + + @pytest.mark.skipif(sys.version_info < (3, 10), reason="X | Y requires 3.10+") + def test_pipe_syntax(self) -> None: + """Int | str.""" + assert A(eval("int | str")) == S("Union", S("int"), S("str")) + + +class TestFromAnnotationObject: + """Registered CObject subclasses.""" + + def test_test_int_pair(self) -> None: + """TestIntPair annotation.""" + assert A(TestIntPair) == S("testing.TestIntPair") + + def test_cxx_class_base(self) -> None: + """_TestCxxClassBase annotation.""" + assert A(_TestCxxClassBase) == S("testing.TestCxxClassBase") + + +class TestFromAnnotationErrors: + """from_annotation raises TypeError for unsupported annotations.""" + + def test_unsupported_type(self) -> None: + """Complex is not supported.""" + with pytest.raises(TypeError, match="Cannot convert"): + A(complex) + + @requires_py39 + def test_list_too_many_args(self) -> None: + """list[int, int, float] raises.""" + with pytest.raises(TypeError, match="list takes at most 1"): + A(list[int, int, float]) # type: ignore[type-arg] + + @requires_py39 + def test_dict_one_arg(self) -> None: + """dict[str] raises.""" + with pytest.raises(TypeError, match="dict requires 0 or 2"): + A(dict[str]) # type: ignore[type-arg] + + +# --------------------------------------------------------------------------- +# Convert returns FFI containers +# --------------------------------------------------------------------------- +import tvm_ffi as _tvm_ffi + + +class TestConvertReturnFFIContainers: + """convert().to_py() returns ffi.Array/List/Map/Dict.""" + + @requires_py39 + def test_array_from_list(self) -> None: + """Array convert from Python list.""" + result = A(tuple[float, ...]).convert([1, 2, 3]).to_py() + assert isinstance(result, _tvm_ffi.Array) + assert list(result) == [1.0, 2.0, 3.0] + + @requires_py39 + def test_list_from_list(self) -> None: + """List convert from Python list.""" + result = A(list[int]).convert([1, 2, 3]).to_py() + assert isinstance(result, _tvm_ffi.List) + assert list(result) == [1, 2, 3] + + @requires_py39 + def test_dict_from_dict(self) -> None: + """Dict convert from Python dict.""" + result = A(dict[str, int]).convert({"a": 1}).to_py() + assert isinstance(result, _tvm_ffi.Dict) + + @requires_py39 + def test_map_from_dict(self) -> None: + """Map convert from Python dict.""" + result = A(tvm_ffi.Map[str, int]).convert({"a": 1}).to_py() + assert isinstance(result, _tvm_ffi.Map) + + @requires_py39 + def test_array_passthrough(self) -> None: + """ffi.Array input passes through unchanged.""" + arr = _tvm_ffi.Array([1, 2, 3]) + result = A(tuple[int, ...]).convert(arr).to_py() + assert result.same_as(arr) + + @requires_py39 + def test_list_passthrough(self) -> None: + """ffi.List input passes through unchanged.""" + lst = _tvm_ffi.List([1, 2, 3]) + result = A(list[int]).convert(lst).to_py() + assert result.same_as(lst) + + @requires_py39 + def test_array_subclass_passthrough(self) -> None: + """ffi.Array subclasses pass through unchanged.""" + + class MyArray(_tvm_ffi.Array): + pass + + arr = MyArray([1, 2, 3]) + result = A(tuple[int, ...]).convert(arr).to_py() + assert result.same_as(arr) + + @requires_py39 + def test_list_subclass_passthrough(self) -> None: + """ffi.List subclasses pass through unchanged.""" + + class MyList(_tvm_ffi.List): + pass + + lst = MyList([1, 2, 3]) + result = A(list[int]).convert(lst).to_py() + assert result.same_as(lst) + + @requires_py39 + def test_map_subclass_passthrough(self) -> None: + """ffi.Map subclasses pass through unchanged for Map[Any, Any].""" + + class MyMap(_tvm_ffi.Map): + pass + + m = MyMap({"a": 1}) + result = A(tvm_ffi.Map[typing.Any, typing.Any]).convert(m).to_py() + assert result.same_as(m) + + @requires_py39 + def test_dict_subclass_passthrough(self) -> None: + """ffi.Dict subclasses pass through unchanged for Dict[Any, Any].""" + + class MyDict(_tvm_ffi.Dict): + pass + + d = MyDict({"a": 1}) + result = A(dict[typing.Any, typing.Any]).convert(d).to_py() + assert result.same_as(d) + + @requires_py39 + def test_nested_array_convert(self) -> None: + """Nested array conversion.""" + result = A(tuple[tuple[int, ...], ...]).convert([[1, 2], [3, 4]]).to_py() + assert isinstance(result, _tvm_ffi.Array) + assert isinstance(result[0], _tvm_ffi.Array) + + +# --------------------------------------------------------------------------- +# FFI type guarantees: convert().to_py() always returns tvm_ffi types +# --------------------------------------------------------------------------- +class TestConvertToFFITypes: + """convert().to_py() returns canonical FFI types for all value kinds.""" + + def test_short_str_is_string(self) -> None: + """Short str (SmallStr) promotes to tvm_ffi.String.""" + result = A(str).convert("hi").to_py() + assert isinstance(result, tvm_ffi.core.String) + assert result == "hi" + + def test_long_str_is_string(self) -> None: + """Long str (kTVMFFIStr object) is tvm_ffi.String.""" + long_s = "x" * 200 + result = A(str).convert(long_s).to_py() + assert isinstance(result, tvm_ffi.core.String) + assert result == long_s + + def test_empty_str_is_string(self) -> None: + """Empty str is tvm_ffi.String.""" + result = A(str).convert("").to_py() + assert isinstance(result, tvm_ffi.core.String) + assert result == "" + + def test_short_bytes_is_bytes(self) -> None: + """Short bytes (SmallBytes) promotes to tvm_ffi.Bytes.""" + result = A(bytes).convert(b"hi").to_py() + assert isinstance(result, tvm_ffi.core.Bytes) + assert result == b"hi" + + def test_long_bytes_is_bytes(self) -> None: + """Long bytes (kTVMFFIBytes object) is tvm_ffi.Bytes.""" + long_b = b"x" * 200 + result = A(bytes).convert(long_b).to_py() + assert isinstance(result, tvm_ffi.core.Bytes) + assert result == long_b + + def test_empty_bytes_is_bytes(self) -> None: + """Empty bytes is tvm_ffi.Bytes.""" + result = A(bytes).convert(b"").to_py() + assert isinstance(result, tvm_ffi.core.Bytes) + assert result == b"" + + def test_bytearray_converts_to_ffi_bytes(self) -> None: + """Bytearray converts to tvm_ffi.Bytes.""" + result = A(bytes).convert(bytearray(b"hello")).to_py() + assert isinstance(result, tvm_ffi.core.Bytes) + assert result == b"hello" + + def test_callable_is_function(self) -> None: + """Callable converts to tvm_ffi.Function.""" + result = A(Callable).convert(lambda x: x).to_py() + assert isinstance(result, tvm_ffi.core.Function) + + @requires_py39 + def test_array_is_ffi_array(self) -> None: + """Array[int] converts to tvm_ffi.Array.""" + result = A(tuple[int, ...]).convert([1, 2]).to_py() + assert isinstance(result, _tvm_ffi.Array) + + @requires_py39 + def test_list_is_ffi_list(self) -> None: + """List[int] converts to tvm_ffi.List.""" + result = A(list[int]).convert([1, 2]).to_py() + assert isinstance(result, _tvm_ffi.List) + + @requires_py39 + def test_map_is_ffi_map(self) -> None: + """Map[str, int] converts to tvm_ffi.Map.""" + result = A(tvm_ffi.Map[str, int]).convert({"a": 1}).to_py() + assert isinstance(result, _tvm_ffi.Map) + + @requires_py39 + def test_dict_is_ffi_dict(self) -> None: + """Dict[str, int] converts to tvm_ffi.Dict.""" + result = A(dict[str, int]).convert({"a": 1}).to_py() + assert isinstance(result, _tvm_ffi.Dict) + + def test_int_is_int(self) -> None: + """Int stays as int.""" + result = A(int).convert(42).to_py() + assert type(result) is int + assert result == 42 + + def test_float_is_float(self) -> None: + """Float stays as float.""" + result = A(float).convert(3.14).to_py() + assert type(result) is float + assert result == 3.14 + + def test_bool_is_bool(self) -> None: + """Bool stays as bool.""" + result = A(bool).convert(True).to_py() + assert result is True + + def test_none_is_none(self) -> None: + """None stays as None.""" + result = A(None).convert(None).to_py() + assert result is None + + def test_object_is_cobject(self) -> None: + """Object converts to CObject subclass.""" + obj = TestIntPair(1, 2) + result = A(TestIntPair).convert(obj).to_py() + assert isinstance(result, tvm_ffi.core.CObject) + assert result.same_as(obj)