Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion include/tvm/ffi/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -913,6 +913,15 @@ typedef enum {
* By default this flag is off (meaning the field accepts positional arguments).
*/
kTVMFFIFieldFlagBitMaskKwOnly = 1 << 10,
/*!
* \brief The setter field is a TVMFFIObjectHandle pointing to a FunctionObj.
*
* When this flag is set, the ``setter`` member of TVMFFIFieldInfo is not a
* TVMFFIFieldSetter function pointer but instead a TVMFFIObjectHandle
* pointing to a FunctionObj. The FunctionObj is called with two arguments:
* ``(field_addr_as_OpaquePtr, value_as_AnyView)``.
*/
kTVMFFIFieldFlagBitSetterIsFunctionObj = 1 << 11,
#ifdef __cplusplus
};
#else
Expand Down Expand Up @@ -1008,9 +1017,15 @@ typedef struct {
TVMFFIFieldGetter getter;
/*!
* \brief The setter to access the field.
*
* When kTVMFFIFieldFlagBitSetterIsFunctionObj is NOT set (default),
* this is a TVMFFIFieldSetter function pointer (cast to void*).
* When kTVMFFIFieldFlagBitSetterIsFunctionObj IS set,
* this is a TVMFFIObjectHandle pointing to a FunctionObj.
*
* \note The setter is set even if the field is readonly for serialization.
*/
TVMFFIFieldSetter setter;
void* setter;
/*!
* \brief The default value or default factory of the field.
*
Expand Down
60 changes: 60 additions & 0 deletions include/tvm/ffi/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -1066,6 +1066,66 @@ inline int32_t TypeKeyToIndex(std::string_view type_key) {
#else
#define TVM_FFI_DLL_EXPORT_TYPED_FUNC_DOC(ExportName, DocString)
#endif

/*!
* \brief Create an empty object via the type's native creator or ``__ffi_new__`` type attr.
*
* Falls back to the ``__ffi_new__`` type attribute (used by Python-defined types)
* when the native ``metadata->creator`` is NULL.
*
* \param type_info The type info for the object to create.
* \return An owned ObjectPtr to the newly allocated (zero-initialized) object.
* \throws RuntimeError if neither creator nor __ffi_new__ is available.
*/
inline ObjectPtr<Object> CreateEmptyObject(const TVMFFITypeInfo* type_info) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should go into reflection/Creator

// Fast path: native C++ creator
if (type_info->metadata != nullptr && type_info->metadata->creator != nullptr) {
TVMFFIObjectHandle handle;
TVM_FFI_CHECK_SAFE_CALL(type_info->metadata->creator(&handle));
return details::ObjectUnsafe::ObjectPtrFromOwned<Object>(static_cast<TVMFFIObject*>(handle));
}
// Fallback: __ffi_new__ type attr (Python-defined types)
constexpr TVMFFIByteArray kFFINewAttrName = {"__ffi_new__", 11};
const TVMFFITypeAttrColumn* column = TVMFFIGetTypeAttrColumn(&kFFINewAttrName);
if (column != nullptr) {
int32_t offset = type_info->type_index - column->begin_index;
if (offset >= 0 && offset < column->size) {
AnyView attr_view = AnyView::CopyFromTVMFFIAny(column->data[offset]);
if (auto opt_func = attr_view.try_cast<Function>()) {
ObjectRef obj_ref = (*opt_func)().cast<ObjectRef>();
return details::ObjectUnsafe::ObjectPtrFromObjectRef<Object>(std::move(obj_ref));
}
}
}
TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_info->type_index)
<< "` does not support reflection creation"
<< " (no native creator or __ffi_new__ type attr)";
}

/*!
* \brief Check whether a type supports reflection creation.
*
* Returns true if the type has a native creator or a ``__ffi_new__`` type attr.
*
* \param type_info The type info to check.
* \return true if CreateEmptyObject would succeed.
*/
inline bool HasCreator(const TVMFFITypeInfo* type_info) {
if (type_info->metadata != nullptr && type_info->metadata->creator != nullptr) {
return true;
}
constexpr TVMFFIByteArray kFFINewAttrName = {"__ffi_new__", 11};
const TVMFFITypeAttrColumn* column = TVMFFIGetTypeAttrColumn(&kFFINewAttrName);
if (column != nullptr) {
int32_t offset = type_info->type_index - column->begin_index;
if (offset >= 0 && offset < column->size &&
column->data[offset].type_index >= kTVMFFIStaticObjectBegin) {
return true;
}
}
return false;
}

} // namespace ffi
} // namespace tvm
#endif // TVM_FFI_FUNCTION_H_
39 changes: 35 additions & 4 deletions include/tvm/ffi/reflection/accessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,37 @@ inline const TVMFFIFieldInfo* GetFieldInfo(std::string_view type_key, const char
TVM_FFI_UNREACHABLE();
}

