From 8c2dd20184dd355cb997864c670d5fa037bb2316 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Tue, 10 Mar 2026 10:43:26 -0700 Subject: [PATCH] feat(ffi)\!: centralize object creation with CreateEmptyObject/HasCreator Add `CreateEmptyObject()` and `HasCreator()` in `function.h` that unify the two-step pattern (check creator, call creator) used across reflection code. These functions first try the native `metadata->creator` fast path, then fall back to the `__ffi_new__` type attribute for Python-defined types. Deduplicate four call sites in `creator.h`, `init.h`, `reflection_extra.cc`, and `serialization.cc` to use the new centralized helpers. Also qualify bare `details::` references to `::tvm::ffi::details::` in reflection headers (`overload.h`, `registry.h`, `init.h`) to prevent lookup ambiguity when included from other namespaces. --- include/tvm/ffi/reflection/creator.h | 75 ++++++++++++++++++++++----- include/tvm/ffi/reflection/init.h | 20 +++---- include/tvm/ffi/reflection/overload.h | 13 ++--- include/tvm/ffi/reflection/registry.h | 21 ++++---- src/ffi/extra/reflection_extra.cc | 10 +--- src/ffi/extra/serialization.cc | 10 +--- 6 files changed, 90 insertions(+), 59 deletions(-) diff --git a/include/tvm/ffi/reflection/creator.h b/include/tvm/ffi/reflection/creator.h index 300ad512d..977e6dedd 100644 --- a/include/tvm/ffi/reflection/creator.h +++ b/include/tvm/ffi/reflection/creator.h @@ -23,13 +23,72 @@ #ifndef TVM_FFI_REFLECTION_CREATOR_H_ #define TVM_FFI_REFLECTION_CREATOR_H_ -#include #include +#include #include #include namespace tvm { namespace ffi { +/*! + * \brief Create an empty object via the type's native creator or ``__ffi_new__`` type attr. + * + * Falls back to the ``__ffi_new__`` type attribute (used by Python-defined types) + * when the native ``metadata->creator`` is NULL. + * + * \param type_info The type info for the object to create. + * \return An owned ObjectPtr to the newly allocated (zero-initialized) object. + * \throws RuntimeError if neither creator nor __ffi_new__ is available. + */ +inline ObjectPtr CreateEmptyObject(const TVMFFITypeInfo* type_info) { + // Fast path: native C++ creator + if (type_info->metadata != nullptr && type_info->metadata->creator != nullptr) { + TVMFFIObjectHandle handle; + TVM_FFI_CHECK_SAFE_CALL(type_info->metadata->creator(&handle)); + return details::ObjectUnsafe::ObjectPtrFromOwned(static_cast(handle)); + } + // Fallback: __ffi_new__ type attr (Python-defined types) + constexpr TVMFFIByteArray kFFINewAttrName = {"__ffi_new__", 11}; + const TVMFFITypeAttrColumn* column = TVMFFIGetTypeAttrColumn(&kFFINewAttrName); + if (column != nullptr) { + int32_t offset = type_info->type_index - column->begin_index; + if (offset >= 0 && offset < column->size) { + AnyView attr_view = AnyView::CopyFromTVMFFIAny(column->data[offset]); + if (auto opt_func = attr_view.try_cast()) { + ObjectRef obj_ref = (*opt_func)().cast(); + return details::ObjectUnsafe::ObjectPtrFromObjectRef(std::move(obj_ref)); + } + } + } + TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_info->type_index) + << "` does not support reflection creation" + << " (no native creator or __ffi_new__ type attr)"; +} + +/*! + * \brief Check whether a type supports reflection creation. + * + * Returns true if the type has a native creator or a ``__ffi_new__`` type attr. + * + * \param type_info The type info to check. + * \return true if CreateEmptyObject would succeed. + */ +inline bool HasCreator(const TVMFFITypeInfo* type_info) { + if (type_info->metadata != nullptr && type_info->metadata->creator != nullptr) { + return true; + } + constexpr TVMFFIByteArray kFFINewAttrName = {"__ffi_new__", 11}; + const TVMFFITypeAttrColumn* column = TVMFFIGetTypeAttrColumn(&kFFINewAttrName); + if (column != nullptr) { + int32_t offset = type_info->type_index - column->begin_index; + if (offset >= 0 && offset < column->size && + column->data[offset].type_index == TypeIndex::kTVMFFIFunction) { + return true; + } + } + return false; +} + namespace reflection { /*! * \brief helper wrapper class of TVMFFITypeInfo to create object based on reflection. @@ -48,13 +107,8 @@ class ObjectCreator { * \param type_info The type info. */ explicit ObjectCreator(const TVMFFITypeInfo* type_info) : type_info_(type_info) { - int32_t type_index = type_info->type_index; - if (type_info->metadata == nullptr) { - TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_index) - << "` does not have reflection registered"; - } - if (type_info->metadata->creator == nullptr) { - TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_index) + if (!HasCreator(type_info)) { + TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_info->type_index) << "` does not support default constructor, " << "as a result cannot be created via reflection"; } @@ -66,10 +120,7 @@ class ObjectCreator { * \return The created object. */ Any operator()(const Map& fields) const { - TVMFFIObjectHandle handle; - TVM_FFI_CHECK_SAFE_CALL(type_info_->metadata->creator(&handle)); - ObjectPtr ptr = - details::ObjectUnsafe::ObjectPtrFromOwned(static_cast(handle)); + ObjectPtr ptr = CreateEmptyObject(type_info_); size_t match_field_count = 0; ForEachFieldInfo(type_info_, [&](const TVMFFIFieldInfo* field_info) { String field_name(field_info->name); diff --git a/include/tvm/ffi/reflection/init.h b/include/tvm/ffi/reflection/init.h index 337753cea..8853f054d 100644 --- a/include/tvm/ffi/reflection/init.h +++ b/include/tvm/ffi/reflection/init.h @@ -29,6 +29,7 @@ #include #include #include +#include #include #include @@ -69,10 +70,8 @@ inline Function MakeInit(int32_t type_index) { }; // ---- Pre-compute field analysis (once per type) ------------------------- const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(type_index); - TVM_FFI_ICHECK(type_info->metadata != nullptr) - << "Type `" << TypeIndexToTypeKey(type_index) << "` has no reflection metadata"; - TVM_FFI_ICHECK(type_info->metadata->creator != nullptr) - << "Type `" << TypeIndexToTypeKey(type_index) << "` has no creator"; + TVM_FFI_ICHECK(HasCreator(type_info)) << "Type `" << TypeIndexToTypeKey(type_index) + << "` has no creator or __ffi_new__ for __ffi_init__"; auto info = std::make_shared(); info->type_key = std::string_view(type_info->type_key.data, type_info->type_key.size); @@ -101,16 +100,11 @@ inline Function MakeInit(int32_t type_index) { // Eagerly resolve the KWARGS sentinel via global function registry. ObjectRef kwargs_sentinel = Function::GetGlobalRequired("ffi.GetKwargsObject")().cast(); - // Cache pointers for the lambda (avoid repeated lookups). - TVMFFIObjectCreator creator = type_info->metadata->creator; return Function::FromPacked( - [info, kwargs_sentinel, creator](PackedArgs args, Any* rv) { - // ---- 1. Create object via creator ------------------------------------ - TVMFFIObjectHandle handle; - TVM_FFI_CHECK_SAFE_CALL(creator(&handle)); - ObjectPtr obj_ptr = - details::ObjectUnsafe::ObjectPtrFromOwned(static_cast(handle)); + [info, kwargs_sentinel, type_info](PackedArgs args, Any* rv) { + // ---- 1. Create object via CreateEmptyObject -------------------------- + ObjectPtr obj_ptr = CreateEmptyObject(type_info); // ---- 2. Find KWARGS sentinel position -------------------------------- int kwargs_pos = -1; @@ -219,7 +213,7 @@ inline void RegisterAutoInit(int32_t type_index) { info.flags = kTVMFFIFieldFlagBitMaskIsStaticMethod; info.method = AnyView(auto_init_fn).CopyToTVMFFIAny(); static const std::string kMetadata = - "{\"type_schema\":" + std::string(details::TypeSchemaImpl::v()) + + "{\"type_schema\":" + std::string(::tvm::ffi::details::TypeSchemaImpl::v()) + ",\"auto_init\":true}"; info.metadata = TVMFFIByteArray{kMetadata.c_str(), kMetadata.size()}; TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterMethod(type_index, &info)); diff --git a/include/tvm/ffi/reflection/overload.h b/include/tvm/ffi/reflection/overload.h index 8647c3c64..80e9bce90 100644 --- a/include/tvm/ffi/reflection/overload.h +++ b/include/tvm/ffi/reflection/overload.h @@ -404,14 +404,15 @@ class OverloadObjectDef : private ObjectDef { template static auto GetOverloadMethod(std::string name, Func&& func) { using WrapFn = decltype(WrapFunction(std::forward(func))); - using OverloadFn = details::OverloadedFunction>; + using OverloadFn = ::tvm::ffi::details::OverloadedFunction>; return ffi::Function::FromPackedInplace(WrapFunction(std::forward(func)), std::move(name)); } template static auto NewOverload(std::string name, Func&& func) { - return details::CreateNewOverload(WrapFunction(std::forward(func)), std::move(name)); + return ::tvm::ffi::details::CreateNewOverload(WrapFunction(std::forward(func)), + std::move(name)); } template @@ -452,7 +453,7 @@ class OverloadObjectDef : private ObjectDef { // initialize default value to nullptr info.default_value_or_factory = AnyView(nullptr).CopyToTVMFFIAny(); info.doc = TVMFFIByteArray{nullptr, 0}; - info.metadata_.emplace_back("type_schema", details::TypeSchema::v()); + info.metadata_.emplace_back("type_schema", ::tvm::ffi::details::TypeSchema::v()); // apply field info traits ((ApplyFieldInfoTrait(&info, std::forward(extra_args)), ...)); // call register @@ -464,7 +465,7 @@ class OverloadObjectDef : private ObjectDef { // register a method template void RegisterMethod(const char* name, bool is_static, Func&& func, Extra&&... extra) { - using FuncInfo = details::FunctionInfo>; + using FuncInfo = ::tvm::ffi::details::FunctionInfo>; MethodInfoBuilder info; info.name = TVMFFIByteArray{name, std::char_traits::length(name)}; info.doc = TVMFFIByteArray{nullptr, 0}; @@ -478,7 +479,7 @@ class OverloadObjectDef : private ObjectDef { // if an overload method exists, register to existing overload function if (const auto overload_it = registered_fields_.find(name); overload_it != registered_fields_.end()) { - details::OverloadBase* overload_ptr = overload_it->second; + ::tvm::ffi::details::OverloadBase* overload_ptr = overload_it->second; return overload_ptr->Register(NewOverload(std::move(method_name), std::forward(func))); } @@ -496,7 +497,7 @@ class OverloadObjectDef : private ObjectDef { TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterMethod(type_index_, &info)); } - std::unordered_map registered_fields_; + std::unordered_map registered_fields_; }; } // namespace reflection diff --git a/include/tvm/ffi/reflection/registry.h b/include/tvm/ffi/reflection/registry.h index 719d61219..17c0078b9 100644 --- a/include/tvm/ffi/reflection/registry.h +++ b/include/tvm/ffi/reflection/registry.h @@ -362,7 +362,8 @@ template TVM_FFI_INLINE int64_t GetFieldByteOffsetToObject(T Class::* field_ptr) { int64_t field_offset_to_class = reinterpret_cast(&(static_cast(nullptr)->*field_ptr)); - return field_offset_to_class - details::ObjectUnsafe::GetObjectOffsetToSubclass(); + return field_offset_to_class - + ::tvm::ffi::details::ObjectUnsafe::GetObjectOffsetToSubclass(); } /// \cond Doxygen_Suppress @@ -371,7 +372,7 @@ class ReflectionDefBase { template static int FieldGetter(void* field, TVMFFIAny* result) { TVM_FFI_SAFE_CALL_BEGIN(); - *result = details::AnyUnsafe::MoveAnyToTVMFFIAny(Any(*reinterpret_cast(field))); + *result = ::tvm::ffi::details::AnyUnsafe::MoveAnyToTVMFFIAny(Any(*reinterpret_cast(field))); TVM_FFI_SAFE_CALL_END(); } @@ -390,7 +391,7 @@ class ReflectionDefBase { static int ObjectCreatorDefault(TVMFFIObjectHandle* result) { TVM_FFI_SAFE_CALL_BEGIN(); ObjectPtr obj = make_object(); - *result = details::ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(std::move(obj)); + *result = ::tvm::ffi::details::ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(std::move(obj)); TVM_FFI_SAFE_CALL_END(); } @@ -398,7 +399,7 @@ class ReflectionDefBase { static int ObjectCreatorUnsafeInit(TVMFFIObjectHandle* result) { TVM_FFI_SAFE_CALL_BEGIN(); ObjectPtr obj = make_object(UnsafeInit{}); - *result = details::ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(std::move(obj)); + *result = ::tvm::ffi::details::ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(std::move(obj)); TVM_FFI_SAFE_CALL_END(); } @@ -499,7 +500,7 @@ class GlobalDef : public ReflectionDefBase { */ template GlobalDef& def(const char* name, Func&& func, Extra&&... extra) { - using FuncInfo = details::FunctionInfo>; + using FuncInfo = ::tvm::ffi::details::FunctionInfo>; RegisterFunc(name, ffi::Function::FromTyped(std::forward(func), std::string(name)), FuncInfo::TypeSchema(), std::forward(extra)...); return *this; @@ -519,8 +520,8 @@ class GlobalDef : public ReflectionDefBase { */ template GlobalDef& def_packed(const char* name, Func func, Extra&&... extra) { - RegisterFunc(name, ffi::Function::FromPacked(func), details::TypeSchemaImpl::v(), - std::forward(extra)...); + RegisterFunc(name, ffi::Function::FromPacked(func), + ::tvm::ffi::details::TypeSchemaImpl::v(), std::forward(extra)...); return *this; } @@ -540,7 +541,7 @@ class GlobalDef : public ReflectionDefBase { */ template GlobalDef& def_method(const char* name, Func&& func, Extra&&... extra) { - using FuncInfo = details::FunctionInfo>; + using FuncInfo = ::tvm::ffi::details::FunctionInfo>; RegisterFunc(name, GetMethod(std::string(name), std::forward(func)), FuncInfo::TypeSchema(), std::forward(extra)...); return *this; @@ -915,7 +916,7 @@ class ObjectDef : public ReflectionDefBase { // initialize default value to nullptr info.default_value_or_factory = AnyView(nullptr).CopyToTVMFFIAny(); info.doc = TVMFFIByteArray{nullptr, 0}; - info.metadata_.emplace_back("type_schema", details::TypeSchema::v()); + info.metadata_.emplace_back("type_schema", ::tvm::ffi::details::TypeSchema::v()); // apply field info traits ((ApplyFieldInfoTrait(&info, std::forward(extra_args)), ...)); // call register @@ -927,7 +928,7 @@ class ObjectDef : public ReflectionDefBase { // register a method template void RegisterMethod(const char* name, bool is_static, Func&& func, Extra&&... extra) { - using FuncInfo = details::FunctionInfo>; + using FuncInfo = ::tvm::ffi::details::FunctionInfo>; MethodInfoBuilder info; info.name = TVMFFIByteArray{name, std::char_traits::length(name)}; info.doc = TVMFFIByteArray{nullptr, 0}; diff --git a/src/ffi/extra/reflection_extra.cc b/src/ffi/extra/reflection_extra.cc index b5ced5c2f..44e8ac3c5 100644 --- a/src/ffi/extra/reflection_extra.cc +++ b/src/ffi/extra/reflection_extra.cc @@ -42,15 +42,7 @@ void MakeObjectFromPackedArgs(ffi::PackedArgs args, Any* ret) { TVM_FFI_ICHECK(args.size() % 2 == 1); const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(type_index); - - if (type_info->metadata == nullptr || type_info->metadata->creator == nullptr) { - TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_index) - << "` does not support reflection creation"; - } - TVMFFIObjectHandle handle; - TVM_FFI_CHECK_SAFE_CALL(type_info->metadata->creator(&handle)); - ObjectPtr ptr = - details::ObjectUnsafe::ObjectPtrFromOwned(static_cast(handle)); + ObjectPtr ptr = CreateEmptyObject(type_info); std::vector keys; std::vector keys_found; diff --git a/src/ffi/extra/serialization.cc b/src/ffi/extra/serialization.cc index 80b96ec7f..c1fb62114 100644 --- a/src/ffi/extra/serialization.cc +++ b/src/ffi/extra/serialization.cc @@ -371,15 +371,7 @@ class ObjectGraphDeserializer { } // otherwise, we go over the fields and create the data. const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(type_index); - if (type_info->metadata == nullptr || type_info->metadata->creator == nullptr) { - TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_index) - << "` does not support default constructor" - << ", so ToJSONGraph is not supported for this type"; - } - TVMFFIObjectHandle handle; - TVM_FFI_CHECK_SAFE_CALL(type_info->metadata->creator(&handle)); - ObjectPtr ptr = - details::ObjectUnsafe::ObjectPtrFromOwned(static_cast(handle)); + ObjectPtr ptr = CreateEmptyObject(type_info); auto decode_field_value = [&](const TVMFFIFieldInfo* field_info, const json::Value& data) -> Any {