Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 63 additions & 12 deletions include/tvm/ffi/reflection/creator.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,72 @@
#ifndef TVM_FFI_REFLECTION_CREATOR_H_
#define TVM_FFI_REFLECTION_CREATOR_H_

#include <tvm/ffi/any.h>
#include <tvm/ffi/container/map.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/accessor.h>
#include <tvm/ffi/string.h>

namespace tvm {
namespace ffi {
/*!
* \brief Create an empty object via the type's native creator or ``__ffi_new__`` type attr.
*
* Falls back to the ``__ffi_new__`` type attribute (used by Python-defined types)
* when the native ``metadata->creator`` is NULL.
*
* \param type_info The type info for the object to create.
* \return An owned ObjectPtr to the newly allocated (zero-initialized) object.
* \throws RuntimeError if neither creator nor __ffi_new__ is available.
*/
inline ObjectPtr<Object> CreateEmptyObject(const TVMFFITypeInfo* type_info) {
// Fast path: native C++ creator
if (type_info->metadata != nullptr && type_info->metadata->creator != nullptr) {
TVMFFIObjectHandle handle;
TVM_FFI_CHECK_SAFE_CALL(type_info->metadata->creator(&handle));
return details::ObjectUnsafe::ObjectPtrFromOwned<Object>(static_cast<TVMFFIObject*>(handle));
}
// Fallback: __ffi_new__ type attr (Python-defined types)
constexpr TVMFFIByteArray kFFINewAttrName = {"__ffi_new__", 11};
const TVMFFITypeAttrColumn* column = TVMFFIGetTypeAttrColumn(&kFFINewAttrName);
if (column != nullptr) {
int32_t offset = type_info->type_index - column->begin_index;
if (offset >= 0 && offset < column->size) {
AnyView attr_view = AnyView::CopyFromTVMFFIAny(column->data[offset]);
if (auto opt_func = attr_view.try_cast<Function>()) {
ObjectRef obj_ref = (*opt_func)().cast<ObjectRef>();
return details::ObjectUnsafe::ObjectPtrFromObjectRef<Object>(std::move(obj_ref));
}
}
}
TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_info->type_index)
<< "` does not support reflection creation"
<< " (no native creator or __ffi_new__ type attr)";
}

/*!
* \brief Check whether a type supports reflection creation.
*
* Returns true if the type has a native creator or a ``__ffi_new__`` type attr.
*
* \param type_info The type info to check.
* \return true if CreateEmptyObject would succeed.
*/
inline bool HasCreator(const TVMFFITypeInfo* type_info) {
if (type_info->metadata != nullptr && type_info->metadata->creator != nullptr) {
return true;
}
constexpr TVMFFIByteArray kFFINewAttrName = {"__ffi_new__", 11};
const TVMFFITypeAttrColumn* column = TVMFFIGetTypeAttrColumn(&kFFINewAttrName);
if (column != nullptr) {
int32_t offset = type_info->type_index - column->begin_index;
if (offset >= 0 && offset < column->size &&
column->data[offset].type_index == TypeIndex::kTVMFFIFunction) {
return true;
}
}
return false;
}

namespace reflection {
/*!
* \brief helper wrapper class of TVMFFITypeInfo to create object based on reflection.
Expand All @@ -48,13 +107,8 @@ class ObjectCreator {
* \param type_info The type info.
*/
explicit ObjectCreator(const TVMFFITypeInfo* type_info) : type_info_(type_info) {
int32_t type_index = type_info->type_index;
if (type_info->metadata == nullptr) {
TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_index)
<< "` does not have reflection registered";
}
if (type_info->metadata->creator == nullptr) {
TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_index)
if (!HasCreator(type_info)) {
TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_info->type_index)
<< "` does not support default constructor, "
<< "as a result cannot be created via reflection";
}
Expand All @@ -66,10 +120,7 @@ class ObjectCreator {
* \return The created object.
*/
Any operator()(const Map<String, Any>& fields) const {
TVMFFIObjectHandle handle;
TVM_FFI_CHECK_SAFE_CALL(type_info_->metadata->creator(&handle));
ObjectPtr<Object> ptr =
details::ObjectUnsafe::ObjectPtrFromOwned<Object>(static_cast<TVMFFIObject*>(handle));
ObjectPtr<Object> ptr = CreateEmptyObject(type_info_);
size_t match_field_count = 0;
ForEachFieldInfo(type_info_, [&](const TVMFFIFieldInfo* field_info) {
String field_name(field_info->name);
Expand Down
20 changes: 7 additions & 13 deletions include/tvm/ffi/reflection/init.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <tvm/ffi/function_details.h>
#include <tvm/ffi/object.h>
#include <tvm/ffi/reflection/accessor.h>
#include <tvm/ffi/reflection/creator.h>
#include <tvm/ffi/string.h>

#include <algorithm>
Expand Down Expand Up @@ -69,10 +70,8 @@ inline Function MakeInit(int32_t type_index) {
};
// ---- Pre-compute field analysis (once per type) -------------------------
const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(type_index);
TVM_FFI_ICHECK(type_info->metadata != nullptr)
<< "Type `" << TypeIndexToTypeKey(type_index) << "` has no reflection metadata";
TVM_FFI_ICHECK(type_info->metadata->creator != nullptr)
<< "Type `" << TypeIndexToTypeKey(type_index) << "` has no creator";
TVM_FFI_ICHECK(HasCreator(type_info)) << "Type `" << TypeIndexToTypeKey(type_index)
<< "` has no creator or __ffi_new__ for __ffi_init__";

auto info = std::make_shared<AutoInitInfo>();
info->type_key = std::string_view(type_info->type_key.data, type_info->type_key.size);
Expand Down Expand Up @@ -101,16 +100,11 @@ inline Function MakeInit(int32_t type_index) {
// Eagerly resolve the KWARGS sentinel via global function registry.
ObjectRef kwargs_sentinel =
Function::GetGlobalRequired("ffi.GetKwargsObject")().cast<ObjectRef>();
// Cache pointers for the lambda (avoid repeated lookups).
TVMFFIObjectCreator creator = type_info->metadata->creator;

return Function::FromPacked(
[info, kwargs_sentinel, creator](PackedArgs args, Any* rv) {
// ---- 1. Create object via creator ------------------------------------
TVMFFIObjectHandle handle;
TVM_FFI_CHECK_SAFE_CALL(creator(&handle));
ObjectPtr<Object> obj_ptr =
details::ObjectUnsafe::ObjectPtrFromOwned<Object>(static_cast<TVMFFIObject*>(handle));
[info, kwargs_sentinel, type_info](PackedArgs args, Any* rv) {
// ---- 1. Create object via CreateEmptyObject --------------------------
ObjectPtr<Object> obj_ptr = CreateEmptyObject(type_info);

// ---- 2. Find KWARGS sentinel position --------------------------------
int kwargs_pos = -1;
Expand Down Expand Up @@ -219,7 +213,7 @@ inline void RegisterAutoInit(int32_t type_index) {
info.flags = kTVMFFIFieldFlagBitMaskIsStaticMethod;
info.method = AnyView(auto_init_fn).CopyToTVMFFIAny();
static const std::string kMetadata =
"{\"type_schema\":" + std::string(details::TypeSchemaImpl<Function>::v()) +
"{\"type_schema\":" + std::string(::tvm::ffi::details::TypeSchemaImpl<Function>::v()) +
",\"auto_init\":true}";
info.metadata = TVMFFIByteArray{kMetadata.c_str(), kMetadata.size()};
TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterMethod(type_index, &info));
Expand Down
13 changes: 7 additions & 6 deletions include/tvm/ffi/reflection/overload.h
Original file line number Diff line number Diff line change
Expand Up @@ -404,14 +404,15 @@ class OverloadObjectDef : private ObjectDef<Class> {
template <typename Func>
static auto GetOverloadMethod(std::string name, Func&& func) {
using WrapFn = decltype(WrapFunction(std::forward<Func>(func)));
using OverloadFn = details::OverloadedFunction<std::decay_t<WrapFn>>;
using OverloadFn = ::tvm::ffi::details::OverloadedFunction<std::decay_t<WrapFn>>;
return ffi::Function::FromPackedInplace<OverloadFn>(WrapFunction(std::forward<Func>(func)),
std::move(name));
}

template <typename Func>
static auto NewOverload(std::string name, Func&& func) {
return details::CreateNewOverload(WrapFunction(std::forward<Func>(func)), std::move(name));
return ::tvm::ffi::details::CreateNewOverload(WrapFunction(std::forward<Func>(func)),
std::move(name));
}

template <typename... ExtraArgs>
Expand Down Expand Up @@ -452,7 +453,7 @@ class OverloadObjectDef : private ObjectDef<Class> {
// initialize default value to nullptr
info.default_value_or_factory = AnyView(nullptr).CopyToTVMFFIAny();
info.doc = TVMFFIByteArray{nullptr, 0};
info.metadata_.emplace_back("type_schema", details::TypeSchema<T>::v());
info.metadata_.emplace_back("type_schema", ::tvm::ffi::details::TypeSchema<T>::v());
// apply field info traits
((ApplyFieldInfoTrait(&info, std::forward<ExtraArgs>(extra_args)), ...));
// call register
Expand All @@ -464,7 +465,7 @@ class OverloadObjectDef : private ObjectDef<Class> {
// register a method
template <typename Func, typename... Extra>
void RegisterMethod(const char* name, bool is_static, Func&& func, Extra&&... extra) {
using FuncInfo = details::FunctionInfo<std::decay_t<Func>>;
using FuncInfo = ::tvm::ffi::details::FunctionInfo<std::decay_t<Func>>;
MethodInfoBuilder info;
info.name = TVMFFIByteArray{name, std::char_traits<char>::length(name)};
info.doc = TVMFFIByteArray{nullptr, 0};
Expand All @@ -478,7 +479,7 @@ class OverloadObjectDef : private ObjectDef<Class> {
// if an overload method exists, register to existing overload function
if (const auto overload_it = registered_fields_.find(name);
overload_it != registered_fields_.end()) {
details::OverloadBase* overload_ptr = overload_it->second;
::tvm::ffi::details::OverloadBase* overload_ptr = overload_it->second;
return overload_ptr->Register(NewOverload(std::move(method_name), std::forward<Func>(func)));
}

Expand All @@ -496,7 +497,7 @@ class OverloadObjectDef : private ObjectDef<Class> {
TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterMethod(type_index_, &info));
}

std::unordered_map<std::string, details::OverloadBase*> registered_fields_;
std::unordered_map<std::string, ::tvm::ffi::details::OverloadBase*> registered_fields_;
};

} // namespace reflection
Expand Down
21 changes: 11 additions & 10 deletions include/tvm/ffi/reflection/registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,8 @@ template <typename Class, typename T>
TVM_FFI_INLINE int64_t GetFieldByteOffsetToObject(T Class::* field_ptr) {
int64_t field_offset_to_class =
reinterpret_cast<int64_t>(&(static_cast<Class*>(nullptr)->*field_ptr));
return field_offset_to_class - details::ObjectUnsafe::GetObjectOffsetToSubclass<Class>();
return field_offset_to_class -
::tvm::ffi::details::ObjectUnsafe::GetObjectOffsetToSubclass<Class>();
}

/// \cond Doxygen_Suppress
Expand All @@ -371,7 +372,7 @@ class ReflectionDefBase {
template <typename T>
static int FieldGetter(void* field, TVMFFIAny* result) {
TVM_FFI_SAFE_CALL_BEGIN();
*result = details::AnyUnsafe::MoveAnyToTVMFFIAny(Any(*reinterpret_cast<T*>(field)));
*result = ::tvm::ffi::details::AnyUnsafe::MoveAnyToTVMFFIAny(Any(*reinterpret_cast<T*>(field)));
TVM_FFI_SAFE_CALL_END();
}

Expand All @@ -390,15 +391,15 @@ class ReflectionDefBase {
static int ObjectCreatorDefault(TVMFFIObjectHandle* result) {
TVM_FFI_SAFE_CALL_BEGIN();
ObjectPtr<T> obj = make_object<T>();
*result = details::ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(std::move(obj));
*result = ::tvm::ffi::details::ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(std::move(obj));
TVM_FFI_SAFE_CALL_END();
}

template <typename T>
static int ObjectCreatorUnsafeInit(TVMFFIObjectHandle* result) {
TVM_FFI_SAFE_CALL_BEGIN();
ObjectPtr<T> obj = make_object<T>(UnsafeInit{});
*result = details::ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(std::move(obj));
*result = ::tvm::ffi::details::ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(std::move(obj));
TVM_FFI_SAFE_CALL_END();
}

Expand Down Expand Up @@ -499,7 +500,7 @@ class GlobalDef : public ReflectionDefBase {
*/
template <typename Func, typename... Extra>
GlobalDef& def(const char* name, Func&& func, Extra&&... extra) {
using FuncInfo = details::FunctionInfo<std::decay_t<Func>>;
using FuncInfo = ::tvm::ffi::details::FunctionInfo<std::decay_t<Func>>;
RegisterFunc(name, ffi::Function::FromTyped(std::forward<Func>(func), std::string(name)),
FuncInfo::TypeSchema(), std::forward<Extra>(extra)...);
return *this;
Expand All @@ -519,8 +520,8 @@ class GlobalDef : public ReflectionDefBase {
*/
template <typename Func, typename... Extra>
GlobalDef& def_packed(const char* name, Func func, Extra&&... extra) {
RegisterFunc(name, ffi::Function::FromPacked(func), details::TypeSchemaImpl<Function>::v(),
std::forward<Extra>(extra)...);
RegisterFunc(name, ffi::Function::FromPacked(func),
::tvm::ffi::details::TypeSchemaImpl<Function>::v(), std::forward<Extra>(extra)...);
return *this;
}

Expand All @@ -540,7 +541,7 @@ class GlobalDef : public ReflectionDefBase {
*/
template <typename Func, typename... Extra>
GlobalDef& def_method(const char* name, Func&& func, Extra&&... extra) {
using FuncInfo = details::FunctionInfo<std::decay_t<Func>>;
using FuncInfo = ::tvm::ffi::details::FunctionInfo<std::decay_t<Func>>;
RegisterFunc(name, GetMethod(std::string(name), std::forward<Func>(func)),
FuncInfo::TypeSchema(), std::forward<Extra>(extra)...);
return *this;
Expand Down Expand Up @@ -915,7 +916,7 @@ class ObjectDef : public ReflectionDefBase {
// initialize default value to nullptr
info.default_value_or_factory = AnyView(nullptr).CopyToTVMFFIAny();
info.doc = TVMFFIByteArray{nullptr, 0};
info.metadata_.emplace_back("type_schema", details::TypeSchema<T>::v());
info.metadata_.emplace_back("type_schema", ::tvm::ffi::details::TypeSchema<T>::v());
// apply field info traits
((ApplyFieldInfoTrait(&info, std::forward<ExtraArgs>(extra_args)), ...));
// call register
Expand All @@ -927,7 +928,7 @@ class ObjectDef : public ReflectionDefBase {
// register a method
template <typename Func, typename... Extra>
void RegisterMethod(const char* name, bool is_static, Func&& func, Extra&&... extra) {
using FuncInfo = details::FunctionInfo<std::decay_t<Func>>;
using FuncInfo = ::tvm::ffi::details::FunctionInfo<std::decay_t<Func>>;
MethodInfoBuilder info;
info.name = TVMFFIByteArray{name, std::char_traits<char>::length(name)};
info.doc = TVMFFIByteArray{nullptr, 0};
Expand Down
10 changes: 1 addition & 9 deletions src/ffi/extra/reflection_extra.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,7 @@ void MakeObjectFromPackedArgs(ffi::PackedArgs args, Any* ret) {

TVM_FFI_ICHECK(args.size() % 2 == 1);
const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(type_index);

if (type_info->metadata == nullptr || type_info->metadata->creator == nullptr) {
TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_index)
<< "` does not support reflection creation";
}
TVMFFIObjectHandle handle;
TVM_FFI_CHECK_SAFE_CALL(type_info->metadata->creator(&handle));
ObjectPtr<Object> ptr =
details::ObjectUnsafe::ObjectPtrFromOwned<Object>(static_cast<TVMFFIObject*>(handle));
ObjectPtr<Object> ptr = CreateEmptyObject(type_info);

std::vector<String> keys;
std::vector<bool> keys_found;
Expand Down
10 changes: 1 addition & 9 deletions src/ffi/extra/serialization.cc
Original file line number Diff line number Diff line change
Expand Up @@ -371,15 +371,7 @@ class ObjectGraphDeserializer {
}
// otherwise, we go over the fields and create the data.
const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(type_index);
if (type_info->metadata == nullptr || type_info->metadata->creator == nullptr) {
TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_index)
<< "` does not support default constructor"
<< ", so ToJSONGraph is not supported for this type";
}
TVMFFIObjectHandle handle;
TVM_FFI_CHECK_SAFE_CALL(type_info->metadata->creator(&handle));
ObjectPtr<Object> ptr =
details::ObjectUnsafe::ObjectPtrFromOwned<Object>(static_cast<TVMFFIObject*>(handle));
ObjectPtr<Object> ptr = CreateEmptyObject(type_info);

auto decode_field_value = [&](const TVMFFIFieldInfo* field_info,
const json::Value& data) -> Any {
Expand Down