/*!
* \brief Call the field setter, dispatching between function pointer and FunctionObj.
*
* When kTVMFFIFieldFlagBitSetterIsFunctionObj is off, invokes the setter as a
* TVMFFIFieldSetter function pointer. When on, calls via TVMFFIFunctionCall
* with arguments (field_addr as OpaquePtr, value).
*
* \param field_info The field info containing the setter.
* \param field_addr The address of the field in the object.
* \param value The value to set (as a TVMFFIAny pointer).
* \return 0 on success, nonzero on failure.
*/
inline int CallFieldSetter(const TVMFFIFieldInfo* field_info, void* field_addr,
const TVMFFIAny* value) {
if (!(field_info->flags & kTVMFFIFieldFlagBitSetterIsFunctionObj)) {
auto setter = reinterpret_cast<TVMFFIFieldSetter>(field_info->setter);
return setter(field_addr, value);
} else {
TVMFFIAny args[2];
args[0].type_index = kTVMFFIOpaquePtr;
args[0].zero_padding = 0;
args[0].v_ptr = field_addr;
args[1] = *value;
TVMFFIAny result;
result.type_index = kTVMFFINone;
result.v_int64 = 0;
return TVMFFIFunctionCall(static_cast<TVMFFIObjectHandle>(field_info->setter), args, 2,
&result);
}
}

/*!
* \brief helper wrapper class to obtain a getter.
*/
Expand Down Expand Up @@ -118,8 +149,8 @@ class FieldSetter {
*/
void operator()(const Object* obj_ptr, AnyView value) const {
const void* addr = reinterpret_cast<const char*>(obj_ptr) + field_info_->offset;
TVM_FFI_CHECK_SAFE_CALL(
field_info_->setter(const_cast<void*>(addr), reinterpret_cast<const TVMFFIAny*>(&value)));
TVM_FFI_CHECK_SAFE_CALL(CallFieldSetter(field_info_, const_cast<void*>(addr),
reinterpret_cast<const TVMFFIAny*>(&value)));
}

void operator()(const ObjectPtr<Object>& obj_ptr, AnyView value) const {
Expand Down Expand Up @@ -215,9 +246,9 @@ inline void SetFieldToDefault(const TVMFFIFieldInfo* field_info, void* field_add
Function factory =
AnyView::CopyFromTVMFFIAny(field_info->default_value_or_factory).cast<Function>();
Any default_val = factory();
field_info->setter(field_addr, reinterpret_cast<const TVMFFIAny*>(&default_val));
CallFieldSetter(field_info, field_addr, reinterpret_cast<const TVMFFIAny*>(&default_val));
} else {
field_info->setter(field_addr, &(field_info->default_value_or_factory));
CallFieldSetter(field_info, field_addr, &(field_info->default_value_or_factory));
}
}

Expand Down
16 changes: 4 additions & 12 deletions include/tvm/ffi/reflection/creator.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,8 @@ class ObjectCreator {
* \param type_info The type info.
*/
explicit ObjectCreator(const TVMFFITypeInfo* type_info) : type_info_(type_info) {
int32_t type_index = type_info->type_index;
if (type_info->metadata == nullptr) {
TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_index)
<< "` does not have reflection registered";
}
if (type_info->metadata->creator == nullptr) {
TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_index)
if (!HasCreator(type_info)) {
TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_info->type_index)
<< "` does not support default constructor, "
<< "as a result cannot be created via reflection";
}
Expand All @@ -66,17 +61,14 @@ class ObjectCreator {
* \return The created object.
*/
Any operator()(const Map<String, Any>& fields) const {
TVMFFIObjectHandle handle;
TVM_FFI_CHECK_SAFE_CALL(type_info_->metadata->creator(&handle));
ObjectPtr<Object> ptr =
details::ObjectUnsafe::ObjectPtrFromOwned<Object>(static_cast<TVMFFIObject*>(handle));
ObjectPtr<Object> ptr = CreateEmptyObject(type_info_);
size_t match_field_count = 0;
ForEachFieldInfo(type_info_, [&](const TVMFFIFieldInfo* field_info) {
String field_name(field_info->name);
void* field_addr = reinterpret_cast<char*>(ptr.get()) + field_info->offset;
if (fields.count(field_name) != 0) {
Any field_value = fields[field_name];
field_info->setter(field_addr, reinterpret_cast<const TVMFFIAny*>(&field_value));
CallFieldSetter(field_info, field_addr, reinterpret_cast<const TVMFFIAny*>(&field_value));
++match_field_count;
} else if (field_info->flags & kTVMFFIFieldFlagBitMaskHasDefault) {
SetFieldToDefault(field_info, field_addr);
Expand Down
36 changes: 22 additions & 14 deletions include/tvm/ffi/reflection/init.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

#include <tvm/ffi/any.h>
#include <tvm/ffi/c_api.h>
#include <tvm/ffi/cast.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/function_details.h>
#include <tvm/ffi/object.h>
Expand All @@ -42,6 +43,20 @@ namespace tvm {
namespace ffi {
namespace reflection {

namespace details {

template <typename TObjectRef>
TObjectRef CastFromAny(AnyView input) {
TVMFFIAny input_pod = input.CopyToTVMFFIAny();
if (auto opt = TypeTraits<TObjectRef>::TryCastFromAnyView(&input_pod)) {
return *std::move(opt);
}
TVM_FFI_THROW(TypeError) << "Cannot cast from `" << TypeIndexToTypeKey(input_pod.type_index)
<< "` to `" << TypeTraits<TObjectRef>::TypeStr() << "`";
}

} // namespace details

/*!
* \brief Create a packed ``__ffi_init__`` constructor for the given type.
*
Expand Down Expand Up @@ -69,10 +84,8 @@ inline Function MakeInit(int32_t type_index) {
};
// ---- Pre-compute field analysis (once per type) -------------------------
const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(type_index);
TVM_FFI_ICHECK(type_info->metadata != nullptr)
<< "Type `" << TypeIndexToTypeKey(type_index) << "` has no reflection metadata";
TVM_FFI_ICHECK(type_info->metadata->creator != nullptr)
<< "Type `" << TypeIndexToTypeKey(type_index) << "` has no creator";
TVM_FFI_ICHECK(HasCreator(type_info)) << "Type `" << TypeIndexToTypeKey(type_index)
<< "` has no creator or __ffi_new__ for __ffi_init__";

auto info = std::make_shared<AutoInitInfo>();
info->type_key = std::string_view(type_info->type_key.data, type_info->type_key.size);
Expand Down Expand Up @@ -101,16 +114,11 @@ inline Function MakeInit(int32_t type_index) {
// Eagerly resolve the KWARGS sentinel via global function registry.
ObjectRef kwargs_sentinel =
Function::GetGlobalRequired("ffi.GetKwargsObject")().cast<ObjectRef>();
// Cache pointers for the lambda (avoid repeated lookups).
TVMFFIObjectCreator creator = type_info->metadata->creator;

return Function::FromPacked(
[info, kwargs_sentinel, creator](PackedArgs args, Any* rv) {
// ---- 1. Create object via creator ------------------------------------
TVMFFIObjectHandle handle;
TVM_FFI_CHECK_SAFE_CALL(creator(&handle));
ObjectPtr<Object> obj_ptr =
details::ObjectUnsafe::ObjectPtrFromOwned<Object>(static_cast<TVMFFIObject*>(handle));
[info, kwargs_sentinel, type_info](PackedArgs args, Any* rv) {
// ---- 1. Create object via CreateEmptyObject --------------------------
ObjectPtr<Object> obj_ptr = CreateEmptyObject(type_info);

// ---- 2. Find KWARGS sentinel position --------------------------------
int kwargs_pos = -1;
Expand All @@ -128,7 +136,7 @@ inline Function MakeInit(int32_t type_index) {

auto set_field = [&](size_t fi, const TVMFFIAny* value) {
void* addr = reinterpret_cast<char*>(obj_ptr.get()) + info->all_fields[fi].info->offset;
TVM_FFI_CHECK_SAFE_CALL(info->all_fields[fi].info->setter(addr, value));
TVM_FFI_CHECK_SAFE_CALL(CallFieldSetter(info->all_fields[fi].info, addr, value));
field_set[fi] = true;
};

Expand Down Expand Up @@ -219,7 +227,7 @@ inline void RegisterAutoInit(int32_t type_index) {
info.flags = kTVMFFIFieldFlagBitMaskIsStaticMethod;
info.method = AnyView(auto_init_fn).CopyToTVMFFIAny();
static const std::string kMetadata =
"{\"type_schema\":" + std::string(details::TypeSchemaImpl<Function>::v()) +
"{\"type_schema\":" + std::string(::tvm::ffi::details::TypeSchemaImpl<Function>::v()) +
",\"auto_init\":true}";
info.metadata = TVMFFIByteArray{kMetadata.c_str(), kMetadata.size()};
TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterMethod(type_index, &info));
Expand Down
15 changes: 8 additions & 7 deletions include/tvm/ffi/reflection/overload.h
Original file line number Diff line number Diff line change
Expand Up @@ -404,14 +404,15 @@ class OverloadObjectDef : private ObjectDef<Class> {
template <typename Func>
static auto GetOverloadMethod(std::string name, Func&& func) {
using WrapFn = decltype(WrapFunction(std::forward<Func>(func)));
using OverloadFn = details::OverloadedFunction<std::decay_t<WrapFn>>;
using OverloadFn = ::tvm::ffi::details::OverloadedFunction<std::decay_t<WrapFn>>;
return ffi::Function::FromPackedInplace<OverloadFn>(WrapFunction(std::forward<Func>(func)),
std::move(name));
}

template <typename Func>
static auto NewOverload(std::string name, Func&& func) {
return details::CreateNewOverload(WrapFunction(std::forward<Func>(func)), std::move(name));
return ::tvm::ffi::details::CreateNewOverload(WrapFunction(std::forward<Func>(func)),
std::move(name));
}

template <typename... ExtraArgs>
Expand Down Expand Up @@ -448,11 +449,11 @@ class OverloadObjectDef : private ObjectDef<Class> {
info.flags |= kTVMFFIFieldFlagBitMaskWritable;
}
info.getter = ReflectionDefBase::FieldGetter<T>;
info.setter = ReflectionDefBase::FieldSetter<T>;
info.setter = reinterpret_cast<void*>(ReflectionDefBase::FieldSetter<T>);
// initialize default value to nullptr
info.default_value_or_factory = AnyView(nullptr).CopyToTVMFFIAny();
info.doc = TVMFFIByteArray{nullptr, 0};
info.metadata_.emplace_back("type_schema", details::TypeSchema<T>::v());
info.metadata_.emplace_back("type_schema", ::tvm::ffi::details::TypeSchema<T>::v());
// apply field info traits
((ApplyFieldInfoTrait(&info, std::forward<ExtraArgs>(extra_args)), ...));
// call register
Expand All @@ -464,7 +465,7 @@ class OverloadObjectDef : private ObjectDef<Class> {
// register a method
template <typename Func, typename... Extra>
void RegisterMethod(const char* name, bool is_static, Func&& func, Extra&&... extra) {
using FuncInfo = details::FunctionInfo<std::decay_t<Func>>;
using FuncInfo = ::tvm::ffi::details::FunctionInfo<std::decay_t<Func>>;
MethodInfoBuilder info;
info.name = TVMFFIByteArray{name, std::char_traits<char>::length(name)};
info.doc = TVMFFIByteArray{nullptr, 0};
Expand All @@ -478,7 +479,7 @@ class OverloadObjectDef : private ObjectDef<Class> {
// if an overload method exists, register to existing overload function
if (const auto overload_it = registered_fields_.find(name);
overload_it != registered_fields_.end()) {
details::OverloadBase* overload_ptr = overload_it->second;
::tvm::ffi::details::OverloadBase* overload_ptr = overload_it->second;
return overload_ptr->Register(NewOverload(std::move(method_name), std::forward<Func>(func)));
}

Expand All @@ -496,7 +497,7 @@ class OverloadObjectDef : private ObjectDef<Class> {
TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterMethod(type_index_, &info));
}

std::unordered_map<std::string, details::OverloadBase*> registered_fields_;
std::unordered_map<std::string, ::tvm::ffi::details::OverloadBase*> registered_fields_;
};

} // namespace reflection
Expand Down
Loading