diff --git a/docs/index.rst b/docs/index.rst index 62768c74..79da2e87 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -71,6 +71,7 @@ Table of Contents packaging/python_packaging.rst packaging/stubgen.rst + packaging/rust_stubgen.md packaging/cpp_tooling.rst .. toctree:: diff --git a/docs/packaging/rust_stubgen.md b/docs/packaging/rust_stubgen.md new file mode 100644 index 00000000..039488e9 --- /dev/null +++ b/docs/packaging/rust_stubgen.md @@ -0,0 +1,142 @@ + + + + + + + + + + + + + + + + + +# Rust Stubgen Guide + +```{note} +The Rust stub generation flow is currently experimental and may evolve. +``` + +This guide covers practical usage of `tvm-ffi-stubgen`: generation command, output crate, and how to call generated APIs. + +## Generate a Stub Crate + +Run from `3rdparty/tvm/3rdparty/tvm-ffi/rust`: + +```bash +cargo run -p tvm-ffi-stubgen -- \ + --init-prefix testing \ + --init-crate tvm-ffi-testing \ + --dlls /abs/path/to/libtvm_ffi_testing.so \ + --overwrite +``` + +### Arguments + +- `OUT_DIR`: positional output directory +- `--dlls`: one or more dynamic libraries for reflection metadata (`;`-separated) +- `--init-prefix`: registry prefix filter (repeatable; see multi-prefix below) +- `--init-crate`: generated crate name +- `--tvm-ffi-path`: optional local path override for `tvm-ffi` +- `--overwrite`: overwrite non-empty output directory +- `--no-format`: skip the post-generation `cargo fmt` pass + +By default, `tvm-ffi-stubgen` runs `cargo fmt` on the generated crate after emitting +`Cargo.toml`, `build.rs`, and Rust sources. Use `--no-format` only when you need to inspect +the raw generated text before formatting or when debugging generator output itself. + +### Multi-Prefix Mode + +`--init-prefix` can be specified multiple times to generate a single crate covering +several namespaces: + +```bash +cargo run -p tvm-ffi-stubgen -- \ + --dlls "libtilelang_module.so;libtvm.so" \ + --init-prefix tl --init-prefix ir --init-prefix tir --init-prefix script \ + --init-crate tilelang-ffi \ + --overwrite +``` + +With a single prefix the prefix is stripped and items land at the crate root. +With multiple prefixes no stripping occurs; each prefix becomes a top-level module +(`crate::tl::*`, `crate::ir::*`, etc.). + +## Generated Output Layout + +The output is a standalone Rust crate: + +- `Cargo.toml` +- `src/lib.rs` +- `src/_tvm_ffi_stubgen_detail/functions.rs` +- `src/_tvm_ffi_stubgen_detail/types.rs` + +`src/lib.rs` re-exports generated wrappers and provides: + +```rust +pub fn load_library(path: &str) -> tvm_ffi::Result +``` + +## Using Generated Crate + +Using the generated stubs is straightforward—simply load the runtime library, call exported functions, and work with generated object wrappers and subtyping as needed. The full process is shown in the following example, covering typical usage: + +```rust +use tvm_ffi_testing as stub; + +fn main() -> tvm_ffi::Result<()> { + // Load FFI library (required before any calls) + stub::load_library("/abs/path/to/libtvm_ffi_testing.so")?; + + // Call a generated function with typed arguments + let y = stub::add_one(1)?; + assert_eq!(y, 2); + + // Call a function via packed interface for dynamic signature + let _out = stub::echo(&[tvm_ffi::Any::from(1_i64)])?; + + // Object constructor/method wrappers are resolved from type metadata. + let pair_obj = stub::TestIntPair::new(3, 4)?; + let pair: stub::TestIntPair = pair_obj + .try_into() + .map_err(|_| tvm_ffi::Error::new(tvm_ffi::TYPE_ERROR, "downcast failed", ""))?; + let sum_any = pair.sum(&[])?; + let sum: i64 = sum_any.try_into()?; + assert_eq!(sum, 7); + + // Cxx inheritance sample: construct derived, view as base, then convert back. + let derived_obj = stub::TestCxxClassDerived::new(11, 7, 3.5, 1.25)?; + let base: stub::TestCxxClassBase = derived_obj.clone().into(); + let base_obj: tvm_ffi::object::ObjectRef = base.clone().into(); + let derived_again: stub::TestCxxClassDerived = base_obj.into(); + assert_eq!(base.v_i64()?, 11); + assert_eq!(base.v_i32()?, 7); + assert!((derived_again.v_f64()? - 3.5).abs() < 1e-9); + assert!((derived_again.v_f32()? - 1.25).abs() < 1e-6); + + // Use object-returning wrappers and ObjectRef-based APIs + let obj = stub::make_unregistered_object()?; + let count = stub::object_use_count(obj.clone())?; + assert!(count >= 1); + + // Fallback wrapper can be built from ObjectRef directly + let _wrapped: stub::TestUnregisteredObject = obj.into(); + + Ok(()) +} +``` + +- Load the library once before using the APIs. +- Generated functions support typed signatures when possible and fall back to `Any` for dynamic calling. +- Generated object method wrappers (including constructor `new`) are resolved via type metadata rather than global function lookup. +- Generated object-returning wrappers integrate with `ObjectRef` APIs and wrapper conversions. + + +## Related Docs + +- Rust language guide: `guides/rust_lang_guide.md` +- Rust stubgen design details (implementation-oriented): `rust/tvm-ffi-stubgen/README.md` diff --git a/rust/.rustfmt.toml b/rust/.rustfmt.toml new file mode 100644 index 00000000..15df9a7c --- /dev/null +++ b/rust/.rustfmt.toml @@ -0,0 +1,3 @@ +edition = "2024" +max_width = 100 +use_small_heuristics = "Default" diff --git a/rust/Cargo.toml b/rust/Cargo.toml index d9243768..78899c45 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -16,6 +16,6 @@ # under the License. [workspace] -members = ["tvm-ffi", "tvm-ffi-sys", "tvm-ffi-macros"] +members = ["tvm-ffi", "tvm-ffi-sys", "tvm-ffi-macros", "tvm-ffi-stubgen"] resolver = "2" diff --git a/rust/rust-toolchain.toml b/rust/rust-toolchain.toml new file mode 100644 index 00000000..73cb934d --- /dev/null +++ b/rust/rust-toolchain.toml @@ -0,0 +1,3 @@ +[toolchain] +channel = "stable" +components = ["rustfmt", "clippy"] diff --git a/rust/tvm-ffi-macros/Cargo.toml b/rust/tvm-ffi-macros/Cargo.toml index f8d29d40..4771ce28 100644 --- a/rust/tvm-ffi-macros/Cargo.toml +++ b/rust/tvm-ffi-macros/Cargo.toml @@ -20,7 +20,8 @@ name = "tvm-ffi-macros" description = "Procedural macro crate for tvm-ffi" version = "0.1.0-alpha.0" -edition = "2021" +edition = "2024" +rust-version = "1.85" license = "Apache-2.0" @@ -30,5 +31,4 @@ proc-macro = true [dependencies] proc-macro2 = "^1.0" quote = "^1.0" -syn = { version = "1.0.48", features = ["full", "parsing", "extra-traits"] } -proc-macro-error = "^1.0" +syn = { version = "^2.0", features = ["full"] } diff --git a/rust/tvm-ffi-macros/src/lib.rs b/rust/tvm-ffi-macros/src/lib.rs index 64fe3f18..76497bd1 100644 --- a/rust/tvm-ffi-macros/src/lib.rs +++ b/rust/tvm-ffi-macros/src/lib.rs @@ -18,19 +18,16 @@ */ use proc_macro::TokenStream; -use proc_macro_error::proc_macro_error; mod object_macros; mod utils; -#[proc_macro_error] #[proc_macro_derive(Object, attributes(type_key, type_index))] pub fn derive_object(input: TokenStream) -> TokenStream { - TokenStream::from(object_macros::derive_object(input)) + object_macros::derive_object(input) } -#[proc_macro_error] #[proc_macro_derive(ObjectRef, attributes(type_key, type_index))] pub fn derive_object_ref(input: TokenStream) -> TokenStream { - TokenStream::from(object_macros::derive_object_ref(input)) + object_macros::derive_object_ref(input) } diff --git a/rust/tvm-ffi-macros/src/object_macros.rs b/rust/tvm-ffi-macros/src/object_macros.rs index 8154709d..9e6e1ac5 100644 --- a/rust/tvm-ffi-macros/src/object_macros.rs +++ b/rust/tvm-ffi-macros/src/object_macros.rs @@ -58,7 +58,7 @@ pub fn derive_object(input: proc_macro::TokenStream) -> TokenStream { &type_key_arg, &mut tindex ); if ret != 0 { - proc_macro_error::abort!("Failed to get type index for type key: {}", #type_key); + panic!("Failed to get type index for type key: {}", #type_key); } tindex } @@ -124,6 +124,198 @@ pub fn derive_object_ref(input: proc_macro::TokenStream) -> TokenStream { } .expect("First field must be `data: ObjectArc`"); + let is_object_ref = struct_name == syn::Ident::new("ObjectRef", struct_name.span()); + + let any_compatible_tokens = if is_object_ref { + quote! { + // implement AnyCompatible for #struct_name + unsafe impl #tvm_ffi_crate::type_traits::AnyCompatible for #struct_name { + fn type_str() -> String { + type ContainerType = <#struct_name as #tvm_ffi_crate::object::ObjectRefCore> + ::ContainerType; + ::TYPE_KEY.into() + } + + unsafe fn copy_to_any_view( + src: &Self, + data: &mut #tvm_ffi_crate::tvm_ffi_sys::TVMFFIAny + ) { + type ContainerType = <#struct_name as #tvm_ffi_crate::object::ObjectRefCore> + ::ContainerType; + let data_ptr = #tvm_ffi_crate::object::ObjectArc::::as_raw( + &src.data + ) as *mut #tvm_ffi_crate::tvm_ffi_sys::TVMFFIObject; + data.type_index = (*data_ptr).type_index; + data.small_str_len = 0; + data.data_union.v_obj = data_ptr; + } + + unsafe fn check_any_strict(data: & #tvm_ffi_crate::tvm_ffi_sys::TVMFFIAny) -> bool { + data.type_index >= #tvm_ffi_crate::TypeIndex::kTVMFFIStaticObjectBegin as i32 + } + + unsafe fn copy_from_any_view_after_check( + data: & #tvm_ffi_crate::tvm_ffi_sys::TVMFFIAny + ) -> Self { + type ContainerType = <#struct_name as #tvm_ffi_crate::object::ObjectRefCore> + ::ContainerType; + let data_ptr = data.data_union.v_obj; + // need to increase ref because original weak ptr + // do not own the code + #tvm_ffi_crate::object::unsafe_::inc_ref( + data_ptr as *mut #tvm_ffi_crate::tvm_ffi_sys::TVMFFIObject + ); + Self { + data : #tvm_ffi_crate::object::ObjectArc::from_raw( + data_ptr as *mut ContainerType + ) + } + } + + unsafe fn move_to_any( + src: Self, + data: &mut #tvm_ffi_crate::tvm_ffi_sys::TVMFFIAny + ) { + type ContainerType = <#struct_name as #tvm_ffi_crate::object::ObjectRefCore> + ::ContainerType; + let data_ptr = #tvm_ffi_crate::object::ObjectArc::into_raw( + src.data + ) as *mut #tvm_ffi_crate::tvm_ffi_sys::TVMFFIObject; + data.type_index = (*data_ptr).type_index; + data.small_str_len = 0; + data.data_union.v_obj = data_ptr; + } + + unsafe fn move_from_any_after_check( + data: &mut #tvm_ffi_crate::tvm_ffi_sys::TVMFFIAny + ) -> Self { + type ContainerType = <#struct_name as #tvm_ffi_crate::object::ObjectRefCore> + ::ContainerType; + let data_ptr = data.data_union.v_obj as *mut ContainerType; + Self { + data : #tvm_ffi_crate::object::ObjectArc::::from_raw(data_ptr) + } + } + + unsafe fn try_cast_from_any_view( + data: & #tvm_ffi_crate::tvm_ffi_sys::TVMFFIAny + ) -> Result { + if Self::check_any_strict(data) { + Ok(Self::copy_from_any_view_after_check(data)) + } else { + Err(()) + } + } + } + } + } else { + quote! { + // implement AnyCompatible for #struct_name + unsafe impl #tvm_ffi_crate::type_traits::AnyCompatible for #struct_name { + fn type_str() -> String { + type ContainerType = <#struct_name as #tvm_ffi_crate::object::ObjectRefCore> + ::ContainerType; + ::TYPE_KEY.into() + } + + unsafe fn copy_to_any_view( + src: &Self, + data: &mut #tvm_ffi_crate::tvm_ffi_sys::TVMFFIAny + ) { + type ContainerType = <#struct_name as #tvm_ffi_crate::object::ObjectRefCore> + ::ContainerType; + let type_index = + ::type_index(); + data.type_index = type_index as i32; + data.small_str_len = 0; + let data_ptr = #tvm_ffi_crate::object::ObjectArc::::as_raw( + &src.data + ); + data.data_union.v_obj = + data_ptr as *mut ContainerType as *mut #tvm_ffi_crate::tvm_ffi_sys::TVMFFIObject; + } + + unsafe fn check_any_strict(data: & #tvm_ffi_crate::tvm_ffi_sys::TVMFFIAny) -> bool { + type ContainerType = <#struct_name as #tvm_ffi_crate::object::ObjectRefCore> + ::ContainerType; + let target_index = + ::type_index(); + if data.type_index == target_index as i32 { + return true; + } + let info = #tvm_ffi_crate::tvm_ffi_sys::TVMFFIGetTypeInfo(data.type_index); + if info.is_null() { + return false; + } + let info = &*info; + let ancestors = info.type_acenstors; + if ancestors.is_null() { + return false; + } + for depth in 0..info.type_depth { + let ancestor = *ancestors.add(depth as usize); + if !ancestor.is_null() && (*ancestor).type_index == target_index { + return true; + } + } + false + } + + unsafe fn copy_from_any_view_after_check( + data: & #tvm_ffi_crate::tvm_ffi_sys::TVMFFIAny + ) -> Self { + type ContainerType = <#struct_name as #tvm_ffi_crate::object::ObjectRefCore> + ::ContainerType; + // Delegate to ObjectRef to handle reference counting + let obj_ref = <#tvm_ffi_crate::object::ObjectRef as #tvm_ffi_crate::type_traits::AnyCompatible>::copy_from_any_view_after_check(data); + // Use public unsafe API to do pointer cast + let arc = <#tvm_ffi_crate::object::ObjectRef as #tvm_ffi_crate::object::ObjectRefCore>::into_data(obj_ref); + let raw = #tvm_ffi_crate::object::ObjectArc::into_raw(arc); + let typed = #tvm_ffi_crate::object::ObjectArc::from_raw(raw as *const ContainerType); + Self { data: typed } + } + + unsafe fn move_to_any( + src: Self, + data: &mut #tvm_ffi_crate::tvm_ffi_sys::TVMFFIAny + ) { + type ContainerType = <#struct_name as #tvm_ffi_crate::object::ObjectRefCore> + ::ContainerType; + let type_index = + ::type_index(); + data.type_index = type_index as i32; + data.small_str_len = 0; + let data_ptr = #tvm_ffi_crate::object::ObjectArc::into_raw( + src.data + ); + data.data_union.v_obj = + data_ptr as *mut ContainerType as *mut #tvm_ffi_crate::tvm_ffi_sys::TVMFFIObject; + } + + unsafe fn move_from_any_after_check( + data: &mut #tvm_ffi_crate::tvm_ffi_sys::TVMFFIAny + ) -> Self { + type ContainerType = <#struct_name as #tvm_ffi_crate::object::ObjectRefCore> + ::ContainerType; + let data_ptr = data.data_union.v_obj as *mut ContainerType; + Self { + data : #tvm_ffi_crate::object::ObjectArc::::from_raw(data_ptr) + } + } + + unsafe fn try_cast_from_any_view( + data: & #tvm_ffi_crate::tvm_ffi_sys::TVMFFIAny + ) -> Result { + if Self::check_any_strict(data) { + Ok(Self::copy_from_any_view_after_check(data)) + } else { + Err(()) + } + } + } + } + }; + let mut expanded = quote! { unsafe impl #tvm_ffi_crate::object::ObjectRefCore for #struct_name { type ContainerType = <#data_ty as std::ops::Deref>::Target; @@ -141,99 +333,7 @@ pub fn derive_object_ref(input: proc_macro::TokenStream) -> TokenStream { } } - // implement AnyCompatible for #struct_name - unsafe impl #tvm_ffi_crate::type_traits::AnyCompatible for #struct_name { - fn type_str() -> String { - type ContainerType = <#struct_name as #tvm_ffi_crate::object::ObjectRefCore> - ::ContainerType; - ::TYPE_KEY.into() - } - - unsafe fn copy_to_any_view( - src: &Self, - data: &mut #tvm_ffi_crate::tvm_ffi_sys::TVMFFIAny - ) { - type ContainerType = <#struct_name as #tvm_ffi_crate::object::ObjectRefCore> - ::ContainerType; - let type_index = - ::type_index(); - data.type_index = type_index as i32; - data.small_str_len = 0; - let data_ptr = #tvm_ffi_crate::object::ObjectArc::::as_raw( - &src.data - ); - data.data_union.v_obj = - data_ptr as *mut ContainerType as *mut #tvm_ffi_crate::tvm_ffi_sys::TVMFFIObject; - } - - unsafe fn check_any_strict(data: & #tvm_ffi_crate::tvm_ffi_sys::TVMFFIAny) -> bool { - type ContainerType = <#struct_name as #tvm_ffi_crate::object::ObjectRefCore> - ::ContainerType; - let type_index = - ::type_index(); - data.type_index == type_index as i32 - } - - unsafe fn copy_from_any_view_after_check( - data: & #tvm_ffi_crate::tvm_ffi_sys::TVMFFIAny - ) -> Self { - type ContainerType = <#struct_name as #tvm_ffi_crate::object::ObjectRefCore> - ::ContainerType; - let data_ptr = data.data_union.v_obj; - // need to increase ref because original weak ptr - // do not own the code - #tvm_ffi_crate::object::unsafe_::inc_ref( - data_ptr as *mut #tvm_ffi_crate::tvm_ffi_sys::TVMFFIObject - ); - Self { - data : #tvm_ffi_crate::object::ObjectArc::from_raw( - data_ptr as *mut ContainerType - ) - } - } - - unsafe fn move_to_any( - src: Self, - data: &mut #tvm_ffi_crate::tvm_ffi_sys::TVMFFIAny - ) { - type ContainerType = <#struct_name as #tvm_ffi_crate::object::ObjectRefCore> - ::ContainerType; - let type_index = - ::type_index(); - data.type_index = type_index as i32; - data.small_str_len = 0; - let data_ptr = #tvm_ffi_crate::object::ObjectArc::into_raw( - src.data - ); - data.data_union.v_obj = - data_ptr as *mut ContainerType as *mut #tvm_ffi_crate::tvm_ffi_sys::TVMFFIObject; - } - - unsafe fn move_from_any_after_check( - data: &mut #tvm_ffi_crate::tvm_ffi_sys::TVMFFIAny - ) -> Self { - type ContainerType = <#struct_name as #tvm_ffi_crate::object::ObjectRefCore> - ::ContainerType; - let data_ptr = data.data_union.v_obj as *mut ContainerType; - Self { - data : #tvm_ffi_crate::object::ObjectArc::::from_raw(data_ptr) - } - } - - unsafe fn try_cast_from_any_view( - data: & #tvm_ffi_crate::tvm_ffi_sys::TVMFFIAny - ) -> Result { - type ContainerType = <#struct_name as #tvm_ffi_crate::object::ObjectRefCore> - ::ContainerType; - let type_index = - ::type_index(); - if data.type_index == type_index as i32 { - Ok(Self::copy_from_any_view_after_check(data)) - } else { - Err(()) - } - } - } + #any_compatible_tokens }; // skip ObjectRef since it can create circular dependency with any.rs if struct_name != "ObjectRef" { diff --git a/rust/tvm-ffi-macros/src/utils.rs b/rust/tvm-ffi-macros/src/utils.rs index da86534f..0b3bca62 100644 --- a/rust/tvm-ffi-macros/src/utils.rs +++ b/rust/tvm-ffi-macros/src/utils.rs @@ -20,8 +20,6 @@ use proc_macro2::TokenStream; use quote::quote; use std::env; -/// Get the tvm-rt crate name -/// \return The tvm-rt crate name pub(crate) fn get_tvm_ffi_crate() -> TokenStream { if env::var("CARGO_PKG_NAME").unwrap() == "tvm-ffi" { quote!(crate) @@ -30,49 +28,27 @@ pub(crate) fn get_tvm_ffi_crate() -> TokenStream { } } -/// Get an attribute by name from a derive input -/// -/// # Arguments -/// * `derive_input` - The derive input to get the attribute from -/// * `name` - The name of the attribute to get -/// -/// # Returns -/// * `Option<&syn::Attribute>` - The attribute if it exists pub(crate) fn get_attr<'a>( derive_input: &'a syn::DeriveInput, name: &str, ) -> Option<&'a syn::Attribute> { - derive_input.attrs.iter().find(|a| a.path.is_ident(name)) + derive_input.attrs.iter().find(|a| a.path().is_ident(name)) } -/// Convert an attribute to a string -/// -/// # Arguments -/// * `attr` - The attribute to convert -/// -/// # Returns -/// * `syn::LitStr` - The string value of the attribute pub(crate) fn attr_to_str(attr: &syn::Attribute) -> syn::LitStr { - match attr.parse_meta() { - Ok(syn::Meta::NameValue(syn::MetaNameValue { - lit: syn::Lit::Str(s), + match &attr.meta { + syn::Meta::NameValue(syn::MetaNameValue { + value: + syn::Expr::Lit(syn::ExprLit { + lit: syn::Lit::Str(s), + .. + }), .. - })) => s, - Ok(_m) => panic!("Expected a string literal, got"), - Err(e) => panic!("{}", e), + }) => s.clone(), + _ => panic!("Expected #[attr = \"string\"] attribute"), } } -/// Convert an attribute to an integer -/// -/// # Arguments -/// * `attr` - The attribute to convert -/// -/// # Returns -/// * `syn::Result` - The integer value of the attribute pub(crate) fn attr_to_expr(attr: &syn::Attribute) -> syn::Result { - let parser = |input: syn::parse::ParseStream| { - input.parse::() // parse expression after '=' - }; - syn::parse::Parser::parse2(parser, attr.tokens.clone()) + attr.parse_args::() } diff --git a/rust/tvm-ffi-stubgen/Cargo.toml b/rust/tvm-ffi-stubgen/Cargo.toml new file mode 100644 index 00000000..4157065b --- /dev/null +++ b/rust/tvm-ffi-stubgen/Cargo.toml @@ -0,0 +1,38 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "tvm-ffi-stubgen" +description = "Rust stub generator for tvm-ffi" +version = "0.1.0" +edition = "2024" +rust-version = "1.85" +license = "Apache-2.0" + +[[bin]] +name = "tvm-ffi-stubgen" +path = "src/main.rs" + +[dependencies] +clap = { version = "4.5", features = ["derive"] } +env_logger = "0.11.9" +libloading = "0.8" +log = "0.4.29" +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +toml = "0.8" +tvm-ffi = { version = "0.1.0-alpha.0", path = "../tvm-ffi" } diff --git a/rust/tvm-ffi-stubgen/README.md b/rust/tvm-ffi-stubgen/README.md new file mode 100644 index 00000000..ba668412 --- /dev/null +++ b/rust/tvm-ffi-stubgen/README.md @@ -0,0 +1,363 @@ +# Rust Stubgen Guide + +`tvm-ffi-stubgen` generates Rust stubs from TVM-FFI reflection metadata. +This document is design-oriented and focuses on generated interface forms and implementation choices. + +## Table of Contents + +- [Document Scope](#document-scope) +- [Generated Interface Forms](#generated-interface-forms) +- [Object Model and Inheritance](#object-model-and-inheritance) +- [Field Accessor Style](#field-accessor-style) +- [Subtyping and Cast Rules](#subtyping-and-cast-rules) +- [repr(C) Decision Rules](#reprc-decision-rules) +- [Safety and Fallback Strategy](#safety-and-fallback-strategy) +- [TODO](#todo) +- [Related User Guide](#related-user-guide) + +## Document Scope + +This README intentionally does not duplicate full command-line tutorial content. +For command usage and end-to-end calling examples, see: + +- `docs/packaging/rust_stubgen.md` + +## Generated Interface Forms + +Stubgen emits a public facade (`src/lib.rs`) plus detail modules: + +- `src/_tvm_ffi_stubgen_detail/functions.rs` +- `src/_tvm_ffi_stubgen_detail/types.rs` + +By default the generator runs `cargo fmt` on the emitted crate after writing these files. +Pass `--no-format` to keep the raw generated text when debugging formatting-sensitive output. + +### Function Wrappers + +#### Typed wrapper path + +When type schema is fully known, function wrappers are generated as typed Rust APIs: + +```rust +pub fn add_one(_0: i64) -> Result { ... } +``` + +#### Packed fallback path + +When schema is not fully resolved, wrappers use packed calling style: + +```rust +pub fn echo(args: &[Any]) -> Result { ... } +``` + +### Type Wrappers + +#### repr(C) path (preferred) + +For types with known `total_size`: + +- `#[repr(C)] Obj` with typed fields and `[u8; N]` gaps +- `#[derive(ObjectRef, Clone)] ` +- `impl_object_hierarchy!(...)` +- direct-field `get_` accessors + +Example shape: + +```rust +#[repr(C)] +pub struct PrimExprObj { + __tvm_ffi_object_parent: BaseExprObj, + dtype: tvm_ffi::DLDataType, + _gap0: [u8; 4], // C++ tail padding +} +``` + +Gaps cover C++ tail padding, vtable pointers, and fields whose type schema +is not mappable to Rust. This allows the vast majority of types to use +repr(C) layout even when metadata is incomplete. + +#### fallback wrapper path + +For types without `total_size` metadata (no `ObjectDef` registered): + +- `define_object_wrapper!(Type, "type.key")` +- field access via `FieldGetter` + +### Object Method Lookup Path + +Object methods (including `__ffi_init__`) are generated from type reflection metadata, +not from global function registry names: + +- generated code calls `tvm_ffi::object_wrapper::resolve_type_method(type_key, method_name)` +- runtime lookup path is `TVMFFITypeKeyToIndex -> TVMFFIGetTypeInfo -> methods[]` +- the `method` entry is converted from `AnyView` to owned `Any`, then to `ffi.Function` + +Global wrappers under `functions.rs` still use `Function::get_global`, but type methods in +`types.rs` no longer assume `.` is globally registered. + +For constructor-like methods (`__ffi_init__`), stubgen emits `new(...)` directly as the public +Rust API (or `ffi_init` only when a user-defined `new` method already exists). + +## Object Model and Inheritance + +repr(C) object inheritance is modeled by composition and deref chain: + +### Obj-level layout + +Derived object stores parent object as first field: + +```rust +#[repr(C)] +pub struct DerivedObj { + __tvm_ffi_object_parent: BaseObj, + extra: i64, +} +``` + +### Ref-level inheritance + +Ref wrappers use `impl_object_hierarchy!` to establish: + +- `Deref Base>` +- `From for Base/ObjectRef` (upcast) +- `TryFrom for Derived` (downcast) + +## Field Accessor Style + +All getter variants share a unified calling convention so callers do not need +to know which code path was used: + +- name prefix is always `get_` +- return is infallible (panics internally rather than returning `Result`) +- only fields of the current type are generated; inherited getters are available + via deref auto-coercion + +### Direct struct field getters (repr(C) layout path) + +Fields that map cleanly to the repr(C) struct body are accessed by direct +memory reference: + +- POD field → return by value +- object/container field → clone and return + +```rust +impl PrimExpr { + pub fn get_dtype(&self) -> tvm_ffi::DLDataType { + self.data.dtype // POD: copy + } +} +impl ForFrame { + pub fn get_doms(&self) -> tvm_ffi::Array { + self.data.doms.clone() // object: clone + } +} +``` + +### Non-layout field getters (FieldGetter runtime path) + +Some registered ObjectDef fields cannot be placed in the repr(C) struct body: + +1. **Parent-range fields** — offset falls inside the parent type's address + range. Example: `ForFrame.vars` at offset 56 is within `TIRFrame`'s 0..64 + range. These are fields the child type "fills into" a gap slot of the parent. +2. **Schema-unmappable fields** — the field's type schema has no Rust + representation in the current type_map. + +For both cases stubgen generates a `LazyLock>` static and a +matching `get_xxx` method with the **same signature** as a direct getter: + +- typed (schema mappable): `pub fn get_xxx(&self) -> T` via `FieldGetter::get()` +- untyped (schema unmappable): `pub fn get_xxx(&self) -> tvm_ffi::Any` via `get_any()` + +```rust +// auto-generated: parent-range field, typed +static FIELD_FORFRAME__VARS: LazyLock>> = ...; +impl ForFrame { + pub fn get_vars(&self) -> tvm_ffi::Array { + let __obj: tvm_ffi::object::ObjectRef = self.clone().into(); + FIELD_FORFRAME__VARS.get(&__obj).expect("...") + } +} +``` + +Callers use `frame.get_vars()` exactly as they would `frame.get_doms()`, with +no awareness of the access mechanism. + +### Debugging non-layout fields + +To inspect what fields TVM has registered for a type (including those not in +the struct layout), use the reflection API at runtime: + +```rust +// example: inspect ForFrame field offsets and schemas +let info = unsafe { tvm_ffi_sys::TVMFFIGetTypeInfo(type_index) }; +for i in 0..info.num_fields { + let f = &(*info.fields)[i]; + println!("name={:?} offset={} schema={:?}", f.name, f.offset, f.metadata); +} +``` + +Alternatively, set `RUST_LOG=trace` when running stubgen to see all field +offset and schema decisions in the log output. + +## Subtyping and Cast Rules + +Stubgen-generated repr(C) refs use standard Rust traits as the only user-facing cast API: + +- borrow upcast: `Deref` +- consuming upcast: `From` / `.into()` +- consuming downcast: `TryFrom` / `.try_into()` + +This avoids custom cast traits and keeps compile-time type constraints explicit. + +## repr(C) Decision Rules + +`check_repr_c` gates repr(C) generation using a **gap-filling** strategy. + +### Hard requirements (cause fallback to `define_object_wrapper!`) + +- Type must have `total_size > 0` (i.e. `ObjectDef` was called for it) +- No overlapping fields + +### Soft handling (does NOT cause fallback) + +- **Tail padding / vtable / unregistered fields**: byte ranges between registered + fields (or between the last field and `total_size`) are emitted as `[u8; N]` gap + members in the `#[repr(C)]` struct. +- **Parent type not in type_map or not repr(C)-compatible**: the parent region is + treated as a gap after the `Object` header. The struct uses `tvm_ffi::object::Object` + as the parent field and gap-fills the bytes between Object and the first known field. +- **Parent-range fields** (registered by this type but `offset < parent_total_size`): + cannot be placed in the struct body. Tracked in `non_layout_fields`; a FieldGetter + static and `get_*` accessor are emitted instead. These fields physically reside + in a gap slot of the parent type that the child fills with its own data. +- **Field type schema not mappable to Rust**: the field becomes a gap in the struct + body. Also tracked in `non_layout_fields` and exposed via `get_*` returning + `tvm_ffi::Any` (using `FieldGetter::get_any()`). + +In all non-layout cases the emitted `get_*` method has the same naming and +infallible-return signature as a direct struct-field getter. + +### Schema mapping rules + +Representative mappings include: + +- `Any` / `ffi.Any` -> `tvm_ffi::AnyValue` +- `ffi.Array` -> `tvm_ffi::Array` +- `ffi.Array` (no args) -> `tvm_ffi::Array` +- `ffi.Map` -> `tvm_ffi::Map` +- `Optional` -> `Option` +- `Optional` (no args) -> `Option` + +## Multi-Prefix Generation + +`--init-prefix` accepts multiple values. Behavior depends on the count: + +- **Single prefix** (e.g. `--init-prefix testing`): the prefix is stripped and items land + at the crate root. This is the default backward-compatible mode. +- **Multiple prefixes** (e.g. `--init-prefix tl --init-prefix ir --init-prefix tir`): + no prefix is stripped; each prefix naturally becomes a top-level module. + +Example with multiple prefixes: + +``` +tl.KernelLaunch → crate::tl::KernelLaunch +ir.Span → crate::ir::Span +tir.BufferLoad → crate::tir::BufferLoad +script.ir_builder.* → crate::script::ir_builder::* +``` + +This allows a single generated crate to cover multiple namespaces. Cross-namespace +`repr(C)` inheritance (e.g. `tir.PrimFunc` extending `ir.BaseFunc`) resolves within +the crate without workarounds. + +## Safety and Fallback Strategy + +Generated user-facing code is intended to remain safe Rust. + +### Safety boundary + +- unsafe operations are encapsulated in `tvm-ffi` internals and macros +- generated wrappers and getters are safe APIs + +### Built-in filtering and fallback + +- built-in `ffi.*` primitives are not re-generated as wrapper types +- only types without `total_size` metadata fall back to `define_object_wrapper!` +- types with incomplete field schemas or unmappable parents still get repr(C) layout + via gap-filling + +### Logging + +Stubgen uses the `log` crate. Set `RUST_LOG` to control verbosity: + +- `RUST_LOG=debug` — shows repr(C) pass/fail decisions and field mapping failures +- `RUST_LOG=trace` — additionally shows per-field offset/size/schema details + +## TODO + +Known gaps and design issues that remain open. + +### Ancestor chain is truncated when direct parent is not repr(C)-mappable + +When `check_repr_c` cannot map the direct parent type, `repr_c.rs` falls back to +`tvm_ffi::object::Object` as the layout parent and fills the missing bytes with a gap. +However, the second pass in `generate.rs` that builds `ancestor_chain` only propagates +through types whose `parent_type_key` is set in `ReprCInfo`; when it is `None` the chain +collapses to `[tvm_ffi::object::ObjectRef]`. + +Consequence: if the C++ hierarchy is `Object → A (mappable) → B (not mappable) → C`, +the generated code for `C` emits + +```rust +tvm_ffi::impl_object_hierarchy!(C: tvm_ffi::object::ObjectRef); +``` + +instead of + +```rust +tvm_ffi::impl_object_hierarchy!(C: A, tvm_ffi::object::ObjectRef); +``` + +This means `From for A` and `TryFrom for C` are not generated, and getters +inherited from `A` are inaccessible via deref on `C` even though the layout is correct. + +The ancestor chain logic should be derived from the runtime type ancestry table +(`TVMFFIGetTypeInfo → type_acenstors`) independently of layout mappability, so that +upcast/downcast correctness is preserved regardless of whether every intermediate type +has a usable repr(C) layout. + +### Common interface between fallback and repr(C) paths + +`define_object_wrapper!` types and repr(C) types currently expose different API surfaces: + +- repr(C) types: `Deref` chain, `From`/`TryFrom`, direct `get_*` accessors, + `get_*` FieldGetter accessors for non-layout fields +- fallback types: `from_object` / `as_object_ref` / `into_object_ref`, runtime `FieldGetter` + +Code that depends on a given type must know which generation path was used, and that +path can change between stubgen versions as reflection metadata improves. A type that +was a thin wrapper in version N may become a repr(C) type in version N+1, silently +breaking downstream call sites that relied on `from_object` or `as_object_ref`. + +A stable, version-independent interface layer is needed so that user code does not need +to distinguish between the two paths, and so that crates built against one stubgen +version remain source-compatible with crates built against a later one. + +### Parent-range field layout override + +When a child type registers a field at an offset within the parent's address range +(e.g. `ForFrame.vars` at offset 56 inside `TIRFrame`'s 0..64 gap), the field is +currently excluded from the repr(C) struct body and exposed only via `FieldGetter`. +This is correct for access but sub-optimal: the child field is physically in a gap +slot that the parent never uses, so it could safely be placed in the struct layout. + +Fix: inspect the parent's layout for the specific offset. If the parent has a +`[u8; N]` gap entry covering that offset, allow the child's field to override it +directly in the struct rather than routing through FieldGetter. + +## Related User Guide + +For generation command-line usage and step-by-step invocation examples, see: + +- `docs/packaging/rust_stubgen.md` diff --git a/rust/tvm-ffi-stubgen/build.rs b/rust/tvm-ffi-stubgen/build.rs new file mode 100644 index 00000000..3026cc71 --- /dev/null +++ b/rust/tvm-ffi-stubgen/build.rs @@ -0,0 +1,63 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::env; +use std::process::Command; + +fn main() { + let lib_dir = tvm_ffi_libdir(); + let target_os = env::var("CARGO_CFG_TARGET_OS").unwrap_or_default(); + + if target_os == "linux" || target_os == "macos" { + println!("cargo:rustc-link-arg=-Wl,-rpath,{}", lib_dir); + } + + let ld_var = match target_os.as_str() { + "windows" => "PATH", + "macos" => "DYLD_LIBRARY_PATH", + "linux" => "LD_LIBRARY_PATH", + _ => "", + }; + if !ld_var.is_empty() { + let current = env::var(ld_var).unwrap_or_default(); + let separator = if ld_var == "PATH" { ";" } else { ":" }; + let value = if current.is_empty() { + lib_dir.clone() + } else { + format!("{}{}{}", lib_dir, separator, current) + }; + println!("cargo:rustc-env={}={}", ld_var, value); + } +} + +fn tvm_ffi_libdir() -> String { + let output = Command::new("tvm-ffi-config") + .arg("--libdir") + .output() + .expect("tvm-ffi-config --libdir"); + if !output.status.success() { + panic!("tvm-ffi-config --libdir failed"); + } + let lib_dir = String::from_utf8(output.stdout) + .expect("tvm-ffi-config output") + .trim() + .to_string(); + if lib_dir.is_empty() { + panic!("tvm-ffi-config returned empty libdir"); + } + lib_dir +} diff --git a/rust/tvm-ffi-stubgen/src/cli.rs b/rust/tvm-ffi-stubgen/src/cli.rs new file mode 100644 index 00000000..0184ee1c --- /dev/null +++ b/rust/tvm-ffi-stubgen/src/cli.rs @@ -0,0 +1,41 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use clap::Parser; +use std::path::PathBuf; + +#[derive(Parser, Debug)] +#[command( + name = "tvm-ffi-stubgen", + about = "Generate Rust stubs from tvm-ffi metadata" +)] +pub struct Args { + #[arg(value_name = "OUT_DIR")] + pub out_dir: PathBuf, + #[arg(long = "dlls", value_delimiter = ';', num_args = 1..)] + pub dlls: Vec, + #[arg(long = "init-prefix", num_args = 1..)] + pub init_prefix: Vec, + #[arg(long = "init-crate")] + pub init_crate: String, + #[arg(long = "tvm-ffi-path")] + pub tvm_ffi_path: Option, + #[arg(long = "overwrite")] + pub overwrite: bool, + #[arg(long = "no-format")] + pub no_format: bool, +} diff --git a/rust/tvm-ffi-stubgen/src/ffi.rs b/rust/tvm-ffi-stubgen/src/ffi.rs new file mode 100644 index 00000000..4164f6e1 --- /dev/null +++ b/rust/tvm-ffi-stubgen/src/ffi.rs @@ -0,0 +1,87 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use libloading::Library; +use std::path::PathBuf; +use tvm_ffi::Array; +use tvm_ffi::tvm_ffi_sys::{ + TVMFFIByteArray, TVMFFIGetTypeInfo, TVMFFITypeInfo, TVMFFITypeKeyToIndex, +}; +use tvm_ffi::{Function, Result as FfiResult, String as FfiString}; + +pub(crate) fn load_dlls(paths: &[PathBuf]) -> Result, Box> { + let mut libs = Vec::new(); + for path in paths { + let lib = unsafe { Library::new(path) }?; + libs.push(lib); + } + Ok(libs) +} + +pub(crate) fn list_global_function_names() -> FfiResult> { + let functor_func = Function::get_global("ffi.FunctionListGlobalNamesFunctor")?; + let functor_any = functor_func.call_tuple_with_len::<0, _>(())?; + let functor: Function = functor_any.try_into()?; + let count_any = functor.call_tuple_with_len::<1, _>((-1i64,))?; + let count: i64 = count_any.try_into()?; + let mut out = Vec::new(); + for idx in 0..count { + let name_any = functor.call_tuple_with_len::<1, _>((idx,))?; + let name: FfiString = name_any.try_into()?; + out.push(name.as_str().to_string()); + } + Ok(out) +} + +pub(crate) fn list_registered_type_keys() -> FfiResult> { + let get_keys = Function::get_global("ffi.GetRegisteredTypeKeys")?; + let keys_any = get_keys.call_tuple_with_len::<0, _>(())?; + let mut out = Vec::new(); + let keys: Array = keys_any.try_into()?; + for key in &keys { + out.push(key.as_str().to_string()); + } + Ok(out) +} + +pub(crate) fn get_type_info(type_key: &str) -> Option<&'static TVMFFITypeInfo> { + unsafe { + let key = TVMFFIByteArray::from_str(type_key); + let mut tindex = 0; + if TVMFFITypeKeyToIndex(&key, &mut tindex) != 0 { + return None; + } + let info = TVMFFIGetTypeInfo(tindex); + if info.is_null() { None } else { Some(&*info) } + } +} + +pub(crate) fn get_global_func_metadata(name: &str) -> FfiResult> { + let func = Function::get_global("ffi.GetGlobalFuncMetadata")?; + let name_arg = FfiString::from(name); + let meta_any = func.call_tuple_with_len::<1, _>((name_arg,))?; + let meta: FfiString = meta_any.try_into()?; + Ok(Some(meta.as_str().to_string())) +} + +pub(crate) fn byte_array_to_string_opt(value: &TVMFFIByteArray) -> Option { + if value.data.is_null() || value.size == 0 { + return None; + } + let slice = unsafe { std::slice::from_raw_parts(value.data, value.size) }; + Some(String::from_utf8_lossy(slice).to_string()) +} diff --git a/rust/tvm-ffi-stubgen/src/generate.rs b/rust/tvm-ffi-stubgen/src/generate.rs new file mode 100644 index 00000000..e9ecbaaa --- /dev/null +++ b/rust/tvm-ffi-stubgen/src/generate.rs @@ -0,0 +1,1309 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::cli::Args; +use crate::ffi; +use crate::model::{ + FieldGen, FunctionGen, FunctionSig, GetterSpec, MethodGen, ModuleNode, RustType, TypeGen, +}; +use crate::repr_c; +use crate::schema::{TypeSchema, extract_type_schema, parse_type_schema}; +use crate::utils; +use std::collections::BTreeMap; +use std::fmt::Write as _; +use toml::value::Table; + +const METHOD_FLAG_STATIC: i64 = 1 << 2; + +pub(crate) fn build_type_map(type_keys: &[String], prefix: &str) -> BTreeMap { + let mut map = BTreeMap::new(); + for key in type_keys { + let (mods, name) = split_name(key, prefix); + let rust_name = sanitize_ident(&name, IdentStyle::Type); + let module_path = module_path(&mods); + let path = if module_path.is_empty() { + format!("crate::{}", rust_name) + } else { + format!("crate::{}::{}", module_path, rust_name) + }; + map.insert(key.clone(), path); + } + map +} + +pub(crate) fn build_function_entries( + func_names: &[String], + type_map: &BTreeMap, + prefix: &str, +) -> tvm_ffi::Result, FunctionGen)>> { + let mut out = Vec::new(); + for full_name in func_names { + let metadata = ffi::get_global_func_metadata(full_name)?; + let schema = metadata + .and_then(|meta| extract_type_schema(&meta)) + .and_then(|schema| parse_type_schema(&schema)); + let sig = build_function_sig(schema.as_ref(), type_map, None); + let (mods, name) = split_name(full_name, prefix); + let rust_name = sanitize_ident(&name, IdentStyle::Function); + out.push(( + mods, + FunctionGen { + full_name: full_name.clone(), + rust_name, + sig, + }, + )); + } + Ok(out) +} + +pub(crate) fn build_type_entries( + type_keys: &[String], + type_map: &BTreeMap, + prefix: &str, +) -> tvm_ffi::Result, TypeGen)>> { + let mut out = Vec::new(); + for key in type_keys { + let (mods, name) = split_name(key, prefix); + let rust_name = sanitize_ident(&name, IdentStyle::Type); + let mut methods = Vec::new(); + let mut fields = Vec::new(); + let mut type_depth = 0i32; + let repr_c_info = repr_c::check_repr_c(key, type_map); + if let Some(info) = ffi::get_type_info(key) { + type_depth = info.type_depth; + if info.num_methods > 0 && !info.methods.is_null() { + let method_slice = + unsafe { std::slice::from_raw_parts(info.methods, info.num_methods as usize) }; + let has_user_new = method_slice.iter().any(|method| { + matches!( + ffi::byte_array_to_string_opt(&method.name).as_deref(), + Some("new") + ) + }); + for method in method_slice { + let method_name = match ffi::byte_array_to_string_opt(&method.name) { + Some(name) => name, + None => continue, + }; + let rust_method_name = if method_name == "__ffi_init__" { + if has_user_new { + "ffi_init".to_string() + } else { + "new".to_string() + } + } else { + map_method_name(&method_name) + }; + let is_static = (method.flags & METHOD_FLAG_STATIC) != 0; + let meta = ffi::byte_array_to_string_opt(&method.metadata); + let schema = meta + .as_deref() + .and_then(extract_type_schema) + .and_then(|s| parse_type_schema(&s)); + let sig = + build_method_sig(schema.as_ref(), type_map, Some(key.as_str()), is_static); + methods.push(MethodGen { + source_name: method_name, + rust_name: rust_method_name, + sig, + is_static, + }); + } + } + if info.num_fields > 0 && !info.fields.is_null() { + let field_slice = + unsafe { std::slice::from_raw_parts(info.fields, info.num_fields as usize) }; + for field in field_slice { + let field_name = match ffi::byte_array_to_string_opt(&field.name) { + Some(name) => name, + None => continue, + }; + let rust_field_name = sanitize_ident(&field_name, IdentStyle::Function); + let meta = ffi::byte_array_to_string_opt(&field.metadata); + let schema = meta + .as_deref() + .and_then(extract_type_schema) + .and_then(|s| parse_type_schema(&s)); + let ty = match schema.as_ref() { + Some(schema) => rust_type_for_schema(schema, type_map, Some(key.as_str())), + None => RustType::unsupported("tvm_ffi::Any"), + }; + fields.push(FieldGen { + name: field_name, + rust_name: rust_field_name, + ty, + }); + } + } + } + out.push(( + mods, + TypeGen { + type_key: key.clone(), + rust_name: rust_name.clone(), + methods, + fields, + type_depth, + repr_c_info: repr_c_info.clone(), + getter_specs: Vec::new(), + ancestor_chain: Vec::new(), + }, + )); + } + // Second pass: fill getter_specs and ancestor_chain for repr_c types in dependency order (base before derived). + let mut type_key_to_idx: BTreeMap = BTreeMap::new(); + for (idx, (_, ty)) in out.iter().enumerate() { + type_key_to_idx.insert(ty.type_key.clone(), idx); + } + let mut order: Vec = (0..out.len()).collect(); + order.sort_by_key(|&i| out[i].1.type_depth); + for &idx in &order { + let (_, ref ty) = out[idx]; + let repr_c_info = match &ty.repr_c_info { + Some(r) => r, + None => continue, + }; + let parent_specs: Vec = + if let Some(ref parent_key) = repr_c_info.parent_type_key { + let parent_idx = *type_key_to_idx.get(parent_key).unwrap_or(&idx); + out[parent_idx].1.getter_specs.clone() + } else { + Vec::new() + }; + let getter_specs = build_getter_specs(&ty.type_key, &ty.repr_c_info, &parent_specs); + + // Build ancestor chain: [DirectParent, Grandparent, ..., ObjectRef] + let ancestor_chain = if let Some(ref parent_key) = repr_c_info.parent_type_key { + if parent_key == "ffi.Object" { + vec!["tvm_ffi::object::ObjectRef".to_string()] + } else if let Some(parent_rust) = type_map.get(parent_key) { + let parent_idx = *type_key_to_idx.get(parent_key).unwrap_or(&idx); + let mut chain = vec![parent_rust.clone()]; + // Inherit parent's ancestors + chain.extend(out[parent_idx].1.ancestor_chain.clone()); + chain + } else { + vec!["tvm_ffi::object::ObjectRef".to_string()] + } + } else { + vec!["tvm_ffi::object::ObjectRef".to_string()] + }; + + out[idx].1.getter_specs = getter_specs; + out[idx].1.ancestor_chain = ancestor_chain; + } + Ok(out) +} + +fn build_getter_specs( + _type_key: &str, + repr_c_info: &Option, + parent_specs: &[GetterSpec], +) -> Vec { + let info = match repr_c_info { + Some(i) => i, + None => return Vec::new(), + }; + let mut specs = Vec::new(); + for parent in parent_specs { + let access_expr = if parent.access_expr.starts_with("self.data.") { + format!( + "self.data.__tvm_ffi_object_parent.{}", + &parent.access_expr["self.data.".len()..] + ) + } else { + parent.access_expr.clone() + }; + specs.push(GetterSpec { + method_name: parent.method_name.clone(), + access_expr, + ret_type: parent.ret_type.clone(), + }); + } + for f in info.fields() { + let method_name = format!("get_{}", f.rust_name); + let access_expr = if f.is_pod { + format!("self.data.{}", f.rust_name) + } else { + format!("self.data.{}.clone()", f.rust_name) + }; + specs.push(GetterSpec { + method_name, + access_expr, + ret_type: f.rust_type.clone(), + }); + } + specs +} + +pub(crate) fn build_function_modules( + funcs: Vec<(Vec, FunctionGen)>, + _prefix: &str, +) -> ModuleNode { + let mut root = ModuleNode::default(); + for (mods, func) in funcs { + insert_function(&mut root, &mods, func); + } + root +} + +pub(crate) fn build_type_modules(types: Vec<(Vec, TypeGen)>, _prefix: &str) -> ModuleNode { + let mut root = ModuleNode::default(); + for (mods, ty) in types { + insert_type(&mut root, &mods, ty); + } + root +} + +fn build_function_sig( + schema: Option<&TypeSchema>, + type_map: &BTreeMap, + self_type_key: Option<&str>, +) -> FunctionSig { + match schema { + None => FunctionSig::packed(), + Some(schema) if schema.origin != "ffi.Function" => FunctionSig::packed(), + Some(schema) if schema.args.is_empty() => FunctionSig::packed(), + Some(schema) => { + let ret = rust_type_for_schema(&schema.args[0], type_map, self_type_key); + let args: Vec = schema.args[1..] + .iter() + .map(|arg| rust_type_for_schema(arg, type_map, self_type_key)) + .collect(); + FunctionSig::from_types(args, ret) + } + } +} + +fn build_method_sig( + schema: Option<&TypeSchema>, + type_map: &BTreeMap, + self_type_key: Option<&str>, + is_static: bool, +) -> FunctionSig { + if !is_static { + return FunctionSig::packed(); + } + build_function_sig(schema, type_map, self_type_key) +} + +fn rust_type_for_schema( + schema: &TypeSchema, + type_map: &BTreeMap, + _self_type_key: Option<&str>, +) -> RustType { + match schema.origin.as_str() { + "None" => RustType::supported("()"), + "bool" => RustType::supported("bool"), + "int" => RustType::supported("i64"), + "float" => RustType::supported("f64"), + "Device" => RustType::supported("tvm_ffi::DLDevice"), + "DataType" => RustType::supported("tvm_ffi::DLDataType"), + "ffi.String" | "std::string" | "const char*" | "ffi.SmallStr" => { + RustType::supported("tvm_ffi::String") + } + "ffi.Bytes" | "TVMFFIByteArray*" | "ffi.SmallBytes" => { + RustType::supported("tvm_ffi::Bytes") + } + "ffi.Function" => RustType::supported("tvm_ffi::Function"), + "ffi.Object" => RustType::supported("tvm_ffi::object::ObjectRef"), + "ffi.Tensor" | "DLTensor*" => RustType::supported("tvm_ffi::Tensor"), + "ffi.Shape" => RustType::supported("tvm_ffi::Shape"), + "ffi.Module" => RustType::supported("tvm_ffi::Module"), + "Optional" => match schema.args.as_slice() { + [inner] => { + let inner_ty = rust_type_for_schema(inner, type_map, _self_type_key); + if inner_ty.supported { + RustType::supported(&format!("Option<{}>", inner_ty.name)) + } else { + RustType::unsupported("tvm_ffi::Any") + } + } + _ => RustType::unsupported("tvm_ffi::Any"), + }, + "ffi.Array" => match schema.args.as_slice() { + [inner] => { + let inner_ty = rust_type_for_schema(inner, type_map, _self_type_key); + if inner_ty.supported { + RustType::supported(&format!("tvm_ffi::Array<{}>", inner_ty.name)) + } else { + RustType::unsupported("tvm_ffi::Any") + } + } + _ => RustType::unsupported("tvm_ffi::Any"), + }, + "ffi.Map" => match schema.args.as_slice() { + [key, value] => { + let key_ty = rust_type_for_schema(key, type_map, _self_type_key); + let value_ty = rust_type_for_schema(value, type_map, _self_type_key); + if key_ty.supported && value_ty.supported { + RustType::supported(&format!( + "tvm_ffi::Map<{}, {}>", + key_ty.name, value_ty.name + )) + } else { + RustType::unsupported("tvm_ffi::Any") + } + } + _ => RustType::unsupported("tvm_ffi::Any"), + }, + "Any" | "ffi.Any" => RustType::supported("tvm_ffi::AnyValue"), + "Union" | "Variant" | "tuple" | "list" | "dict" => RustType::unsupported("tvm_ffi::Any"), + other => match type_map.get(other) { + Some(path) => RustType::object_wrapper(path), + None => RustType::unsupported("tvm_ffi::Any"), + }, + } +} + +fn insert_function(root: &mut ModuleNode, mods: &[String], func: FunctionGen) { + let mut node = root; + for module in mods { + node = node + .children + .entry(module.clone()) + .or_insert_with(|| ModuleNode { + name: module.clone(), + ..ModuleNode::default() + }); + } + node.functions.push(func); +} + +fn insert_type(root: &mut ModuleNode, mods: &[String], ty: TypeGen) { + let mut node = root; + for module in mods { + node = node + .children + .entry(module.clone()) + .or_insert_with(|| ModuleNode { + name: module.clone(), + ..ModuleNode::default() + }); + } + node.types.push(ty); +} + +fn split_name(full_name: &str, prefix: &str) -> (Vec, String) { + let remainder = if prefix.is_empty() { + full_name + } else { + full_name.strip_prefix(prefix).unwrap_or(full_name) + }; + let parts: Vec<&str> = remainder.split('.').filter(|p| !p.is_empty()).collect(); + if parts.is_empty() { + return (Vec::new(), "ffi".to_string()); + } + if parts.len() == 1 { + return (Vec::new(), parts[0].to_string()); + } + let mut mods = Vec::new(); + for part in &parts[..parts.len() - 1] { + mods.push(sanitize_ident(part, IdentStyle::Module)); + } + (mods, parts[parts.len() - 1].to_string()) +} + +fn module_path(mods: &[String]) -> String { + if mods.is_empty() { + return String::new(); + } + mods.join("::") +} + +pub(crate) fn render_cargo_toml( + args: &Args, + _type_map: &BTreeMap, +) -> Result> { + let tvm_ffi_path = match &args.tvm_ffi_path { + Some(path) => path.clone(), + None => utils::default_tvm_ffi_path()?, + }; + let tvm_ffi_path = tvm_ffi_path.canonicalize().unwrap_or(tvm_ffi_path); + let tvm_ffi_path_str = tvm_ffi_path.to_string_lossy().to_string(); + + let mut package = Table::new(); + package.insert( + "name".to_string(), + toml::Value::String(args.init_crate.clone()), + ); + package.insert( + "version".to_string(), + toml::Value::String("0.1.0".to_string()), + ); + package.insert( + "edition".to_string(), + toml::Value::String("2024".to_string()), + ); + + let mut tvm_ffi = Table::new(); + tvm_ffi.insert("path".to_string(), toml::Value::String(tvm_ffi_path_str)); + + let mut dependencies = Table::new(); + dependencies.insert("tvm-ffi".to_string(), toml::Value::Table(tvm_ffi)); + + let mut doc = Table::new(); + doc.insert("package".to_string(), toml::Value::Table(package)); + doc.insert("dependencies".to_string(), toml::Value::Table(dependencies)); + + Ok(toml::to_string(&toml::Value::Table(doc))?) +} + +pub(crate) fn render_lib_rs(functions_root: &ModuleNode, types_root: &ModuleNode) -> String { + let mut out = String::new(); + out.push_str( + r#"#![allow( + clippy::needless_question_mark, + clippy::too_many_arguments, + clippy::enum_variant_names, + clippy::manual_div_ceil, + clippy::just_underscores_and_digits, + non_snake_case +)] + +pub mod _tvm_ffi_stubgen_detail { + pub mod functions; + pub mod types; +} + +"#, + ); + render_facade_module( + &mut out, + Some(functions_root), + Some(types_root), + &[], + 0, + true, + ); + out.push_str( + r#" +pub fn load_library(path: &str) -> tvm_ffi::Result { + tvm_ffi::Module::load_from_file(path) +} +"#, + ); + out +} + +pub(crate) fn render_build_rs() -> String { + let mut out = String::new(); + out.push_str( + r#"use std::env; +use std::process::Command; + +fn update_ld_library_path(lib_dir: &str) { + let os_env_var = match env::var("CARGO_CFG_TARGET_OS").as_deref() { + Ok("windows") => "PATH", + Ok("macos") => "DYLD_LIBRARY_PATH", + Ok("linux") => "LD_LIBRARY_PATH", + _ => "", + }; + if os_env_var.is_empty() { + return; + } + let current_val = env::var(os_env_var).unwrap_or_else(|_| String::new()); + let separator = if os_env_var == "PATH" { ";" } else { ":" }; + let new_ld_path = if current_val.is_empty() { + lib_dir.to_string() + } else { + format!("{}{}{}", current_val, separator, lib_dir) + }; + println!("cargo:rustc-env={}={}", os_env_var, new_ld_path); +} + +fn main() { + let output = Command::new("tvm-ffi-config") + .arg("--libdir") + .output() + .expect("Failed to run tvm-ffi-config"); + if !output.status.success() { + panic!("tvm-ffi-config --libdir failed"); + } + let lib_dir = String::from_utf8(output.stdout) + .expect("Invalid UTF-8 output from tvm-ffi-config") + .trim() + .to_string(); + if lib_dir.is_empty() { + panic!("tvm-ffi-config returned empty library path"); + } + println!("cargo:rustc-link-search=native={}", lib_dir); + println!("cargo:rustc-link-lib=dylib=tvm_ffi"); + update_ld_library_path(&lib_dir); +} +"#, + ); + out +} + +pub(crate) fn render_functions_rs(root: &ModuleNode) -> String { + let mut out = String::new(); + out.push_str( + r#"#![allow(unused_imports)] +#![allow(non_snake_case, nonstandard_style)] + +use std::sync::LazyLock; +use tvm_ffi::{Any, AnyView, Function, Result}; + +"#, + ); + render_function_module(&mut out, root, 0); + out +} + +pub(crate) fn render_types_rs(root: &ModuleNode, type_map: &BTreeMap) -> String { + let mut out = String::new(); + out.push_str( + r#"#![allow(unused_imports)] +#![allow(non_snake_case, nonstandard_style)] + +use std::sync::LazyLock; +use tvm_ffi::{Any, AnyView, ObjectArc, Result}; + +"#, + ); + render_type_module(&mut out, root, 0, type_map); + out +} + +fn render_facade_module( + out: &mut String, + functions: Option<&ModuleNode>, + types: Option<&ModuleNode>, + path: &[String], + indent: usize, + is_root: bool, +) { + // Check if this module has any actual content + let has_functions = functions.is_some_and(|node| !node.functions.is_empty()); + let has_types = + types.is_some_and(|node| node.types.iter().any(|ty| !is_builtin_type(&ty.type_key))); + + let mut child_names = std::collections::BTreeSet::new(); + if let Some(node) = functions { + child_names.extend(node.children.keys().cloned()); + } + if let Some(node) = types { + child_names.extend(node.children.keys().cloned()); + } + + // Skip rendering if the module is empty and has no children + if !is_root && !has_functions && !has_types && child_names.is_empty() { + return; + } + + let indent_str = " ".repeat(indent); + if !is_root { + let name = path.last().expect("module path missing"); + writeln!(out, "{}pub mod {} {{", indent_str, name).ok(); + } + + let current_indent = if is_root { + indent_str.clone() + } else { + " ".repeat(indent + 4) + }; + let module_path = if path.is_empty() { + String::new() + } else { + format!("::{}", path.join("::")) + }; + + if let Some(node) = functions { + for func in &node.functions { + writeln!( + out, + "{}pub use crate::_tvm_ffi_stubgen_detail::functions{}::{};", + current_indent, module_path, func.rust_name + ) + .ok(); + } + } + if let Some(node) = types { + for ty in &node.types { + // Skip built-in types that are not generated + if is_builtin_type(&ty.type_key) { + continue; + } + writeln!( + out, + "{}pub use crate::_tvm_ffi_stubgen_detail::types{}::{};", + current_indent, module_path, ty.rust_name + ) + .ok(); + } + } + + for child in child_names { + let mut child_path = path.to_vec(); + child_path.push(child.clone()); + let func_child = functions.and_then(|node| node.children.get(&child)); + let type_child = types.and_then(|node| node.children.get(&child)); + render_facade_module(out, func_child, type_child, &child_path, indent + 4, false); + } + + if !is_root { + writeln!(out, "{}}}", indent_str).ok(); + } +} + +fn render_function_module(out: &mut String, node: &ModuleNode, indent: usize) { + let indent_str = " ".repeat(indent); + if indent > 0 { + writeln!(out, "{}use std::sync::LazyLock;", indent_str).ok(); + writeln!( + out, + "{}use tvm_ffi::{{Any, AnyView, Function, Result}};", + indent_str + ) + .ok(); + writeln!(out).ok(); + } + for func in &node.functions { + render_function(out, func, indent); + } + for child in node.children.values() { + writeln!(out, "{}pub mod {} {{", indent_str, child.name).ok(); + render_function_module(out, child, indent + 4); + writeln!(out, "{}}}", indent_str).ok(); + } +} + +fn render_type_module( + out: &mut String, + node: &ModuleNode, + indent: usize, + type_map: &BTreeMap, +) { + let indent_str = " ".repeat(indent); + if indent > 0 { + writeln!(out, "{}use std::sync::LazyLock;", indent_str).ok(); + writeln!( + out, + "{}use tvm_ffi::{{Any, AnyView, ObjectArc, Result}};", + indent_str + ) + .ok(); + writeln!(out).ok(); + } + for ty in &node.types { + render_type(out, ty, indent, type_map); + } + for child in node.children.values() { + writeln!(out, "{}pub mod {} {{", indent_str, child.name).ok(); + render_type_module(out, child, indent + 4, type_map); + writeln!(out, "{}}}", indent_str).ok(); + } +} + +fn render_function(out: &mut String, func: &FunctionGen, indent: usize) { + let indent_str = " ".repeat(indent); + let static_name = static_ident("FUNC", &func.full_name); + writeln!( + out, + "{}static {}: LazyLock = LazyLock::new(|| Function::get_global(\"{}\").expect(\"missing global function\"));", + indent_str, static_name, func.full_name + ) + .ok(); + if func.sig.packed { + writeln!( + out, + "{}pub fn {}(args: &[Any]) -> Result {{", + indent_str, func.rust_name + ) + .ok(); + writeln!(out, "{} let func = &*{};", indent_str, static_name).ok(); + writeln!( + out, + "{} let views: Vec> = args.iter().map(AnyView::from).collect();", + indent_str + ) + .ok(); + writeln!(out, "{} func.call_packed(&views)", indent_str).ok(); + writeln!(out, "{}}}", indent_str).ok(); + writeln!(out).ok(); + return; + } + let args = render_args(&func.sig.args); + writeln!( + out, + "{}pub fn {}({}) -> Result<{}> {{", + indent_str, func.rust_name, args, func.sig.ret.name + ) + .ok(); + writeln!(out, "{} let func = &*{};", indent_str, static_name).ok(); + writeln!( + out, + "{} let typed = tvm_ffi::into_typed_fn!(func.clone(), Fn({}) -> Result<{}>);", + indent_str, + render_type_list(&func.sig.args), + func.sig.ret.typed_ret_name() + ) + .ok(); + let call_expr = format!("typed({})", render_call_args_typed(&func.sig.args)); + writeln!( + out, + "{} {}", + indent_str, + func.sig + .ret + .wrap_typed_return(&call_expr, func.sig.ret.typed_ret_name()) + ) + .ok(); + writeln!(out, "{}}}", indent_str).ok(); + writeln!(out).ok(); +} + +fn render_type(out: &mut String, ty: &TypeGen, indent: usize, type_map: &BTreeMap) { + // Filter out built-in types that are already provided by tvm-ffi + if is_builtin_type(&ty.type_key) { + return; + } + + let _indent_str = " ".repeat(indent); + if let Some(ref info) = ty.repr_c_info { + render_repr_c_type(out, ty, info, indent, type_map); + return; + } + render_fallback_type(out, ty, indent); +} + +fn is_builtin_type(type_key: &str) -> bool { + // Filter ffi.* primitive types and aliases that are provided by tvm-ffi + matches!( + type_key, + "ffi.Object" + | "ffi.String" + | "ffi.Function" + | "ffi.Module" + | "ffi.Tensor" + | "ffi.Shape" + | "ffi.Array" + | "ffi.Map" + | "ffi.Bytes" + | "ffi.SmallStr" + | "ffi.SmallBytes" + | "DLTensor*" + | "DataType" + | "Device" + | "bool" + | "int" + | "float" + | "None" + ) +} + +fn render_repr_c_type( + out: &mut String, + ty: &TypeGen, + info: &repr_c::ReprCInfo, + indent: usize, + _type_map: &BTreeMap, +) { + let indent_str = " ".repeat(indent); + let obj_name = format!("{}Obj", ty.rust_name); + + // Determine parent type for *Obj struct using absolute path into + // the types module so cross-module references resolve correctly. + let parent_ty = match &info.parent_type_key { + None => "tvm_ffi::object::Object".to_string(), + Some(parent_key) if parent_key == "ffi.Object" => "tvm_ffi::object::Object".to_string(), + Some(parent_key) => { + let parent_rust = _type_map + .get(parent_key) + .cloned() + .unwrap_or_else(|| format!("{}Obj", sanitize_ident(parent_key, IdentStyle::Type))); + // Convert crate path to absolute path inside types module and append Obj + let types_path = + parent_rust.replacen("crate::", "crate::_tvm_ffi_stubgen_detail::types::", 1); + format!("{}Obj", types_path) + } + }; + + // Generate *Obj struct with #[repr(C)] + writeln!(out, "{}#[repr(C)]", indent_str).ok(); + writeln!(out, "{}#[derive(tvm_ffi::derive::Object)]", indent_str).ok(); + writeln!(out, "{}#[type_key = \"{}\"]", indent_str, ty.type_key).ok(); + writeln!(out, "{}pub struct {} {{", indent_str, obj_name).ok(); + writeln!( + out, + "{} __tvm_ffi_object_parent: {},", + indent_str, parent_ty + ) + .ok(); + for entry in &info.layout { + match entry { + repr_c::LayoutEntry::Field(f) => { + writeln!(out, "{} {}: {},", indent_str, f.rust_name, f.rust_type).ok(); + } + repr_c::LayoutEntry::Gap { name, size } => { + writeln!(out, "{} {}: [u8; {}],", indent_str, name, size).ok(); + } + } + } + writeln!(out, "{}}}\n", indent_str).ok(); + + // Generate *Ref wrapper with #[repr(C)] + writeln!(out, "{}#[repr(C)]", indent_str).ok(); + writeln!( + out, + "{}#[derive(tvm_ffi::derive::ObjectRef, Clone)]", + indent_str + ) + .ok(); + writeln!(out, "{}pub struct {} {{", indent_str, ty.rust_name).ok(); + writeln!( + out, + "{} data: tvm_ffi::object::ObjectArc<{}>,", + indent_str, obj_name + ) + .ok(); + writeln!(out, "{}}}\n", indent_str).ok(); + + // Generate impl_object_hierarchy! macro call + if !ty.ancestor_chain.is_empty() { + write!( + out, + "{}tvm_ffi::impl_object_hierarchy!({}:", + indent_str, ty.rust_name + ) + .ok(); + for (i, ancestor) in ty.ancestor_chain.iter().enumerate() { + if i == 0 { + write!(out, " {}", ancestor).ok(); + } else { + write!(out, ", {}", ancestor).ok(); + } + } + writeln!(out, ");").ok(); + writeln!(out).ok(); + } + + // Generate getter methods for typed fields + writeln!(out, "{}impl {} {{", indent_str, ty.rust_name).ok(); + for f in info.fields() { + let method_name = format!("get_{}", f.rust_name); + let access_expr = if f.is_pod { + format!("self.data.{}", f.rust_name) + } else { + format!("self.data.{}.clone()", f.rust_name) + }; + writeln!( + out, + "{} pub fn {}(&self) -> {} {{", + indent_str, method_name, f.rust_type + ) + .ok(); + writeln!(out, "{} {}", indent_str, access_expr).ok(); + writeln!(out, "{} }}", indent_str).ok(); + } + writeln!(out, "{}}}\n", indent_str).ok(); + + // Generate FieldGetter statics for non-layout fields. + // Each entry is a registered ObjectDef field that couldn't become a direct struct + // member (parent-range offset or unmappable schema). + for nlf in &info.non_layout_fields { + let static_name = static_ident("FIELD", &format!("{}::{}", ty.type_key, nlf.name)); + // Use the concrete mapped type when available; fall back to tvm_ffi::Any so the + // static can still be constructed (get_any() has no type constraints). + let static_ty = nlf.rust_type.as_deref().unwrap_or("tvm_ffi::Any"); + writeln!( + out, + "{}static {}: std::sync::LazyLock> = std::sync::LazyLock::new(|| {{", + indent_str, static_name, static_ty + ) + .ok(); + writeln!( + out, + "{} tvm_ffi::object_wrapper::FieldGetter::new(\"{}\", \"{}\")", + indent_str, ty.type_key, nlf.name + ) + .ok(); + writeln!( + out, + "{} .expect(\"non-layout field {} must be registered in TVM reflection\")", + indent_str, nlf.name + ) + .ok(); + writeln!(out, "{}}});", indent_str).ok(); + } + + // Generate method statics and impls + for method in &ty.methods { + render_method_static(out, ty, method, indent); + } + writeln!(out, "{}impl {} {{", indent_str, ty.rust_name).ok(); + for method in &ty.methods { + render_method(out, ty, method, indent + 4); + } + // Generate FieldGetter accessor methods for non-layout fields. + // Interface matches direct struct-field getters: `get_*` naming, infallible return. + // Typed fields (mappable schema) return the concrete type via get(). + // Untyped fields (unmappable schema) return tvm_ffi::Any via get_any(). + for nlf in &info.non_layout_fields { + let static_name = static_ident("FIELD", &format!("{}::{}", ty.type_key, nlf.name)); + let method_name = format!("get_{}", nlf.rust_name); + let (return_type, call_expr) = if let Some(rt) = &nlf.rust_type { + ( + rt.as_str(), + format!( + "{}.get(&__obj).expect(\"non-layout field {} should be accessible\")", + static_name, nlf.name + ), + ) + } else { + ( + "tvm_ffi::Any", + format!( + "{}.get_any(&__obj).expect(\"non-layout field {} should be accessible\")", + static_name, nlf.name + ), + ) + }; + writeln!( + out, + "{} pub fn {}(&self) -> {} {{", + indent_str, method_name, return_type + ) + .ok(); + writeln!( + out, + "{} let __obj: tvm_ffi::object::ObjectRef = self.clone().into();", + indent_str + ) + .ok(); + writeln!(out, "{} {}", indent_str, call_expr).ok(); + writeln!(out, "{} }}", indent_str).ok(); + } + writeln!(out, "{}}}\n", indent_str).ok(); +} + +fn render_fallback_type(out: &mut String, ty: &TypeGen, indent: usize) { + let indent_str = " ".repeat(indent); + writeln!( + out, + "{}tvm_ffi::define_object_wrapper!({}, \"{}\");\n", + indent_str, ty.rust_name, ty.type_key + ) + .ok(); + + for field in &ty.fields { + render_field_static(out, ty, field, indent); + } + for method in &ty.methods { + render_method_static(out, ty, method, indent); + } + + writeln!(out, "{}impl {} {{", indent_str, ty.rust_name).ok(); + for field in &ty.fields { + render_field(out, ty, field, indent + 4); + } + for method in &ty.methods { + render_method(out, ty, method, indent + 4); + } + writeln!(out, "{}}}\n", indent_str).ok(); +} + +fn render_field_static(out: &mut String, ty: &TypeGen, field: &FieldGen, indent: usize) { + let indent_str = " ".repeat(indent); + let static_name = static_ident("FIELD", &format!("{}::{}", ty.type_key, field.name)); + writeln!( + out, + "{}static {}: LazyLock> = LazyLock::new(|| tvm_ffi::object_wrapper::FieldGetter::new(\"{}\", \"{}\").expect(\"missing field\"));", + indent_str, static_name, field.ty.name, ty.type_key, field.name + ) + .ok(); +} + +fn render_field(out: &mut String, ty: &TypeGen, field: &FieldGen, indent: usize) { + let indent_str = " ".repeat(indent); + let static_name = static_ident("FIELD", &format!("{}::{}", ty.type_key, field.name)); + writeln!( + out, + "{}pub fn {}(&self) -> Result<{}> {{", + indent_str, field.rust_name, field.ty.name + ) + .ok(); + if field.ty.name == "tvm_ffi::Any" { + writeln!( + out, + "{} {}.get_any(self.as_object_ref())", + indent_str, static_name + ) + .ok(); + } else { + writeln!( + out, + "{} {}.get(self.as_object_ref())", + indent_str, static_name + ) + .ok(); + } + writeln!(out, "{}}}", indent_str).ok(); +} + +fn render_method_static(out: &mut String, ty: &TypeGen, method: &MethodGen, indent: usize) { + let indent_str = " ".repeat(indent); + let static_name = static_ident("METHOD", &format!("{}::{}", ty.type_key, method.rust_name)); + writeln!( + out, + "{}static {}: LazyLock = LazyLock::new(|| tvm_ffi::object_wrapper::resolve_type_method(\"{}\", \"{}\").expect(\"missing type method\"));", + indent_str, static_name, ty.type_key, method.source_name + ) + .ok(); +} + +fn render_method(out: &mut String, ty: &TypeGen, method: &MethodGen, indent: usize) { + let indent_str = " ".repeat(indent); + let static_name = static_ident("METHOD", &format!("{}::{}", ty.type_key, method.rust_name)); + let self_prefix = if method.is_static { "" } else { "&self" }; + if method.sig.packed { + if method.is_static { + writeln!( + out, + "{}pub fn {}(args: &[Any]) -> Result {{", + indent_str, method.rust_name + ) + .ok(); + writeln!(out, "{} let func = &*{};", indent_str, static_name).ok(); + writeln!( + out, + "{} let views: Vec> = args.iter().map(AnyView::from).collect();", + indent_str + ) + .ok(); + writeln!(out, "{} func.call_packed(&views)", indent_str).ok(); + writeln!(out, "{}}}", indent_str).ok(); + return; + } + writeln!( + out, + "{}pub fn {}(&self, args: &[Any]) -> Result {{", + indent_str, method.rust_name + ) + .ok(); + writeln!(out, "{} let func = &*{};", indent_str, static_name).ok(); + writeln!( + out, + "{} let mut views: Vec> = Vec::with_capacity(args.len() + 1);", + indent_str + ) + .ok(); + // For repr(C) types, use deref coercion to upcast to ObjectRef + // For ObjectWrapper types, use the as_object_ref() method + if ty.repr_c_info.is_some() { + writeln!( + out, + "{} views.push(AnyView::from(self as &tvm_ffi::object::ObjectRef));", + indent_str + ) + .ok(); + } else { + writeln!( + out, + "{} views.push(AnyView::from(self.as_object_ref()));", + indent_str + ) + .ok(); + } + writeln!( + out, + "{} views.extend(args.iter().map(AnyView::from));", + indent_str + ) + .ok(); + writeln!(out, "{} func.call_packed(&views)", indent_str).ok(); + writeln!(out, "{}}}", indent_str).ok(); + return; + } + + let args = render_args(&method.sig.args); + let signature = if method.is_static { + format!("{}({})", method.rust_name, args) + } else if args.is_empty() { + format!("{}({})", method.rust_name, self_prefix) + } else { + format!("{}({}, {})", method.rust_name, self_prefix, args) + }; + writeln!( + out, + "{}pub fn {} -> Result<{}> {{", + indent_str, signature, method.sig.ret.name + ) + .ok(); + writeln!(out, "{} let func = &*{};", indent_str, static_name).ok(); + let type_list = if method.is_static { + render_type_list(&method.sig.args) + } else { + let mut types = vec!["tvm_ffi::object::ObjectRef".to_string()]; + types.extend( + method + .sig + .args + .iter() + .map(|arg| arg.typed_arg_name().to_string()), + ); + types.join(", ") + }; + writeln!( + out, + "{} let typed = tvm_ffi::into_typed_fn!(func.clone(), Fn({}) -> Result<{}>);", + indent_str, + type_list, + method.sig.ret.typed_ret_name() + ) + .ok(); + let call_expr = format!("typed({})", render_method_call_args(method)); + writeln!( + out, + "{} {}", + indent_str, + method + .sig + .ret + .wrap_typed_return(&call_expr, method.sig.ret.typed_ret_name()) + ) + .ok(); + writeln!(out, "{}}}", indent_str).ok(); +} + +fn render_args(args: &[RustType]) -> String { + let mut out = Vec::new(); + for (i, arg) in args.iter().enumerate() { + out.push(format!("_{}: {}", i, arg.name)); + } + out.join(", ") +} + +fn render_type_list(args: &[RustType]) -> String { + args.iter() + .map(|arg| arg.typed_arg_name().to_string()) + .collect::>() + .join(", ") +} + +fn render_call_args_typed(args: &[RustType]) -> String { + let mut out = Vec::new(); + for (i, arg) in args.iter().enumerate() { + out.push(arg.call_expr(&format!("_{}", i))); + } + out.join(", ") +} + +fn render_method_call_args(method: &MethodGen) -> String { + if method.is_static { + return render_call_args_typed(&method.sig.args); + } + let mut out = Vec::new(); + let self_type = RustType::object_wrapper("Self"); + out.push(self_type.call_expr("self")); + for (i, arg) in method.sig.args.iter().enumerate() { + out.push(arg.call_expr(&format!("_{}", i))); + } + out.join(", ") +} + +fn map_method_name(name: &str) -> String { + sanitize_ident(name, IdentStyle::Function) +} + +#[derive(Clone, Copy, PartialEq, Eq)] +enum IdentStyle { + Function, + Module, + Type, +} + +fn sanitize_ident(name: &str, style: IdentStyle) -> String { + let mut out = String::new(); + let mut prev_underscore = false; + for (i, ch) in name.chars().enumerate() { + let mut c = ch; + if style == IdentStyle::Module && ch.is_ascii_uppercase() { + if i > 0 && !prev_underscore { + out.push('_'); + } + c = ch.to_ascii_lowercase(); + } + if c.is_ascii_alphanumeric() || c == '_' { + out.push(c); + prev_underscore = c == '_'; + } else { + out.push('_'); + prev_underscore = true; + } + } + if out.is_empty() { + out.push('_'); + } + if out.chars().next().unwrap().is_ascii_digit() { + out.insert(0, '_'); + } + const KEYWORDS: &[&str] = &[ + "as", "break", "const", "continue", "crate", "else", "enum", "extern", "false", "fn", + "for", "if", "impl", "in", "let", "loop", "match", "mod", "move", "mut", "pub", "ref", + "return", "self", "Self", "static", "struct", "super", "trait", "true", "type", "unsafe", + "use", "where", "while", "async", "await", "dyn", + ]; + if KEYWORDS.contains(&out.as_str()) { + out.push('_'); + } + match style { + IdentStyle::Type => to_pascal_case(&out), + _ => out, + } +} + +fn to_pascal_case(name: &str) -> String { + let mut out = String::new(); + let mut uppercase = true; + for ch in name.chars() { + if ch == '_' { + uppercase = true; + continue; + } + if uppercase { + out.extend(ch.to_uppercase()); + uppercase = false; + } else { + out.push(ch); + } + } + if out.is_empty() { + "Type".to_string() + } else { + out + } +} + +fn static_ident(prefix: &str, full_name: &str) -> String { + let mut out = String::new(); + out.push_str(prefix); + out.push('_'); + for ch in full_name.chars() { + if ch.is_ascii_alphanumeric() { + out.push(ch.to_ascii_uppercase()); + } else { + out.push('_'); + } + } + if out.chars().next().unwrap().is_ascii_digit() { + out.insert(0, '_'); + } + out +} diff --git a/rust/tvm-ffi-stubgen/src/lib.rs b/rust/tvm-ffi-stubgen/src/lib.rs new file mode 100644 index 00000000..445c45a3 --- /dev/null +++ b/rust/tvm-ffi-stubgen/src/lib.rs @@ -0,0 +1,135 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod cli; +mod ffi; +mod generate; +mod model; +mod repr_c; +mod schema; +mod utils; + +use crate::schema::{collect_type_keys, extract_type_schema, parse_type_schema}; +pub use cli::Args; +use std::collections::{BTreeSet, HashSet}; +use std::process::Command; + +fn format_generated_crate(out_dir: &std::path::Path) -> Result<(), Box> { + let manifest_path = out_dir.join("Cargo.toml"); + let output = Command::new("cargo") + .arg("fmt") + .arg("--manifest-path") + .arg(&manifest_path) + .current_dir(out_dir) + .output()?; + if output.status.success() { + return Ok(()); + } + let stderr = String::from_utf8_lossy(&output.stderr); + let stdout = String::from_utf8_lossy(&output.stdout); + Err(format!( + "cargo fmt failed for generated crate {}.\nstdout:\n{}\nstderr:\n{}", + manifest_path.display(), + stdout.trim(), + stderr.trim() + ) + .into()) +} + +pub fn run(args: Args) -> Result<(), Box> { + if args.init_prefix.is_empty() { + return Err("--init-prefix is required".into()); + } + if args.dlls.is_empty() { + return Err("--dlls is required".into()); + } + utils::ensure_out_dir(&args.out_dir, args.overwrite)?; + + let prefixes: Vec = args + .init_prefix + .iter() + .map(|p| utils::normalize_prefix(p)) + .collect(); + // Single prefix: strip it so items land at crate root (backward compat). + // Multiple prefixes: don't strip; each prefix becomes a top-level module. + let effective_prefix = if prefixes.len() == 1 { + prefixes[0].clone() + } else { + String::new() + }; + + let _loaded_libs = ffi::load_dlls(&args.dlls)?; + + let global_funcs = ffi::list_global_function_names()?; + let filtered_funcs: Vec = global_funcs + .into_iter() + .filter(|name| prefixes.iter().any(|p| name.starts_with(p))) + .collect(); + + let type_keys = ffi::list_registered_type_keys()?; + let type_key_set: HashSet = type_keys.iter().cloned().collect(); + let mut filtered_types: Vec = type_keys + .iter() + .filter(|name| prefixes.iter().any(|p| name.starts_with(p))) + .cloned() + .collect(); + + let mut referenced_types: BTreeSet = BTreeSet::new(); + for full_name in &filtered_funcs { + let metadata = ffi::get_global_func_metadata(full_name)?; + let schema = metadata + .and_then(|meta| extract_type_schema(&meta)) + .and_then(|schema| parse_type_schema(&schema)); + if let Some(schema) = schema { + collect_type_keys(&schema, &type_key_set, &mut referenced_types); + } + } + + for ty in referenced_types { + if !filtered_types.contains(&ty) { + filtered_types.push(ty); + } + } + + let type_map = generate::build_type_map(&filtered_types, &effective_prefix); + let functions = + generate::build_function_entries(&filtered_funcs, &type_map, &effective_prefix)?; + let types = generate::build_type_entries(&filtered_types, &type_map, &effective_prefix)?; + + let functions_root = generate::build_function_modules(functions, &effective_prefix); + let types_root = generate::build_type_modules(types, &effective_prefix); + + let cargo_toml = generate::render_cargo_toml(&args, &type_map)?; + let lib_rs = generate::render_lib_rs(&functions_root, &types_root); + let functions_rs = generate::render_functions_rs(&functions_root); + let types_rs = generate::render_types_rs(&types_root, &type_map); + let build_rs = generate::render_build_rs(); + + let src_dir = args.out_dir.join("src"); + let detail_dir = src_dir.join("_tvm_ffi_stubgen_detail"); + std::fs::create_dir_all(&detail_dir)?; + std::fs::write(args.out_dir.join("Cargo.toml"), cargo_toml)?; + std::fs::write(args.out_dir.join("build.rs"), build_rs)?; + std::fs::write(src_dir.join("lib.rs"), lib_rs)?; + std::fs::write(detail_dir.join("functions.rs"), functions_rs)?; + std::fs::write(detail_dir.join("types.rs"), types_rs)?; + if !args.no_format { + format_generated_crate(&args.out_dir)?; + } + + Ok(()) +} diff --git a/rust/tvm-ffi-stubgen/src/main.rs b/rust/tvm-ffi-stubgen/src/main.rs new file mode 100644 index 00000000..6dfa1d09 --- /dev/null +++ b/rust/tvm-ffi-stubgen/src/main.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use clap::Parser; +use tvm_ffi_stubgen::{Args, run}; + +fn main() -> Result<(), Box> { + env_logger::init(); + let args = Args::parse(); + run(args) +} diff --git a/rust/tvm-ffi-stubgen/src/model.rs b/rust/tvm-ffi-stubgen/src/model.rs new file mode 100644 index 00000000..5e054bd6 --- /dev/null +++ b/rust/tvm-ffi-stubgen/src/model.rs @@ -0,0 +1,182 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::BTreeMap; + +#[derive(Debug, Clone)] +pub(crate) struct RustType { + pub(crate) name: String, + pub(crate) supported: bool, + pub(crate) kind: RustTypeKind, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) enum RustTypeKind { + Plain, + ObjectWrapper, +} + +#[derive(Debug, Clone)] +pub(crate) struct FunctionSig { + pub(crate) args: Vec, + pub(crate) ret: RustType, + pub(crate) packed: bool, +} + +#[derive(Debug, Clone)] +pub(crate) struct FunctionGen { + pub(crate) full_name: String, + pub(crate) rust_name: String, + pub(crate) sig: FunctionSig, +} + +#[derive(Debug, Clone)] +pub(crate) struct MethodGen { + pub(crate) source_name: String, + pub(crate) rust_name: String, + pub(crate) sig: FunctionSig, + pub(crate) is_static: bool, +} + +#[derive(Debug, Clone)] +pub(crate) struct FieldGen { + pub(crate) name: String, + pub(crate) rust_name: String, + pub(crate) ty: RustType, +} + +/// Spec for a single get_* method on a repr(C) Ref type. +#[derive(Debug, Clone)] +pub(crate) struct GetterSpec { + /// Method name, e.g. "get_first" + pub(crate) method_name: String, + /// Expression to produce the value, e.g. "self.data.first.clone()" + pub(crate) access_expr: String, + /// Return type, e.g. "Shape" or "i64" + pub(crate) ret_type: String, +} + +#[derive(Debug, Clone)] +pub(crate) struct TypeGen { + pub(crate) type_key: String, + pub(crate) rust_name: String, + pub(crate) methods: Vec, + pub(crate) fields: Vec, + /// Depth in inheritance hierarchy (0 = Object, 1 = direct subclass of Object, ...). + pub(crate) type_depth: i32, + /// If Some, type passes check_repr_c and we generate repr(C) *Obj + *Ref. + pub(crate) repr_c_info: Option, + /// Getter specs for repr_c types (get_* methods). Empty for non-repr_c. + pub(crate) getter_specs: Vec, + /// Ancestor chain for repr_c types: [DirectParent, Grandparent, ..., ObjectRef]. + /// Empty for non-repr_c or types without proper hierarchy info. + pub(crate) ancestor_chain: Vec, +} + +#[derive(Debug, Default)] +pub(crate) struct ModuleNode { + pub(crate) name: String, + pub(crate) functions: Vec, + pub(crate) types: Vec, + pub(crate) children: BTreeMap, +} + +impl FunctionSig { + pub(crate) fn packed() -> Self { + Self { + args: Vec::new(), + ret: RustType::unsupported("tvm_ffi::Any"), + packed: true, + } + } + + pub(crate) fn from_types(args: Vec, ret: RustType) -> Self { + let typed = args.len() <= 12 && args.iter().all(|arg| arg.supported) && ret.supported; + Self { + args, + ret, + packed: !typed, + } + } +} + +impl RustType { + pub(crate) fn supported(name: &str) -> Self { + Self { + name: name.to_string(), + supported: true, + kind: RustTypeKind::Plain, + } + } + + pub(crate) fn unsupported(name: &str) -> Self { + Self { + name: name.to_string(), + supported: false, + kind: RustTypeKind::Plain, + } + } + + pub(crate) fn object_wrapper(name: &str) -> Self { + Self { + name: name.to_string(), + supported: true, + kind: RustTypeKind::ObjectWrapper, + } + } + + pub(crate) fn typed_arg_name(&self) -> &str { + match self.kind { + RustTypeKind::Plain => &self.name, + RustTypeKind::ObjectWrapper => "tvm_ffi::object::ObjectRef", + } + } + + pub(crate) fn typed_ret_name(&self) -> &str { + match self.kind { + RustTypeKind::Plain => &self.name, + RustTypeKind::ObjectWrapper => &self.name, + } + } + + pub(crate) fn call_expr(&self, arg_name: &str) -> String { + match self.kind { + RustTypeKind::Plain => arg_name.to_string(), + RustTypeKind::ObjectWrapper => { + // Use Into trait for upcast + format!("{}.into()", arg_name) + } + } + } + + pub(crate) fn wrap_typed_return(&self, expr: &str, typed_ret_name: &str) -> String { + match self.kind { + RustTypeKind::Plain => expr.to_string(), + RustTypeKind::ObjectWrapper => { + if typed_ret_name == "tvm_ffi::object::ObjectRef" { + if self.name == "Self" { + format!("{}.map(Self::from)", expr) + } else { + format!("{}.map({}::from)", expr, self.name) + } + } else { + expr.to_string() + } + } + } + } +} diff --git a/rust/tvm-ffi-stubgen/src/repr_c.rs b/rust/tvm-ffi-stubgen/src/repr_c.rs new file mode 100644 index 00000000..4072bc3b --- /dev/null +++ b/rust/tvm-ffi-stubgen/src/repr_c.rs @@ -0,0 +1,425 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Validation that a type has a compact C-compatible layout (check_repr_c) +//! and extraction of field layout for repr(C) code generation. +//! +//! The strategy is gap-filling: given the parent struct size and the registered +//! field offsets/sizes, any byte range not covered by a known field is emitted +//! as a `[u8; N]` padding member. This handles C++ tail padding, vtable +//! pointers, and unregistered fields uniformly without requiring alignment +//! inference. + +use crate::ffi; +use crate::schema::{TypeSchema, extract_type_schema, parse_type_schema}; +use log::{debug, trace}; +use std::collections::BTreeMap; + +/// Result of check_repr_c: type passes and we have full layout for codegen. +#[derive(Debug, Clone)] +pub(crate) struct ReprCInfo { + /// Type key of the immediate parent (Object or a subclass). None for root types. + pub(crate) parent_type_key: Option, + /// Ordered layout entries (fields and gaps) covering [parent_total_size .. total_size). + pub(crate) layout: Vec, + /// Fields registered in this type's ObjectDef that are NOT part of the repr(C) struct + /// layout. Two causes: (1) offset < parent_total_size — the field occupies a slot + /// within the parent's address range; (2) schema not mappable — the field's type + /// cannot be expressed as a repr(C) Rust type. All such fields can still be read at + /// runtime via FieldGetter. + pub(crate) non_layout_fields: Vec, +} + +/// A registered field that does not appear in the repr(C) struct layout. +#[derive(Debug, Clone)] +pub(crate) struct NonLayoutField { + /// Original C++ field name (used as the FieldGetter key). + pub(crate) name: String, + /// Sanitized Rust identifier (used as the getter method name suffix). + pub(crate) rust_name: String, + /// Mapped Rust type string, or None if the schema could not be mapped. + /// When None, the generated getter returns `tvm_ffi::Any` via get_any(). + pub(crate) rust_type: Option, +} + +/// A single entry in the repr(C) struct body after the parent. +#[derive(Debug, Clone)] +pub(crate) enum LayoutEntry { + /// A known, typed field. + Field(ReprCField), + /// An opaque gap (padding, vtable pointer, or unregistered field). + Gap { name: String, size: i64 }, +} + +#[derive(Debug, Clone)] +pub(crate) struct ReprCField { + pub(crate) rust_name: String, + pub(crate) offset: i64, + pub(crate) size: i64, + /// Rust type name for the field (e.g. "i64", "Shape"). + pub(crate) rust_type: String, + /// True if Copy type (getter returns value); false if Ref (getter returns clone). + pub(crate) is_pod: bool, +} + +impl ReprCInfo { + /// Iterate only the typed fields (skipping gaps). + pub(crate) fn fields(&self) -> impl Iterator { + self.layout.iter().filter_map(|e| match e { + LayoutEntry::Field(f) => Some(f), + LayoutEntry::Gap { .. } => None, + }) + } +} + +/// Returns ReprCInfo if the type can be laid out as repr(C); None otherwise. +/// +/// Failure reasons (all logged at DEBUG level): +/// - No type info registered at all +/// - Metadata missing or total_size unknown +/// - Parent type not in type_map or parent itself fails +/// - A field's type schema cannot be mapped to a Rust type +pub(crate) fn check_repr_c( + type_key: &str, + type_map: &BTreeMap, +) -> Option { + let info = match ffi::get_type_info(type_key) { + Some(i) => i, + None => { + debug!("{}: no type info registered", type_key); + return None; + } + }; + let total_size = match total_size_from_info(info) { + Some(s) if s > 0 => s as i64, + _ => { + debug!("{}: metadata missing or total_size <= 0", type_key); + return None; + } + }; + trace!( + "{}: total_size={}, type_depth={}, num_fields={}, num_methods={}", + type_key, total_size, info.type_depth, info.num_fields, info.num_methods + ); + + // Resolve parent. + // If the direct parent is in type_map and passes check_repr_c, we use it as + // the typed parent field. Otherwise we fall back to ffi.Object as the parent + // and let gap-filling cover the bytes between Object and our first field. + let obj_size = { + let oi = ffi::get_type_info("ffi.Object")?; + total_size_from_info(oi)? as i64 + }; + let (parent_type_key, parent_total_size) = + if info.type_depth > 0 && !info.type_acenstors.is_null() { + let ancestor_ptr = unsafe { *info.type_acenstors.add((info.type_depth - 1) as usize) }; + let direct_parent_key = if !ancestor_ptr.is_null() { + let pi = unsafe { &*ancestor_ptr }; + ffi::byte_array_to_string_opt(&pi.type_key) + } else { + None + }; + match direct_parent_key { + Some(ref key) if key == "ffi.Object" => (None, obj_size), + Some(ref key) + if type_map.contains_key(key) && check_repr_c(key, type_map).is_some() => + { + let pi = ffi::get_type_info(key)?; + let ps = total_size_from_info(pi)? as i64; + trace!("{}: parent='{}' (typed, size={})", type_key, key, ps); + (Some(key.clone()), ps) + } + Some(ref key) => { + // Parent exists but not mappable — use Object as parent, gap covers the rest. + trace!( + "{}: parent='{}' not mappable, falling back to Object", + type_key, key + ); + (None, obj_size) + } + None => (None, obj_size), + } + } else { + (None, obj_size) + }; + trace!( + "{}: parent={:?}, parent_total_size={}", + type_key, parent_type_key, parent_total_size + ); + + // Collect and sort fields that belong to this type (offset >= parent_total_size). + // Any registered field that cannot become a direct struct member is tracked in + // non_layout_fields so it can be exposed via a FieldGetter accessor. + let mut typed_fields: Vec = Vec::new(); + let mut non_layout_fields: Vec = Vec::new(); + if info.num_fields > 0 && !info.fields.is_null() { + let field_slice = + unsafe { std::slice::from_raw_parts(info.fields, info.num_fields as usize) }; + for field in field_slice { + let name = match ffi::byte_array_to_string_opt(&field.name) { + Some(n) => n, + None => { + debug!("{}: a field name is unreadable", type_key); + return None; + } + }; + // Fields whose offset falls inside the parent type's address range cannot be + // part of the repr(C) struct layout (they occupy a slot the parent owns). + if field.offset < parent_total_size { + let rust_type = if field.offset >= 0 && field.size >= 0 { + let meta = ffi::byte_array_to_string_opt(&field.metadata); + let schema = meta + .as_deref() + .and_then(extract_type_schema) + .and_then(|s| parse_type_schema(&s)); + repr_c_field_type(schema.as_ref(), type_map, type_key, field.size) + .map(|(ty, _)| ty) + } else { + None + }; + trace!( + "{}: field '{}' at offset={} is in parent range → non-layout (rust_type={:?})", + type_key, name, field.offset, rust_type + ); + non_layout_fields.push(NonLayoutField { + name: name.clone(), + rust_name: sanitize_ident(&name), + rust_type, + }); + continue; + } + trace!( + "{}: field '{}': offset={}, size={}", + type_key, name, field.offset, field.size + ); + if field.offset < 0 || field.size < 0 { + debug!("{}: field '{}' has invalid offset/size", type_key, name); + return None; + } + let meta = ffi::byte_array_to_string_opt(&field.metadata); + let schema = meta + .as_deref() + .and_then(extract_type_schema) + .and_then(|s| parse_type_schema(&s)); + trace!( + "{}: field '{}' schema origin={:?}", + type_key, + name, + schema.as_ref().map(|s| &s.origin) + ); + let mapped = repr_c_field_type(schema.as_ref(), type_map, type_key, field.size); + let (rust_type, is_pod) = match mapped { + Some(v) => v, + None => { + // Schema not mappable: cannot be a struct field, but still accessible + // at runtime via FieldGetter with an untyped (Any) return. + debug!( + "{}: field '{}' type not mappable, covered by gap + non-layout FieldGetter (schema_origin={:?})", + type_key, + name, + schema.as_ref().map(|s| &s.origin) + ); + non_layout_fields.push(NonLayoutField { + name: name.clone(), + rust_name: sanitize_ident(&name), + rust_type: None, + }); + continue; + } + }; + trace!( + "{}: field '{}' -> rust_type='{}', is_pod={}", + type_key, name, rust_type, is_pod + ); + typed_fields.push(ReprCField { + rust_name: sanitize_ident(&name), + offset: field.offset, + size: field.size, + rust_type, + is_pod, + }); + } + } + typed_fields.sort_by_key(|f| f.offset); + + // Build layout by walking [parent_total_size .. total_size) and inserting + // gaps wherever there is no registered field. + let mut layout = Vec::new(); + let mut pos = parent_total_size; + let mut gap_idx = 0usize; + for f in &typed_fields { + if f.offset > pos { + let gap_size = f.offset - pos; + trace!( + "{}: gap at {}..{} ({} bytes)", + type_key, pos, f.offset, gap_size + ); + layout.push(LayoutEntry::Gap { + name: format!("_gap{}", gap_idx), + size: gap_size, + }); + gap_idx += 1; + pos = f.offset; + } + if f.offset < pos { + // Overlapping fields — shouldn't happen, bail out. + debug!( + "{}: field '{}' at offset={} overlaps pos={}", + type_key, f.rust_name, f.offset, pos + ); + return None; + } + layout.push(LayoutEntry::Field(f.clone())); + pos = f.offset + f.size; + } + // Trailing gap (tail padding, or fields after last registered one) + if pos < total_size { + let gap_size = total_size - pos; + trace!( + "{}: trailing gap at {}..{} ({} bytes)", + type_key, pos, total_size, gap_size + ); + layout.push(LayoutEntry::Gap { + name: format!("_gap{}", gap_idx), + size: gap_size, + }); + } else if pos > total_size { + debug!( + "{}: fields exceed total_size (pos={} > total_size={})", + type_key, pos, total_size + ); + return None; + } + + debug!( + "{}: repr_c OK ({} fields, {} gaps, {} layout entries)", + type_key, + typed_fields.len(), + layout + .iter() + .filter(|e| matches!(e, LayoutEntry::Gap { .. })) + .count(), + layout.len() + ); + Some(ReprCInfo { + parent_type_key, + layout, + non_layout_fields, + }) +} + +fn total_size_from_info(info: &tvm_ffi::tvm_ffi_sys::TVMFFITypeInfo) -> Option { + if info.metadata.is_null() { + return None; + } + let meta = unsafe { &*info.metadata }; + if meta.total_size <= 0 { + return None; + } + Some(meta.total_size) +} + +/// Map schema to (rust_type_name, is_pod). None if not repr_c compatible. +fn repr_c_field_type( + schema: Option<&TypeSchema>, + type_map: &BTreeMap, + _self_type_key: &str, + field_size: i64, +) -> Option<(String, bool)> { + let schema = schema?; + match schema.origin.as_str() { + "Any" | "ffi.Any" => Some(("tvm_ffi::AnyValue".to_string(), false)), + "bool" => Some(("bool".to_string(), true)), + "int" => match field_size { + 1 => Some(("i8".to_string(), true)), + 2 => Some(("i16".to_string(), true)), + 4 => Some(("i32".to_string(), true)), + 8 => Some(("i64".to_string(), true)), + _ => None, + }, + "float" => match field_size { + 4 => Some(("f32".to_string(), true)), + 8 => Some(("f64".to_string(), true)), + _ => None, + }, + "Device" => Some(("tvm_ffi::DLDevice".to_string(), true)), + "DataType" => Some(("tvm_ffi::DLDataType".to_string(), true)), + "ffi.String" | "std::string" | "const char*" | "ffi.SmallStr" => { + Some(("tvm_ffi::String".to_string(), false)) + } + "ffi.Bytes" | "ffi.SmallBytes" => Some(("tvm_ffi::Bytes".to_string(), false)), + "ffi.Function" => Some(("tvm_ffi::Function".to_string(), false)), + "ffi.Object" => Some(("tvm_ffi::object::ObjectRef".to_string(), false)), + "ffi.Shape" => Some(("tvm_ffi::Shape".to_string(), false)), + "ffi.Module" => Some(("tvm_ffi::Module".to_string(), false)), + "ffi.Tensor" | "DLTensor*" => Some(("tvm_ffi::Tensor".to_string(), false)), + "Optional" => match schema.args.as_slice() { + [inner] => repr_c_field_type(Some(inner), type_map, _self_type_key, field_size) + .map(|(inner_ty, pod)| (format!("Option<{}>", inner_ty), pod)), + [] => Some(("Option".to_string(), false)), + _ => None, + }, + "ffi.Array" => match schema.args.as_slice() { + [inner] => { + let (inner_ty, _) = + repr_c_field_type(Some(inner), type_map, _self_type_key, field_size)?; + Some((format!("tvm_ffi::Array<{}>", inner_ty), false)) + } + [] => Some(( + "tvm_ffi::Array".to_string(), + false, + )), + _ => None, + }, + "ffi.Map" => match schema.args.as_slice() { + [k, v] => { + let (k_ty, _) = repr_c_field_type(Some(k), type_map, _self_type_key, field_size)?; + let (v_ty, _) = repr_c_field_type(Some(v), type_map, _self_type_key, field_size)?; + Some((format!("tvm_ffi::Map<{}, {}>", k_ty, v_ty), false)) + } + _ => None, + }, + other => type_map.get(other).map(|path| (path.clone(), false)), + } +} + +fn sanitize_ident(name: &str) -> String { + let mut out = String::new(); + for ch in name.chars() { + if ch.is_ascii_alphanumeric() || ch == '_' { + out.push(ch); + } else { + out.push('_'); + } + } + if out.is_empty() { + out.push('_'); + } + if out.chars().next().unwrap().is_ascii_digit() { + out.insert(0, '_'); + } + const KEYWORDS: &[&str] = &[ + "as", "break", "const", "continue", "crate", "else", "enum", "extern", "false", "fn", + "for", "if", "impl", "in", "let", "loop", "match", "mod", "move", "mut", "pub", "ref", + "return", "self", "Self", "static", "struct", "super", "trait", "true", "type", "unsafe", + "use", "where", "while", "async", "await", "dyn", + ]; + if KEYWORDS.contains(&out.as_str()) { + out.push('_'); + } + out +} diff --git a/rust/tvm-ffi-stubgen/src/schema.rs b/rust/tvm-ffi-stubgen/src/schema.rs new file mode 100644 index 00000000..3baf5de3 --- /dev/null +++ b/rust/tvm-ffi-stubgen/src/schema.rs @@ -0,0 +1,66 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use serde::Deserialize; +use std::collections::{BTreeSet, HashSet}; + +#[derive(Debug, Clone)] +pub(crate) struct TypeSchema { + pub(crate) origin: String, + pub(crate) args: Vec, +} + +#[derive(Deserialize)] +struct TypeSchemaJson { + #[serde(rename = "type")] + ty: String, + #[serde(default)] + args: Vec, +} + +pub(crate) fn extract_type_schema(metadata: &str) -> Option { + let value: serde_json::Value = serde_json::from_str(metadata).ok()?; + value + .get("type_schema") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()) +} + +pub(crate) fn parse_type_schema(schema: &str) -> Option { + let json: TypeSchemaJson = serde_json::from_str(schema).ok()?; + Some(parse_type_schema_json(&json)) +} + +pub(crate) fn collect_type_keys( + schema: &TypeSchema, + known: &HashSet, + out: &mut BTreeSet, +) { + if known.contains(&schema.origin) { + out.insert(schema.origin.clone()); + } + for arg in &schema.args { + collect_type_keys(arg, known, out); + } +} + +fn parse_type_schema_json(json: &TypeSchemaJson) -> TypeSchema { + TypeSchema { + origin: json.ty.clone(), + args: json.args.iter().map(parse_type_schema_json).collect(), + } +} diff --git a/rust/tvm-ffi-stubgen/src/utils.rs b/rust/tvm-ffi-stubgen/src/utils.rs new file mode 100644 index 00000000..d1ae369e --- /dev/null +++ b/rust/tvm-ffi-stubgen/src/utils.rs @@ -0,0 +1,61 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::fs; +use std::path::{Path, PathBuf}; + +pub(crate) fn normalize_prefix(prefix: &str) -> String { + if prefix.is_empty() { + return String::new(); + } + if prefix.ends_with('.') { + prefix.to_string() + } else { + format!("{}.", prefix) + } +} + +pub(crate) fn ensure_out_dir( + out_dir: &Path, + overwrite: bool, +) -> Result<(), Box> { + if out_dir.exists() { + let mut has_entries = false; + for entry in fs::read_dir(out_dir)? { + let entry = entry?; + if entry.file_name() != "." && entry.file_name() != ".." { + has_entries = true; + break; + } + } + if has_entries && !overwrite { + return Err("output directory is not empty (use --overwrite to proceed)".into()); + } + } else { + fs::create_dir_all(out_dir)?; + } + Ok(()) +} + +pub(crate) fn default_tvm_ffi_path() -> Result> { + let current = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + let candidate = current.join("../tvm-ffi"); + if candidate.exists() { + return Ok(candidate); + } + Err("unable to locate tvm-ffi path (use --tvm-ffi-path)".into()) +} diff --git a/rust/tvm-ffi-stubgen/tests/stubgen.rs b/rust/tvm-ffi-stubgen/tests/stubgen.rs new file mode 100644 index 00000000..2b69de30 --- /dev/null +++ b/rust/tvm-ffi-stubgen/tests/stubgen.rs @@ -0,0 +1,223 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::env; +use std::fs; +use std::path::{Path, PathBuf}; +use std::process::Command; +use tvm_ffi_stubgen::{Args, run}; + +#[test] +fn stubgen_tvm_ffi_testing() { + let lib_dir = tvm_ffi_libdir().expect("tvm-ffi-config --libdir"); + let dlls = resolve_testing_dlls(&lib_dir).expect("unable to locate tvm_ffi testing libraries"); + let testing_lib = dlls + .iter() + .find(|path| { + path.file_name() + .map(|name| name.to_string_lossy().contains("tvm_ffi_testing")) + .unwrap_or(false) + }) + .cloned() + .expect("tvm_ffi_testing library"); + let out_dir = unique_temp_dir("tvm_ffi_stubgen_test"); + let args = Args { + out_dir: out_dir.clone(), + dlls: vec![testing_lib.clone()], + init_prefix: vec!["testing".to_string()], + init_crate: "tvm_ffi_testing_stub".to_string(), + tvm_ffi_path: None, + overwrite: true, + no_format: false, + }; + + run(args).expect("stubgen run"); + + let cargo_toml = out_dir.join("Cargo.toml"); + let functions_rs = out_dir + .join("src") + .join("_tvm_ffi_stubgen_detail") + .join("functions.rs"); + let types_rs = out_dir + .join("src") + .join("_tvm_ffi_stubgen_detail") + .join("types.rs"); + assert!(cargo_toml.exists(), "Cargo.toml not generated"); + assert!(functions_rs.exists(), "functions.rs not generated"); + assert!(types_rs.exists(), "types.rs not generated"); + + let functions_body = fs::read_to_string(functions_rs).expect("read functions.rs"); + let types_body = fs::read_to_string(types_rs).expect("read types.rs"); + assert!(functions_body.contains("add_one"), "missing add_one stub"); + assert!( + types_body.contains("resolve_type_method"), + "type method wrappers should resolve from type metadata" + ); + assert!( + types_body.contains("pub fn new("), + "constructor `new` should be generated when available" + ); + assert!( + !types_body.contains("c_ffi_init"), + "legacy constructor name c_ffi_init should not be generated" + ); + assert!( + !types_body.contains("Function::get_global(\"testing.TestIntPair.__ffi_init__\")"), + "type methods should not use global lookup path" + ); + + write_integration_test(&out_dir, &testing_lib).expect("write integration test"); + run_generated_tests(&out_dir, &lib_dir).expect("run generated tests"); +} + +fn resolve_testing_dlls(lib_dir: &Path) -> Result, String> { + if let Some(dlls) = dlls_from_dir(lib_dir) { + return Ok(dlls); + } + Err("tvm-ffi-config --libdir did not contain tvm_ffi libraries".to_string()) +} + +fn dlls_from_dir(dir: &Path) -> Option> { + let tvm_ffi = dir.join(lib_filename("tvm_ffi")); + let tvm_ffi_testing = dir.join(lib_filename("tvm_ffi_testing")); + if tvm_ffi.exists() && tvm_ffi_testing.exists() { + Some(vec![tvm_ffi, tvm_ffi_testing]) + } else { + None + } +} + +fn lib_filename(name: &str) -> String { + if cfg!(target_os = "windows") { + format!("{}.dll", name) + } else if cfg!(target_os = "macos") { + format!("lib{}.dylib", name) + } else { + format!("lib{}.so", name) + } +} + +fn unique_temp_dir(prefix: &str) -> PathBuf { + let base = env::temp_dir(); + let pid = std::process::id(); + let nanos = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_nanos()) + .unwrap_or(0); + base.join(format!("{}_{}_{}", prefix, pid, nanos)) +} + +fn tvm_ffi_libdir() -> Result> { + let output = Command::new("tvm-ffi-config").arg("--libdir").output()?; + if !output.status.success() { + return Err("tvm-ffi-config --libdir failed".into()); + } + let lib_dir = String::from_utf8(output.stdout)?.trim().to_string(); + if lib_dir.is_empty() { + return Err("tvm-ffi-config returned empty libdir".into()); + } + Ok(PathBuf::from(lib_dir)) +} + +fn write_integration_test( + out_dir: &Path, + testing_lib: &Path, +) -> Result<(), Box> { + let tests_dir = out_dir.join("tests"); + fs::create_dir_all(&tests_dir)?; + let test_body = format!( + r#"use tvm_ffi_testing_stub as stub; + +#[test] +fn generated_usage_roundtrip() {{ + let lib_path = "{lib_path}"; + stub::load_library(lib_path).expect("load tvm_ffi_testing"); + + let value = stub::add_one(1).expect("call add_one"); + assert_eq!(value, 2); + + let _out = stub::echo(&[tvm_ffi::Any::from(1_i64)]).expect("call echo"); + + // Constructor + instance method should resolve from type metadata. + let pair_obj = stub::TestIntPair::new(3, 4).expect("construct TestIntPair"); + let pair: stub::TestIntPair = pair_obj + .try_into() + .unwrap_or_else(|_| panic!("object -> TestIntPair downcast failed")); + let sum_any = pair.sum(&[]).expect("call TestIntPair.sum"); + let sum: i64 = sum_any.try_into().expect("sum any -> i64"); + assert_eq!(sum, 7); + + // Verify upcast/downcast roundtrip on Cxx inheritance chain. + let derived_obj = stub::TestCxxClassDerived::new(11, 7, 3.5, 1.25) + .expect("construct TestCxxClassDerived"); + let _derived: stub::TestCxxClassDerived = derived_obj.clone().into(); + let base: stub::TestCxxClassBase = derived_obj.clone().into(); + let base_obj: tvm_ffi::object::ObjectRef = base.clone().into(); + let roundtrip: stub::TestCxxClassDerived = base_obj.into(); + assert_eq!(base.v_i64().expect("base.v_i64"), 11); + assert_eq!(base.v_i32().expect("base.v_i32"), 7); + assert!((roundtrip.v_f64().expect("derived.v_f64") - 3.5).abs() < 1e-9); + assert!((roundtrip.v_f32().expect("derived.v_f32") - 1.25).abs() < 1e-6); + + let obj = stub::make_unregistered_object().expect("create unregistered object"); + let count = stub::object_use_count(obj.clone()).expect("query object use count"); + assert!(count >= 1); + + // Fallback wrapper can be constructed from ObjectRef directly. + let _wrapped: stub::TestUnregisteredObject = obj.into(); +}} +"#, + lib_path = testing_lib.display() + ); + fs::write(tests_dir.join("integration.rs"), test_body)?; + Ok(()) +} + +fn run_generated_tests(out_dir: &Path, lib_dir: &Path) -> Result<(), Box> { + let mut cmd = Command::new("cargo"); + cmd.arg("test") + .arg("--manifest-path") + .arg(out_dir.join("Cargo.toml")) + .current_dir(out_dir); + + let ld_var = if cfg!(target_os = "windows") { + "PATH" + } else if cfg!(target_os = "macos") { + "DYLD_LIBRARY_PATH" + } else { + "LD_LIBRARY_PATH" + }; + + let current_ld = env::var(ld_var).unwrap_or_default(); + let separator = if ld_var == "PATH" { ";" } else { ":" }; + let lib_dir_str = lib_dir.to_string_lossy(); + let new_ld = if current_ld.is_empty() { + lib_dir_str.to_string() + } else { + format!("{}{}{}", lib_dir_str, separator, current_ld) + }; + cmd.env(ld_var, new_ld); + + let path_value = env::var("PATH").unwrap_or_default(); + cmd.env("PATH", path_value); + + let status = cmd.status()?; + if !status.success() { + return Err("generated crate tests failed".into()); + } + Ok(()) +} diff --git a/rust/tvm-ffi-sys/Cargo.toml b/rust/tvm-ffi-sys/Cargo.toml index ef87038d..b0c10751 100644 --- a/rust/tvm-ffi-sys/Cargo.toml +++ b/rust/tvm-ffi-sys/Cargo.toml @@ -20,7 +20,8 @@ name = "tvm-ffi-sys" description = "Low-level sys crate for tvm-ffi" version = "0.1.0-alpha.0" -edition = "2021" +edition = "2024" +rust-version = "1.85" license = "Apache-2.0" diff --git a/rust/tvm-ffi-sys/src/c_api.rs b/rust/tvm-ffi-sys/src/c_api.rs index e0bf0858..f3c5b8ee 100644 --- a/rust/tvm-ffi-sys/src/c_api.rs +++ b/rust/tvm-ffi-sys/src/c_api.rs @@ -113,6 +113,12 @@ pub struct TVMFFIObject { __padding: u32, } +impl Default for TVMFFIObject { + fn default() -> Self { + Self::new() + } +} + impl TVMFFIObject { pub fn new() -> Self { Self { diff --git a/rust/tvm-ffi-sys/src/dlpack.rs b/rust/tvm-ffi-sys/src/dlpack.rs index e069ea8e..a8eeed96 100644 --- a/rust/tvm-ffi-sys/src/dlpack.rs +++ b/rust/tvm-ffi-sys/src/dlpack.rs @@ -107,8 +107,8 @@ pub struct DLTensor { impl DLDevice { pub fn new(device_type: DLDeviceType, device_id: i32) -> Self { Self { - device_type: device_type, - device_id: device_id, + device_type, + device_id, } } } @@ -117,8 +117,8 @@ impl DLDataType { pub fn new(code: DLDataTypeCode, bits: u8, lanes: u16) -> Self { Self { code: code as u8, - bits: bits, - lanes: lanes, + bits, + lanes, } } } diff --git a/rust/tvm-ffi-sys/src/lib.rs b/rust/tvm-ffi-sys/src/lib.rs index 1530cc71..018421d7 100644 --- a/rust/tvm-ffi-sys/src/lib.rs +++ b/rust/tvm-ffi-sys/src/lib.rs @@ -16,6 +16,8 @@ * specific language governing permissions and limitations * under the License. */ +#![allow(clippy::new_without_default)] +#![allow(clippy::missing_safety_doc)] pub mod c_api; pub mod c_env_api; pub mod dlpack; diff --git a/rust/tvm-ffi/Cargo.toml b/rust/tvm-ffi/Cargo.toml index b27c8c84..ec73bb1b 100644 --- a/rust/tvm-ffi/Cargo.toml +++ b/rust/tvm-ffi/Cargo.toml @@ -20,7 +20,8 @@ name = "tvm-ffi" description = "tvm-ffi rust support" version = "0.1.0-alpha.0" -edition = "2021" +edition = "2024" +rust-version = "1.85" license = "Apache-2.0" diff --git a/rust/tvm-ffi/src/any.rs b/rust/tvm-ffi/src/any.rs index ecf8b9ea..ade44db6 100644 --- a/rust/tvm-ffi/src/any.rs +++ b/rust/tvm-ffi/src/any.rs @@ -37,6 +37,13 @@ pub struct Any { data: TVMFFIAny, } +/// Managed Any wrapper that participates in typed signatures. +#[repr(transparent)] +#[derive(Clone)] +pub struct AnyValue { + inner: Any, +} + //--------------------- // AnyView //--------------------- @@ -94,9 +101,9 @@ impl<'a, T: AnyCompatible> From<&'a T> for AnyView<'a> { fn from(value: &'a T) -> Self { unsafe { let mut data = TVMFFIAny::new(); - T::copy_to_any_view(&value, &mut data); + T::copy_to_any_view(value, &mut data); Self { - data: data, + data, _phantom: std::marker::PhantomData, } } @@ -144,6 +151,11 @@ impl Any { pub fn type_index(&self) -> i32 { self.data.type_index } + + #[inline] + pub(crate) fn as_raw_ffi_any(&self) -> TVMFFIAny { + self.data + } /// Try to query if stored typed in Any exactly matches the type T /// /// This function is fast in the case of failure and can be used to check @@ -200,6 +212,34 @@ impl Any { } } +impl AnyValue { + pub fn new(value: Any) -> Self { + Self { inner: value } + } + + pub fn as_any(&self) -> &Any { + &self.inner + } + + pub fn into_any(self) -> Any { + self.inner + } +} + +impl From for AnyValue { + fn from(value: Any) -> Self { + Self { inner: value } + } +} + +impl<'a> TryFrom> for AnyValue { + type Error = crate::error::Error; + + fn try_from(value: AnyView<'a>) -> Result { + Ok(AnyValue::from(Any::from(value))) + } +} + impl Default for Any { fn default() -> Self { Self::new() diff --git a/rust/tvm-ffi/src/collections/array.rs b/rust/tvm-ffi/src/collections/array.rs index 6f259ba1..06950192 100644 --- a/rust/tvm-ffi/src/collections/array.rs +++ b/rust/tvm-ffi/src/collections/array.rs @@ -165,7 +165,7 @@ impl Array { #[inline] fn as_container(&self) -> &ArrayObj { unsafe { - let ptr = ObjectArc::as_raw(&self.data) as *const ArrayObj; + let ptr = ObjectArc::as_raw(&self.data); &*ptr } } diff --git a/rust/tvm-ffi/src/collections/map.rs b/rust/tvm-ffi/src/collections/map.rs new file mode 100644 index 00000000..f1b6be5c --- /dev/null +++ b/rust/tvm-ffi/src/collections/map.rs @@ -0,0 +1,255 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +use std::marker::PhantomData; +use std::sync::LazyLock; + +use crate::any::TryFromTemp; +use crate::derive::Object; +use crate::error::Result; +use crate::function::Function; +use crate::object::{Object, ObjectArc, ObjectRefCore}; +use crate::type_traits::AnyCompatible; +use crate::{Any, AnyView}; +use tvm_ffi_sys::TVMFFITypeIndex as TypeIndex; +use tvm_ffi_sys::{TVMFFIAny, TVMFFIObject}; + +#[repr(C)] +#[derive(Object)] +#[type_key = "ffi.Map"] +#[type_index(TypeIndex::kTVMFFIMap)] +pub struct MapObj { + pub object: Object, +} + +impl Map +where + K: AnyCompatible + Clone + 'static, + V: AnyCompatible + Clone + 'static, +{ + /// Create a new Map from key/value pairs. + pub fn new>(items: I) -> Result { + static API_FUNC: LazyLock = + LazyLock::new(|| Function::get_global("ffi.Map").unwrap()); + let items: Vec<(K, V)> = items.into_iter().collect(); + let mut args: Vec = Vec::with_capacity(items.len() * 2); + for (key, value) in items.iter() { + args.push(AnyView::from(key)); + args.push(AnyView::from(value)); + } + (*API_FUNC).call_packed(&args)?.try_into() + } + + /// Return the number of entries in the map. + pub fn len(&self) -> Result { + static API_FUNC: LazyLock = + LazyLock::new(|| Function::get_global("ffi.MapSize").unwrap()); + let args = [AnyView::from(self)]; + let size_any = (*API_FUNC).call_packed(&args)?; + let temp: TryFromTemp = TryFromTemp::try_from(size_any)?; + let size = TryFromTemp::into_value(temp); + Ok(size as usize) + } + + /// Return true if the map is empty. + pub fn is_empty(&self) -> Result { + Ok(self.len()? == 0) + } + + /// Return true if the map contains the key. + pub fn contains_key(&self, key: &K) -> Result { + static API_FUNC: LazyLock = + LazyLock::new(|| Function::get_global("ffi.MapCount").unwrap()); + let args = [AnyView::from(self), AnyView::from(key)]; + let count_any = (*API_FUNC).call_packed(&args)?; + let temp: TryFromTemp = TryFromTemp::try_from(count_any)?; + let count = TryFromTemp::into_value(temp); + Ok(count != 0) + } + + /// Return the value for key or raise a KeyError. + pub fn get(&self, key: &K) -> Result { + static API_FUNC: LazyLock = + LazyLock::new(|| Function::get_global("ffi.MapGetItem").unwrap()); + let args = [AnyView::from(self), AnyView::from(key)]; + let value_any = (*API_FUNC).call_packed(&args)?; + let temp: TryFromTemp = TryFromTemp::try_from(value_any)?; + Ok(TryFromTemp::into_value(temp)) + } + + /// Return the value for key or None if missing. + pub fn get_optional(&self, key: &K) -> Result> { + if !self.contains_key(key)? { + return Ok(None); + } + self.get(key).map(Some) + } + + /// Return the value for key or a default value if missing. + pub fn get_or(&self, key: &K, default: V) -> Result { + match self.get_optional(key)? { + Some(value) => Ok(value), + None => Ok(default), + } + } + + /// Iterate over key/value pairs. + pub fn iter(&self) -> Result> { + static API_FUNC: LazyLock = + LazyLock::new(|| Function::get_global("ffi.MapForwardIterFunctor").unwrap()); + let args = [AnyView::from(self)]; + let functor: Function = (*API_FUNC).call_packed(&args)?.try_into()?; + Ok(MapIterator { + functor, + remaining: self.len()?, + _marker: PhantomData, + }) + } +} + +pub struct MapIterator { + functor: Function, + remaining: usize, + _marker: PhantomData<(K, V)>, +} + +impl Iterator for MapIterator +where + K: AnyCompatible + Clone + 'static, + V: AnyCompatible + Clone + 'static, +{ + type Item = (K, V); + + fn next(&mut self) -> Option { + if self.remaining == 0 { + return None; + } + let key_any = self.functor.call_tuple_with_len::<1, _>((0i64,)).ok()?; + let key_temp: TryFromTemp = TryFromTemp::try_from(key_any).ok()?; + let key = TryFromTemp::into_value(key_temp); + + let value_any = self.functor.call_tuple_with_len::<1, _>((1i64,)).ok()?; + let value_temp: TryFromTemp = TryFromTemp::try_from(value_any).ok()?; + let value = TryFromTemp::into_value(value_temp); + let _ = self.functor.call_tuple_with_len::<1, _>((2i64,)); + self.remaining -= 1; + Some((key, value)) + } +} +#[repr(C)] +#[derive(Clone)] +pub struct Map { + data: ObjectArc, + _marker: PhantomData<(K, V)>, +} + +unsafe impl ObjectRefCore for Map { + type ContainerType = MapObj; + + fn data(this: &Self) -> &ObjectArc { + &this.data + } + + fn into_data(this: Self) -> ObjectArc { + this.data + } + + fn from_data(data: ObjectArc) -> Self { + Self { + data, + _marker: PhantomData, + } + } +} + +// --- Any Type System Conversions --- + +unsafe impl AnyCompatible for Map +where + K: AnyCompatible + Clone + 'static, + V: AnyCompatible + Clone + 'static, +{ + fn type_str() -> String { + format!("Map<{}, {}>", K::type_str(), V::type_str()) + } + + unsafe fn check_any_strict(data: &TVMFFIAny) -> bool { + data.type_index == TypeIndex::kTVMFFIMap as i32 + } + + unsafe fn copy_to_any_view(src: &Self, data: &mut TVMFFIAny) { + data.type_index = TypeIndex::kTVMFFIMap as i32; + data.data_union.v_obj = ObjectArc::as_raw(Self::data(src)) as *mut TVMFFIObject; + data.small_str_len = 0; + } + + unsafe fn move_to_any(src: Self, data: &mut TVMFFIAny) { + data.type_index = TypeIndex::kTVMFFIMap as i32; + data.data_union.v_obj = ObjectArc::into_raw(Self::into_data(src)) as *mut TVMFFIObject; + data.small_str_len = 0; + } + + unsafe fn copy_from_any_view_after_check(data: &TVMFFIAny) -> Self { + let ptr = data.data_union.v_obj as *const MapObj; + crate::object::unsafe_::inc_ref(ptr as *mut TVMFFIObject); + Self::from_data(ObjectArc::from_raw(ptr)) + } + + unsafe fn move_from_any_after_check(data: &mut TVMFFIAny) -> Self { + let ptr = data.data_union.v_obj as *const MapObj; + let obj = Self::from_data(ObjectArc::from_raw(ptr)); + + data.type_index = TypeIndex::kTVMFFINone as i32; + data.data_union.v_int64 = 0; + + obj + } + + unsafe fn try_cast_from_any_view(data: &TVMFFIAny) -> Result { + if data.type_index != TypeIndex::kTVMFFIMap as i32 { + return Err(()); + } + Ok(Self::copy_from_any_view_after_check(data)) + } +} + +impl TryFrom for Map +where + K: AnyCompatible + Clone + 'static, + V: AnyCompatible + Clone + 'static, +{ + type Error = crate::error::Error; + + fn try_from(value: Any) -> Result { + let temp: TryFromTemp = TryFromTemp::try_from(value)?; + Ok(TryFromTemp::into_value(temp)) + } +} + +impl<'a, K, V> TryFrom> for Map +where + K: AnyCompatible + Clone + 'static, + V: AnyCompatible + Clone + 'static, +{ + type Error = crate::error::Error; + + fn try_from(value: AnyView<'a>) -> Result { + let temp: TryFromTemp = TryFromTemp::try_from(value)?; + Ok(TryFromTemp::into_value(temp)) + } +} diff --git a/rust/tvm-ffi/src/collections/mod.rs b/rust/tvm-ffi/src/collections/mod.rs index ad17dcca..791ff755 100644 --- a/rust/tvm-ffi/src/collections/mod.rs +++ b/rust/tvm-ffi/src/collections/mod.rs @@ -18,5 +18,6 @@ */ /// Collection types pub mod array; +pub mod map; pub mod shape; pub mod tensor; diff --git a/rust/tvm-ffi/src/collections/shape.rs b/rust/tvm-ffi/src/collections/shape.rs index 39d9a1df..16ec4419 100644 --- a/rust/tvm-ffi/src/collections/shape.rs +++ b/rust/tvm-ffi/src/collections/shape.rs @@ -127,7 +127,7 @@ impl Eq for Shape {} impl PartialOrd for Shape { #[inline] fn partial_cmp(&self, other: &Self) -> Option { - self.as_slice().partial_cmp(other.as_slice()) + Some(self.cmp(other)) } } diff --git a/rust/tvm-ffi/src/collections/tensor.rs b/rust/tvm-ffi/src/collections/tensor.rs index 6b34613e..acd4bbe2 100644 --- a/rust/tvm-ffi/src/collections/tensor.rs +++ b/rust/tvm-ffi/src/collections/tensor.rs @@ -22,8 +22,8 @@ use crate::dtype::AsDLDataType; use crate::dtype::DLDataTypeExt; use crate::error::Result; use crate::object::{Object, ObjectArc, ObjectCore, ObjectCoreWithExtraItems}; -use tvm_ffi_sys::dlpack::{DLDataType, DLDevice, DLDeviceType, DLTensor}; use tvm_ffi_sys::TVMFFITypeIndex as TypeIndex; +use tvm_ffi_sys::dlpack::{DLDataType, DLDevice, DLDeviceType, DLTensor}; //----------------------------------------------------- // NDAllocator Trait @@ -66,7 +66,7 @@ impl DLTensorExt for DLTensor { } fn item_size(&self) -> usize { - (self.dtype.bits as usize * self.dtype.lanes as usize + 7) / 8 + (self.dtype.bits as usize * self.dtype.lanes as usize).div_ceil(8) } } @@ -267,15 +267,15 @@ impl Tensor { object: Object::new(), dltensor: DLTensor { data: std::ptr::null_mut(), - device: device, + device, ndim: shape.len() as i32, - dtype: dtype, + dtype, shape: std::ptr::null_mut(), strides: std::ptr::null_mut(), byte_offset: 0, }, }, - alloc: alloc, + alloc, }; unsafe { let mut obj_arc = ObjectArc::new_with_extra_items(tensor_obj); @@ -320,16 +320,16 @@ unsafe impl NDAllocator for CPUNDAlloc { const MIN_ALIGN: usize = 64; unsafe fn alloc_data(&mut self, prototype: &DLTensor) -> *mut core::ffi::c_void { - let numel = prototype.numel() as usize; + let numel = prototype.numel(); let item_size = prototype.item_size(); - let size = numel * item_size as usize; + let size = numel * item_size; let layout = std::alloc::Layout::from_size_align(size, Self::MIN_ALIGN).unwrap(); let ptr = std::alloc::alloc(layout); ptr as *mut core::ffi::c_void } unsafe fn free_data(&mut self, tensor: &DLTensor) { - let numel = tensor.numel() as usize; + let numel = tensor.numel(); let item_size = tensor.item_size(); let size = numel * item_size; let layout = std::alloc::Layout::from_size_align(size, Self::MIN_ALIGN).unwrap(); diff --git a/rust/tvm-ffi/src/device.rs b/rust/tvm-ffi/src/device.rs index d6c0418b..885836d9 100644 --- a/rust/tvm-ffi/src/device.rs +++ b/rust/tvm-ffi/src/device.rs @@ -75,7 +75,7 @@ unsafe impl AnyCompatible for DLDevice { } unsafe fn check_any_strict(data: &TVMFFIAny) -> bool { - return data.type_index == TypeIndex::kTVMFFIDevice as i32; + data.type_index == TypeIndex::kTVMFFIDevice as i32 } unsafe fn copy_from_any_view_after_check(data: &TVMFFIAny) -> Self { diff --git a/rust/tvm-ffi/src/dtype.rs b/rust/tvm-ffi/src/dtype.rs index eb471b2b..7081242b 100644 --- a/rust/tvm-ffi/src/dtype.rs +++ b/rust/tvm-ffi/src/dtype.rs @@ -18,9 +18,9 @@ */ use crate::error::Result; use crate::type_traits::AnyCompatible; +use tvm_ffi_sys::TVMFFITypeIndex as TypeIndex; /// Data type handling use tvm_ffi_sys::dlpack::{DLDataType, DLDataTypeCode}; -use tvm_ffi_sys::TVMFFITypeIndex as TypeIndex; use tvm_ffi_sys::{TVMFFIAny, TVMFFIByteArray, TVMFFIDataTypeFromString, TVMFFIDataTypeToString}; /// Extra methods for DLDataType @@ -53,7 +53,7 @@ impl DLDataTypeExt for DLDataType { fn to_string(&self) -> crate::string::String { unsafe { let mut ffi_any = TVMFFIAny::new(); - crate::check_safe_call!(TVMFFIDataTypeToString(&*self, &mut ffi_any)).unwrap(); + crate::check_safe_call!(TVMFFIDataTypeToString(self, &mut ffi_any)).unwrap(); crate::any::Any::from_raw_ffi_any(ffi_any) .try_into() .unwrap() @@ -120,7 +120,7 @@ unsafe impl AnyCompatible for DLDataType { /// # Returns /// `true` if the Any contains a DLDataType, `false` otherwise unsafe fn check_any_strict(data: &TVMFFIAny) -> bool { - return data.type_index == TypeIndex::kTVMFFIDataType as i32; + data.type_index == TypeIndex::kTVMFFIDataType as i32 } /// Copy a DLDataType from an Any view (after type check) diff --git a/rust/tvm-ffi/src/error.rs b/rust/tvm-ffi/src/error.rs index cb6dabb1..341d9365 100644 --- a/rust/tvm-ffi/src/error.rs +++ b/rust/tvm-ffi/src/error.rs @@ -18,6 +18,7 @@ */ use crate::derive::{Object, ObjectRef}; use crate::object::{Object, ObjectArc}; +use std::convert::Infallible; use std::ffi::c_void; use tvm_ffi_sys::TVMFFIBacktraceUpdateMode::kTVMFFIBacktraceUpdateModeAppend; use tvm_ffi_sys::{ @@ -64,6 +65,12 @@ pub struct Error { data: ObjectArc, } +impl From for Error { + fn from(value: Infallible) -> Self { + match value {} + } +} + /// Default result that uses Error as the error type pub type Result = std::result::Result; @@ -118,7 +125,7 @@ impl Error { /// # Returns /// The kind of the error pub fn kind(&self) -> ErrorKind<'_> { - ErrorKind(&self.data.cell.kind.as_str()) + ErrorKind(self.data.cell.kind.as_str()) } /// Get the message of the error @@ -179,7 +186,7 @@ impl Error { let mut new_backtrace = String::new(); new_backtrace.push_str(this.backtrace()); new_backtrace.push_str(backtrace); - return Error::new(this.kind(), this.message(), &new_backtrace); + Error::new(this.kind(), this.message(), &new_backtrace) } } } diff --git a/rust/tvm-ffi/src/function.rs b/rust/tvm-ffi/src/function.rs index e488983f..e2c74097 100644 --- a/rust/tvm-ffi/src/function.rs +++ b/rust/tvm-ffi/src/function.rs @@ -150,7 +150,7 @@ impl Function { heap_args.resize(args_len, AnyView::new()); &mut heap_args[..args_len] }; - (&tuple_args).fill_any_view(packed_args); + tuple_args.fill_any_view(packed_args); self.call_packed(packed_args) } /// Call function with compile-time known argument count @@ -170,7 +170,7 @@ impl Function { TupleType: TupleAsPackedArgs, { let mut packed_args = [AnyView::new(); LEN]; - (&tuple_args).fill_any_view(&mut packed_args); + tuple_args.fill_any_view(&mut packed_args); self.call_packed(&packed_args) } /// Get global function by name diff --git a/rust/tvm-ffi/src/function_internal.rs b/rust/tvm-ffi/src/function_internal.rs index e059051c..249e62cb 100644 --- a/rust/tvm-ffi/src/function_internal.rs +++ b/rust/tvm-ffi/src/function_internal.rs @@ -107,6 +107,75 @@ impl IntoArgHolder for &[u8] { } } +impl IntoArgHolder for crate::object::ObjectRef { + type Target = crate::object::ObjectRef; + fn into_arg_holder(self) -> Self::Target { + self + } +} + +impl IntoArgHolder for crate::any::AnyValue { + type Target = crate::any::AnyValue; + fn into_arg_holder(self) -> Self::Target { + self + } +} + +impl IntoArgHolder for crate::DLDevice { + type Target = crate::DLDevice; + fn into_arg_holder(self) -> Self::Target { + self + } +} + +impl IntoArgHolder for crate::DLDataType { + type Target = crate::DLDataType; + fn into_arg_holder(self) -> Self::Target { + self + } +} + +impl IntoArgHolder for crate::Array +where + T: crate::AnyCompatible + Clone + 'static, +{ + type Target = crate::Array; + fn into_arg_holder(self) -> Self::Target { + self + } +} + +impl IntoArgHolder for crate::Map +where + K: crate::AnyCompatible + Clone + 'static, + V: crate::AnyCompatible + Clone + 'static, +{ + type Target = crate::Map; + fn into_arg_holder(self) -> Self::Target { + self + } +} + +impl IntoArgHolder for Option +where + T: IntoArgHolder, +{ + type Target = Option; + fn into_arg_holder(self) -> Self::Target { + self.map(IntoArgHolder::into_arg_holder) + } +} + +impl IntoArgHolder for T +where + T: crate::object_wrapper::ObjectWrapper, +{ + type Target = T; + fn into_arg_holder(self) -> Self::Target { + self + } +} + // helper trait to implement IntoArgHolderTuple to apply into_arg_holder to each element pub trait IntoArgHolderTuple { type Target; @@ -115,6 +184,7 @@ pub trait IntoArgHolderTuple { macro_rules! impl_into_arg_holder_tuple { ( $($T:ident),* ; $($idx:tt),* ) => { + #[allow(clippy::unused_unit)] impl<$($T),*> $crate::function_internal::IntoArgHolderTuple for ($($T,)*) where $($T: IntoArgHolder),* { @@ -136,6 +206,10 @@ impl_into_arg_holder_tuple!(T0, T1, T2, T3, T4; 0, 1, 2, 3, 4); impl_into_arg_holder_tuple!(T0, T1, T2, T3, T4, T5; 0, 1, 2, 3, 4, 5); impl_into_arg_holder_tuple!(T0, T1, T2, T3, T4, T5, T6; 0, 1, 2, 3, 4, 5, 6); impl_into_arg_holder_tuple!(T0, T1, T2, T3, T4, T5, T6, T7; 0, 1, 2, 3, 4, 5, 6, 7); +impl_into_arg_holder_tuple!(T0, T1, T2, T3, T4, T5, T6, T7, T8; 0, 1, 2, 3, 4, 5, 6, 7, 8); +impl_into_arg_holder_tuple!(T0, T1, T2, T3, T4, T5, T6, T7, T8, T9; 0, 1, 2, 3, 4, 5, 6, 7, 8, 9); +impl_into_arg_holder_tuple!(T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10; 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10); +impl_into_arg_holder_tuple!(T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11; 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11); //------------------------------------------------------------ // ArgIntoRef @@ -152,6 +226,65 @@ crate::impl_arg_into_ref!( bool, i8, i16, i32, i64, isize, u8, u16, u32, u64, usize, f32, f64, String, Bytes ); +impl ArgIntoRef for crate::object::ObjectRef { + type Target = crate::object::ObjectRef; + fn to_ref(&self) -> &Self::Target { + self + } +} + +impl ArgIntoRef for crate::any::AnyValue { + type Target = crate::any::AnyValue; + fn to_ref(&self) -> &Self::Target { + self + } +} + +impl ArgIntoRef for crate::DLDevice { + type Target = crate::DLDevice; + fn to_ref(&self) -> &Self::Target { + self + } +} + +impl ArgIntoRef for crate::DLDataType { + type Target = crate::DLDataType; + fn to_ref(&self) -> &Self::Target { + self + } +} + +impl ArgIntoRef for Option +where + T: AnyCompatible, +{ + type Target = Option; + fn to_ref(&self) -> &Self::Target { + self + } +} + +impl ArgIntoRef for crate::Array +where + T: AnyCompatible + Clone, +{ + type Target = crate::Array; + fn to_ref(&self) -> &Self::Target { + self + } +} + +impl ArgIntoRef for crate::Map +where + K: AnyCompatible + Clone, + V: AnyCompatible + Clone, +{ + type Target = crate::Map; + fn to_ref(&self) -> &Self::Target { + self + } +} + //----------------------------------------------------------- // TupleAsPackedArgs // @@ -191,3 +324,7 @@ impl_tuple_as_packed_args!(5; T0, T1, T2, T3, T4; 0, 1, 2, 3, 4); impl_tuple_as_packed_args!(6; T0, T1, T2, T3, T4, T5; 0, 1, 2, 3, 4, 5); impl_tuple_as_packed_args!(7; T0, T1, T2, T3, T4, T5, T6; 0, 1, 2, 3, 4, 5, 6); impl_tuple_as_packed_args!(8; T0, T1, T2, T3, T4, T5, T6, T7; 0, 1, 2, 3, 4, 5, 6, 7); +impl_tuple_as_packed_args!(9; T0, T1, T2, T3, T4, T5, T6, T7, T8; 0, 1, 2, 3, 4, 5, 6, 7, 8); +impl_tuple_as_packed_args!(10; T0, T1, T2, T3, T4, T5, T6, T7, T8, T9; 0, 1, 2, 3, 4, 5, 6, 7, 8, 9); +impl_tuple_as_packed_args!(11; T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10; 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10); +impl_tuple_as_packed_args!(12; T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11; 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11); diff --git a/rust/tvm-ffi/src/lib.rs b/rust/tvm-ffi/src/lib.rs index fad82601..e6ba5a9a 100644 --- a/rust/tvm-ffi/src/lib.rs +++ b/rust/tvm-ffi/src/lib.rs @@ -16,6 +16,17 @@ * specific language governing permissions and limitations * under the License. */ +// TODO: incrementally migrate unsafe fn bodies to use explicit unsafe blocks +#![allow(unsafe_op_in_unsafe_fn)] +#![allow( + clippy::mut_from_ref, + clippy::not_unsafe_ptr_arg_deref, + clippy::missing_safety_doc, + clippy::new_without_default, + clippy::len_without_is_empty, + clippy::result_unit_err +)] + pub mod any; pub mod collections; pub mod derive; @@ -27,23 +38,27 @@ pub mod function; pub mod function_internal; pub mod macros; pub mod object; +pub mod object_wrapper; pub mod string; +pub mod subtyping; pub mod type_traits; pub use tvm_ffi_sys; -pub use crate::any::{Any, AnyView}; +pub use crate::any::{Any, AnyValue, AnyView}; pub use crate::collections::array::Array; +pub use crate::collections::map::Map; pub use crate::collections::shape::Shape; pub use crate::collections::tensor::{CPUNDAlloc, NDAllocator, Tensor}; pub use crate::device::{current_stream, with_stream}; pub use crate::dtype::DLDataTypeExt; -pub use crate::error::{Error, ErrorKind, Result}; pub use crate::error::{ ATTRIBUTE_ERROR, INDEX_ERROR, KEY_ERROR, RUNTIME_ERROR, TYPE_ERROR, VALUE_ERROR, }; +pub use crate::error::{Error, ErrorKind, Result}; pub use crate::extra::module::Module; pub use crate::function::Function; pub use crate::object::{Object, ObjectArc, ObjectCore, ObjectCoreWithExtraItems, ObjectRefCore}; +pub use crate::object_wrapper::ObjectWrapper; pub use crate::string::{Bytes, String}; pub use crate::type_traits::AnyCompatible; diff --git a/rust/tvm-ffi/src/macros.rs b/rust/tvm-ffi/src/macros.rs index 95c2b5ba..e2c8eb3d 100644 --- a/rust/tvm-ffi/src/macros.rs +++ b/rust/tvm-ffi/src/macros.rs @@ -47,6 +47,7 @@ macro_rules! function_name { /// /// # Returns /// * `Result<(), Error>` - The result of the safe call +/// /// Macro to check safe calls and automatically update traceback with file/line info /// /// Usage: check_safe_call!(function(args))?; @@ -102,7 +103,7 @@ macro_rules! bail { macro_rules! ensure { ($cond:expr, $error_kind:expr, $fmt:expr $(, $args:expr)* $(,)?) => {{ if !$cond { - crate::bail!($error_kind, $fmt $(, $args)*); + $crate::bail!($error_kind, $fmt $(, $args)*); } }}; } @@ -238,6 +239,135 @@ macro_rules! impl_arg_into_ref { } } +/// Define a stubgen-oriented object wrapper type. +/// +/// This macro is intended for code emitted by the Rust stub generator. +/// It is not meant as a general-purpose user-facing API. +#[macro_export] +macro_rules! define_object_wrapper { + ($name:ident, $type_key:expr) => { + #[derive(Clone)] + pub struct $name { + inner: $crate::object::ObjectRef, + } + + impl $name { + pub fn from_object(inner: $crate::object::ObjectRef) -> Self { + Self { inner } + } + + pub fn as_object_ref(&self) -> &$crate::object::ObjectRef { + &self.inner + } + + pub fn into_object_ref(self) -> $crate::object::ObjectRef { + self.inner + } + } + + impl From<$crate::object::ObjectRef> for $name { + fn from(inner: $crate::object::ObjectRef) -> Self { + Self::from_object(inner) + } + } + + impl From<$name> for $crate::object::ObjectRef { + fn from(wrapper: $name) -> Self { + wrapper.into_object_ref() + } + } + + impl $crate::object_wrapper::ObjectWrapper for $name { + const TYPE_KEY: &'static str = $type_key; + + fn from_object(inner: $crate::object::ObjectRef) -> Self { + Self::from_object(inner) + } + + fn as_object_ref(&self) -> &$crate::object::ObjectRef { + self.as_object_ref() + } + + fn into_object_ref(self) -> $crate::object::ObjectRef { + self.into_object_ref() + } + } + + $crate::impl_try_from_any!($name); + }; +} + +/// Implement object hierarchy relationships (Deref, From, TryFrom). +/// +/// This macro is intended for code emitted by the Rust stub generator to +/// establish parent-child relationships in the object hierarchy. +/// +/// # Syntax +/// ```ignore +/// impl_object_hierarchy!(Self: DirectParent, Grandparent, ..., ObjectRef); +/// ``` +/// +/// # Generated implementations +/// - `Deref` for ergonomic field access +/// - `From for DirectParent` (and all ancestors) for upcasts +/// - `TryFrom for Self` for downcasts +/// +/// # Example +/// ```ignore +/// // Given: Node -> BaseExpr -> Expr -> ObjectRef +/// impl_object_hierarchy!(Node: BaseExpr, Expr, ObjectRef); +/// ``` +#[macro_export] +macro_rules! impl_object_hierarchy { + ($self_ty:ty: $direct_parent:ty $(, $ancestor:ty)* $(,)?) => { + // Implement Deref to the direct parent for ergonomic access + impl std::ops::Deref for $self_ty { + type Target = $direct_parent; + + fn deref(&self) -> &Self::Target { + // Safety: All ObjectRef types are repr(C) with a single pointer field (ObjectArc). + // Self and DirectParent have identical memory layout. + // This is a zero-cost, lifetime-preserving reference cast. + unsafe { &*(self as *const $self_ty as *const $direct_parent) } + } + } + + // Implement From for DirectParent (upcast) + impl From<$self_ty> for $direct_parent { + fn from(value: $self_ty) -> Self { + $crate::subtyping::upcast(value) + } + } + + // Implement TryFrom for Self (downcast) + impl TryFrom<$direct_parent> for $self_ty { + type Error = $direct_parent; + + fn try_from(value: $direct_parent) -> Result { + $crate::subtyping::try_downcast(value) + } + } + + // Implement From for each ancestor (transitive upcast) + $( + impl From<$self_ty> for $ancestor { + fn from(value: $self_ty) -> Self { + $crate::subtyping::upcast(value) + } + } + + // Implement TryFrom for Self (downcast) + impl TryFrom<$ancestor> for $self_ty { + type Error = $ancestor; + + fn try_from(value: $ancestor) -> Result { + $crate::subtyping::try_downcast(value) + } + } + )* + }; +} + // ---------------------------------------------------------------------------- // Macros for function definitions // ---------------------------------------------------------------------------- @@ -312,7 +442,7 @@ macro_rules! tvm_ffi_dll_export_typed_func { /// Since the ffi mechanism requires us to pass arguments by reference. /// /// # Supported Argument Counts -/// This macro supports functions with 0 to 8 arguments. +/// This macro supports functions with 0 to 12 arguments. ///----------------------------------------------------------- #[macro_export] macro_rules! into_typed_fn { @@ -322,7 +452,7 @@ macro_rules! into_typed_fn { move || -> $ret_ty { Ok(_f.call_tuple_with_len::<0, _>(())?.try_into()?) } }}; // Case for 1 argument - ($f:expr, $trait:ident($t0:ty) -> $ret_ty:ty) => {{ + ($f:expr, $trait:ident($t0:ty $(,)?) -> $ret_ty:ty) => {{ let _f = $f; move |a0: $t0| -> $ret_ty { use $crate::function_internal::IntoArgHolderTuple; @@ -331,7 +461,7 @@ macro_rules! into_typed_fn { } }}; // Case for 2 arguments - ($f:expr, $trait:ident($t0:ty, $t1:ty) -> $ret_ty:ty) => {{ + ($f:expr, $trait:ident($t0:ty, $t1:ty $(,)?) -> $ret_ty:ty) => {{ let _f = $f; move |a0: $t0, a1: $t1| -> $ret_ty { use $crate::function_internal::IntoArgHolderTuple; @@ -340,7 +470,7 @@ macro_rules! into_typed_fn { } }}; // Case for 3 arguments - ($f:expr, $trait:ident($t0:ty, $t1:ty, $t2:ty) -> $ret_ty:ty) => {{ + ($f:expr, $trait:ident($t0:ty, $t1:ty, $t2:ty $(,)?) -> $ret_ty:ty) => {{ let _f = $f; move |a0: $t0, a1: $t1, a2: $t2| -> $ret_ty { use $crate::function_internal::IntoArgHolderTuple; @@ -349,7 +479,7 @@ macro_rules! into_typed_fn { } }}; // Case for 4 arguments - ($f:expr, $trait:ident($t0:ty, $t1:ty, $t2:ty, $t3:ty) -> $ret_ty:ty) => {{ + ($f:expr, $trait:ident($t0:ty, $t1:ty, $t2:ty, $t3:ty $(,)?) -> $ret_ty:ty) => {{ let _f = $f; move |a0: $t0, a1: $t1, a2: $t2, a3: $t3| -> $ret_ty { use $crate::function_internal::IntoArgHolderTuple; @@ -358,7 +488,7 @@ macro_rules! into_typed_fn { } }}; // Case for 5 arguments - ($f:expr, $trait:ident($t0:ty, $t1:ty, $t2:ty, $t3:ty, $t4:ty) -> $ret_ty:ty) => {{ + ($f:expr, $trait:ident($t0:ty, $t1:ty, $t2:ty, $t3:ty, $t4:ty $(,)?) -> $ret_ty:ty) => {{ let _f = $f; move |a0: $t0, a1: $t1, a2: $t2, a3: $t3, a4: $t4| -> $ret_ty { use $crate::function_internal::IntoArgHolderTuple; @@ -367,7 +497,7 @@ macro_rules! into_typed_fn { } }}; // Case for 6 arguments - ($f:expr, $trait:ident($t0:ty, $t1:ty, $t2:ty, $t3:ty, $t4:ty, $t5:ty) -> $ret_ty:ty) => {{ + ($f:expr, $trait:ident($t0:ty, $t1:ty, $t2:ty, $t3:ty, $t4:ty, $t5:ty $(,)?) -> $ret_ty:ty) => {{ let _f = $f; move |a0: $t0, a1: $t1, a2: $t2, a3: $t3, a4: $t4, a5: $t5| -> $ret_ty { use $crate::function_internal::IntoArgHolderTuple; @@ -376,7 +506,7 @@ macro_rules! into_typed_fn { } }}; // Case for 7 arguments - ($f:expr, $trait:ident($t0:ty, $t1:ty, $t2:ty, $t3:ty, $t4:ty, $t5:ty, $t6:ty) + ($f:expr, $trait:ident($t0:ty, $t1:ty, $t2:ty, $t3:ty, $t4:ty, $t5:ty, $t6:ty $(,)?) -> $ret_ty:ty) => {{ let _f = $f; move |a0: $t0, a1: $t1, a2: $t2, a3: $t3, a4: $t4, a5: $t5, a6: $t6| -> $ret_ty { @@ -386,7 +516,7 @@ macro_rules! into_typed_fn { } }}; // Case for 8 arguments - ($f:expr, $trait:ident($t0:ty, $t1:ty, $t2:ty, $t3:ty, $t4:ty, $t5:ty, $t6:ty, $t7:ty) + ($f:expr, $trait:ident($t0:ty, $t1:ty, $t2:ty, $t3:ty, $t4:ty, $t5:ty, $t6:ty, $t7:ty $(,)?) -> $ret_ty:ty) => {{ let _f = $f; move |a0: $t0, a1: $t1, a2: $t2, a3: $t3, a4: $t4, a5: $t5, a6: $t6, a7: $t7| -> $ret_ty { @@ -395,4 +525,87 @@ macro_rules! into_typed_fn { Ok(_f.call_tuple_with_len::<8, _>(tuple_args)?.try_into()?) } }}; + // Case for 9 arguments + ($f:expr, $trait:ident($t0:ty, $t1:ty, $t2:ty, $t3:ty, $t4:ty, $t5:ty, $t6:ty, $t7:ty, $t8:ty $(,)?) + -> $ret_ty:ty) => {{ + let _f = $f; + move |a0: $t0, + a1: $t1, + a2: $t2, + a3: $t3, + a4: $t4, + a5: $t5, + a6: $t6, + a7: $t7, + a8: $t8| + -> $ret_ty { + use $crate::function_internal::IntoArgHolderTuple; + let tuple_args = (a0, a1, a2, a3, a4, a5, a6, a7, a8).into_arg_holder_tuple(); + Ok(_f.call_tuple_with_len::<9, _>(tuple_args)?.try_into()?) + } + }}; + // Case for 10 arguments + ($f:expr, $trait:ident($t0:ty, $t1:ty, $t2:ty, $t3:ty, $t4:ty, $t5:ty, $t6:ty, $t7:ty, $t8:ty, $t9:ty $(,)?) + -> $ret_ty:ty) => {{ + let _f = $f; + move |a0: $t0, + a1: $t1, + a2: $t2, + a3: $t3, + a4: $t4, + a5: $t5, + a6: $t6, + a7: $t7, + a8: $t8, + a9: $t9| + -> $ret_ty { + use $crate::function_internal::IntoArgHolderTuple; + let tuple_args = (a0, a1, a2, a3, a4, a5, a6, a7, a8, a9).into_arg_holder_tuple(); + Ok(_f.call_tuple_with_len::<10, _>(tuple_args)?.try_into()?) + } + }}; + // Case for 11 arguments + ($f:expr, $trait:ident($t0:ty, $t1:ty, $t2:ty, $t3:ty, $t4:ty, $t5:ty, $t6:ty, $t7:ty, $t8:ty, $t9:ty, $t10:ty $(,)?) + -> $ret_ty:ty) => {{ + let _f = $f; + move |a0: $t0, + a1: $t1, + a2: $t2, + a3: $t3, + a4: $t4, + a5: $t5, + a6: $t6, + a7: $t7, + a8: $t8, + a9: $t9, + a10: $t10| + -> $ret_ty { + use $crate::function_internal::IntoArgHolderTuple; + let tuple_args = (a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10).into_arg_holder_tuple(); + Ok(_f.call_tuple_with_len::<11, _>(tuple_args)?.try_into()?) + } + }}; + // Case for 12 arguments + ($f:expr, $trait:ident($t0:ty, $t1:ty, $t2:ty, $t3:ty, $t4:ty, $t5:ty, $t6:ty, $t7:ty, $t8:ty, $t9:ty, $t10:ty, $t11:ty $(,)?) + -> $ret_ty:ty) => {{ + let _f = $f; + move |a0: $t0, + a1: $t1, + a2: $t2, + a3: $t3, + a4: $t4, + a5: $t5, + a6: $t6, + a7: $t7, + a8: $t8, + a9: $t9, + a10: $t10, + a11: $t11| + -> $ret_ty { + use $crate::function_internal::IntoArgHolderTuple; + let tuple_args = + (a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11).into_arg_holder_tuple(); + Ok(_f.call_tuple_with_len::<12, _>(tuple_args)?.try_into()?) + } + }}; } diff --git a/rust/tvm-ffi/src/object.rs b/rust/tvm-ffi/src/object.rs index dc1970a4..00754aae 100644 --- a/rust/tvm-ffi/src/object.rs +++ b/rust/tvm-ffi/src/object.rs @@ -22,7 +22,7 @@ use std::sync::atomic::AtomicU64; use crate::derive::ObjectRef; pub use tvm_ffi_sys::TVMFFITypeIndex as TypeIndex; /// Object related ABI handling -use tvm_ffi_sys::{TVMFFIObject, COMBINED_REF_COUNT_BOTH_ONE}; +use tvm_ffi_sys::{COMBINED_REF_COUNT_BOTH_ONE, TVMFFIObject}; /// Object type is by default the TVMFFIObject #[repr(C)] @@ -60,7 +60,6 @@ pub unsafe trait ObjectCore: Sized + 'static { /// /// # Returns /// * `&mut TVMFFIObject` - The object header - /// \return The object header unsafe fn object_header_mut(this: &mut Self) -> &mut TVMFFIObject; } @@ -119,7 +118,7 @@ pub(crate) mod unsafe_ { }; use std::ffi::c_void; - use std::sync::atomic::{fence, Ordering}; + use std::sync::atomic::{Ordering, fence}; use tvm_ffi_sys::TVMFFIObject; use tvm_ffi_sys::TVMFFIObjectDeleterFlagBitMask::{ kTVMFFIObjectDeleterFlagBitMaskBoth, kTVMFFIObjectDeleterFlagBitMaskStrong, @@ -307,7 +306,7 @@ impl ObjectArc { ); // move into the object arc ptr Self { - ptr: std::ptr::NonNull::new_unchecked(ptr as *mut T), + ptr: std::ptr::NonNull::new_unchecked(ptr), _phantom: std::marker::PhantomData, } } @@ -345,7 +344,7 @@ impl ObjectArc { ); // move into the object arc ptr Self { - ptr: std::ptr::NonNull::new_unchecked(ptr as *mut T), + ptr: std::ptr::NonNull::new_unchecked(ptr), _phantom: std::marker::PhantomData, } } @@ -358,7 +357,6 @@ impl ObjectArc { /// /// # Returns /// * `ObjectArc` - The ObjectArc - /// \return The ObjectArc #[inline] pub unsafe fn from_raw(ptr: *const T) -> Self { Self { @@ -389,7 +387,6 @@ impl ObjectArc { /// /// # Returns /// * `*const T` - The raw pointer - /// \return The raw pointer #[inline] pub unsafe fn as_raw(this: &Self) -> *const T { this.ptr.as_ptr() as *const T diff --git a/rust/tvm-ffi/src/object_wrapper.rs b/rust/tvm-ffi/src/object_wrapper.rs new file mode 100644 index 00000000..4533d5be --- /dev/null +++ b/rust/tvm-ffi/src/object_wrapper.rs @@ -0,0 +1,291 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use crate::any::Any; +use crate::object::{Object, ObjectArc, ObjectRef, ObjectRefCore}; +use crate::type_traits::AnyCompatible; +use std::marker::PhantomData; +use tvm_ffi_sys::{ + TVMFFIAny, TVMFFIAnyViewToOwnedAny, TVMFFIByteArray, TVMFFIFieldGetter, TVMFFIGetTypeInfo, + TVMFFIObject, TVMFFITypeKeyToIndex, +}; + +/// Runtime support for stubgen-generated object wrappers. +/// +/// This module is intended for code emitted by the Rust stub generator and is +/// not meant as a general-purpose user-facing API. +pub trait ObjectWrapper: Clone { + const TYPE_KEY: &'static str; + fn from_object(inner: ObjectRef) -> Self; + fn as_object_ref(&self) -> &ObjectRef; + fn into_object_ref(self) -> ObjectRef; +} + +/// Resolve an object type method from runtime reflection metadata. +/// +/// Unlike `Function::get_global`, this lookup walks `TVMFFITypeInfo.methods` +/// for the given type key and converts the method entry to a callable +/// `ffi.Function`. +pub fn resolve_type_method(type_key: &str, method_name: &str) -> crate::Result { + unsafe { + let key = TVMFFIByteArray::from_str(type_key); + let mut type_index = 0i32; + crate::check_safe_call!(TVMFFITypeKeyToIndex(&key, &mut type_index))?; + resolve_type_method_by_type_index(type_index, type_key, method_name) + } +} + +struct FieldGetterInner { + offset: usize, + getter: TVMFFIFieldGetter, +} + +impl FieldGetterInner { + fn get_any(&self, obj: &ObjectRef) -> crate::Result { + unsafe { + let arc = ::data(obj); + let raw = ObjectArc::as_raw(arc) as *mut TVMFFIObject; + if raw.is_null() { + crate::bail!( + crate::error::ATTRIBUTE_ERROR, + "Null object for field access" + ); + } + let field_ptr = (raw as *mut u8).add(self.offset) as *mut std::ffi::c_void; + let mut out = TVMFFIAny::new(); + crate::check_safe_call!((self.getter)(field_ptr, &mut out))?; + Ok(Any::from_raw_ffi_any(out)) + } + } +} + +pub struct FieldGetter { + inner: FieldGetterInner, + _marker: PhantomData, +} + +// FieldGetter stores only reflection metadata, not values of T. +// It is safe to share across threads regardless of T's Send/Sync. +unsafe impl Send for FieldGetter {} +unsafe impl Sync for FieldGetter {} + +impl FieldGetter { + pub fn new(type_key: &'static str, field_name: &'static str) -> crate::Result { + let inner = resolve_field_by_type_key(type_key, field_name)?; + Ok(Self { + inner, + _marker: PhantomData, + }) + } + + pub fn get_any(&self, obj: &ObjectRef) -> crate::Result { + self.inner.get_any(obj) + } +} + +impl FieldGetter +where + T: TryFrom, + T::Error: Into, +{ + pub fn get(&self, obj: &ObjectRef) -> crate::Result { + self.inner.get_any(obj)?.try_into().map_err(Into::into) + } +} + +fn resolve_field_by_type_key( + type_key: &'static str, + field_name: &'static str, +) -> crate::Result { + unsafe { + let key = TVMFFIByteArray::from_str(type_key); + let mut type_index = 0i32; + crate::check_safe_call!(TVMFFITypeKeyToIndex(&key, &mut type_index))?; + resolve_field_by_type_index(type_index, field_name) + } +} + +fn resolve_type_method_by_type_index( + type_index: i32, + type_key: &str, + method_name: &str, +) -> crate::Result { + unsafe { + let info = TVMFFIGetTypeInfo(type_index); + if info.is_null() { + crate::bail!( + crate::error::ATTRIBUTE_ERROR, + "Type info missing for type {}", + type_key + ); + } + let info = &*info; + if info.methods.is_null() || info.num_methods <= 0 { + crate::bail!( + crate::error::ATTRIBUTE_ERROR, + "Type {} has no methods", + type_key + ); + } + let methods = std::slice::from_raw_parts(info.methods, info.num_methods as usize); + for method in methods { + if method.name.as_str() != method_name { + continue; + } + let mut owned = TVMFFIAny::new(); + crate::check_safe_call!(TVMFFIAnyViewToOwnedAny(&method.method, &mut owned))?; + let method_any = Any::from_raw_ffi_any(owned); + return method_any.try_into().map_err(|_err: crate::Error| { + crate::Error::new( + crate::TYPE_ERROR, + &format!( + "Method {}.{} is not callable as ffi.Function", + type_key, method_name + ), + "", + ) + }); + } + crate::bail!( + crate::error::ATTRIBUTE_ERROR, + "Method {}.{} not found in reflection metadata", + type_key, + method_name + ); + } +} + +fn resolve_field_by_type_index( + type_index: i32, + field_name: &'static str, +) -> crate::Result { + unsafe { + let info = TVMFFIGetTypeInfo(type_index); + if info.is_null() { + crate::bail!( + crate::error::ATTRIBUTE_ERROR, + "Type info missing for field {}", + field_name + ); + } + let info = &*info; + if info.fields.is_null() || info.num_fields <= 0 { + crate::bail!( + crate::error::ATTRIBUTE_ERROR, + "Type {} has no fields", + info.type_key.as_str() + ); + } + let fields = std::slice::from_raw_parts(info.fields, info.num_fields as usize); + for field in fields { + if field.name.as_str() != field_name { + continue; + } + let getter = match field.getter { + Some(getter) => getter, + None => { + crate::bail!( + crate::error::ATTRIBUTE_ERROR, + "Field {} has no getter", + field_name + ); + } + }; + if field.offset < 0 { + crate::bail!( + crate::error::ATTRIBUTE_ERROR, + "Field {} has invalid offset", + field_name + ); + } + return Ok(FieldGetterInner { + offset: field.offset as usize, + getter, + }); + } + crate::bail!( + crate::error::ATTRIBUTE_ERROR, + "Field {} not found", + field_name + ); + } +} + +fn type_index_for_key(type_key: &'static str) -> Option { + let key = unsafe { TVMFFIByteArray::from_str(type_key) }; + let mut index = 0i32; + let code = unsafe { TVMFFITypeKeyToIndex(&key, &mut index) }; + if code == 0 { Some(index) } else { None } +} + +unsafe impl AnyCompatible for T { + fn type_str() -> String { + T::TYPE_KEY.to_string() + } + + unsafe fn copy_to_any_view(src: &Self, data: &mut TVMFFIAny) { + let obj = src.as_object_ref(); + let arc = ::data(obj); + let raw = ObjectArc::as_raw(arc) as *mut TVMFFIObject; + data.type_index = (*raw).type_index; + data.small_str_len = 0; + data.data_union.v_obj = raw; + } + + unsafe fn move_to_any(src: Self, data: &mut TVMFFIAny) { + let obj = src.into_object_ref(); + let arc = ::into_data(obj); + let raw = ObjectArc::into_raw(arc) as *mut TVMFFIObject; + data.type_index = (*raw).type_index; + data.small_str_len = 0; + data.data_union.v_obj = raw; + } + + unsafe fn check_any_strict(data: &TVMFFIAny) -> bool { + let Some(target_index) = type_index_for_key(T::TYPE_KEY) else { + return false; + }; + crate::subtyping::is_instance_of(data.type_index, target_index) + } + + unsafe fn copy_from_any_view_after_check(data: &TVMFFIAny) -> Self { + let ptr = data.data_union.v_obj; + crate::object::unsafe_::inc_ref(ptr); + let arc = ObjectArc::from_raw(ptr as *mut Object); + let obj = ::from_data(arc); + T::from_object(obj) + } + + unsafe fn move_from_any_after_check(data: &mut TVMFFIAny) -> Self { + let ptr = data.data_union.v_obj; + let arc = ObjectArc::from_raw(ptr as *mut Object); + data.type_index = crate::TypeIndex::kTVMFFINone as i32; + data.data_union.v_int64 = 0; + let obj = ::from_data(arc); + T::from_object(obj) + } + + unsafe fn try_cast_from_any_view(data: &TVMFFIAny) -> Result { + if Self::check_any_strict(data) { + Ok(Self::copy_from_any_view_after_check(data)) + } else { + Err(()) + } + } +} diff --git a/rust/tvm-ffi/src/string.rs b/rust/tvm-ffi/src/string.rs index 94739a4c..2e7a0b12 100644 --- a/rust/tvm-ffi/src/string.rs +++ b/rust/tvm-ffi/src/string.rs @@ -17,7 +17,7 @@ * under the License. */ use crate::derive::Object; -use crate::object::{unsafe_, Object, ObjectArc, ObjectCoreWithExtraItems}; +use crate::object::{Object, ObjectArc, ObjectCoreWithExtraItems, unsafe_}; use crate::type_traits::AnyCompatible; use std::cmp::Ordering; use std::fmt::{Debug, Display}; @@ -91,7 +91,7 @@ unsafe impl ObjectCoreWithExtraItems for BytesObj { #[inline] /// Get the count of extra items (trailing null byte for FFI compatibility) fn extra_items_count(this: &Self) -> usize { - return this.data.size + 1; + this.data.size + 1 } } @@ -114,7 +114,7 @@ where data: TVMFFIAny { type_index: TypeIndex::kTVMFFISmallBytes as i32, small_str_len: value.len() as u32, - data_union: data_union, + data_union, }, } } else { @@ -186,7 +186,7 @@ impl Eq for Bytes {} impl PartialOrd for Bytes { #[inline] fn partial_cmp(&self, other: &Self) -> Option { - self.as_slice().partial_cmp(other.as_slice()) + Some(self.cmp(other)) } } @@ -244,7 +244,7 @@ unsafe impl ObjectCoreWithExtraItems for StringObj { /// Get the count of extra items (trailing null byte for FFI compatibility) fn extra_items_count(this: &Self) -> usize { // extra item is the trailing \0 for ffi compatibility - return this.data.size + 1; + this.data.size + 1 } } @@ -304,7 +304,7 @@ where data: TVMFFIAny { type_index: TypeIndex::kTVMFFISmallStr as i32, small_str_len: bytes.len() as u32, - data_union: data_union, + data_union, }, } } else { @@ -402,7 +402,7 @@ where impl PartialOrd for String { #[inline] fn partial_cmp(&self, other: &Self) -> Option { - self.as_str().partial_cmp(other.as_str()) + Some(self.cmp(other)) } } @@ -452,8 +452,8 @@ unsafe impl AnyCompatible for Bytes { } unsafe fn check_any_strict(data: &TVMFFIAny) -> bool { - return data.type_index == TypeIndex::kTVMFFISmallBytes as i32 - || data.type_index == TypeIndex::kTVMFFIBytes as i32; + data.type_index == TypeIndex::kTVMFFISmallBytes as i32 + || data.type_index == TypeIndex::kTVMFFIBytes as i32 } unsafe fn copy_from_any_view_after_check(data: &TVMFFIAny) -> Self { @@ -500,8 +500,8 @@ unsafe impl AnyCompatible for String { } unsafe fn check_any_strict(data: &TVMFFIAny) -> bool { - return data.type_index == TypeIndex::kTVMFFISmallStr as i32 - || data.type_index == TypeIndex::kTVMFFIStr as i32; + data.type_index == TypeIndex::kTVMFFISmallStr as i32 + || data.type_index == TypeIndex::kTVMFFIStr as i32 } unsafe fn copy_from_any_view_after_check(data: &TVMFFIAny) -> Self { diff --git a/rust/tvm-ffi/src/subtyping.rs b/rust/tvm-ffi/src/subtyping.rs new file mode 100644 index 00000000..18c42703 --- /dev/null +++ b/rust/tvm-ffi/src/subtyping.rs @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +//! Subtyping infrastructure for object hierarchy conversions. +//! +//! This module provides type-safe upcast and downcast operations for objects +//! that follow the TVM FFI object hierarchy. + +use crate::object::{Object, ObjectArc, ObjectCore, ObjectRefCore}; +use tvm_ffi_sys::TVMFFIGetTypeInfo; + +/// Check if a type_index is an instance of target_index (including inheritance). +/// +/// # Safety +/// This function accesses the type info table via FFI and follows ancestor pointers. +#[doc(hidden)] +pub unsafe fn is_instance_of(type_index: i32, target_index: i32) -> bool { + if type_index == target_index { + return true; + } + let info = TVMFFIGetTypeInfo(type_index); + if info.is_null() { + return false; + } + let info = &*info; + let ancestors = info.type_acenstors; + if ancestors.is_null() { + return false; + } + for depth in 0..info.type_depth { + let ancestor = *ancestors.add(depth as usize); + if !ancestor.is_null() && (*ancestor).type_index == target_index { + return true; + } + } + false +} + +/// Upcast an object reference from a subtype to a supertype. +/// +/// This is a consuming operation that transfers ownership. +/// +/// # Type Parameters +/// * `From` - The source type (subtype) +/// * `To` - The target type (supertype) +/// +/// # Internal Implementation Detail +/// This function is public for macro expansion but should not be called directly. +/// Use `From::from()` or `.into()` for upcasting instead. +#[doc(hidden)] +pub fn upcast(value: From) -> To { + unsafe { + let arc = ::into_data(value); + let raw = ObjectArc::into_raw(arc); + let casted = ObjectArc::from_raw(raw as *const ::ContainerType); + ::from_data(casted) + } +} + +/// Try to downcast an object reference from a supertype to a subtype. +/// +/// This is a consuming operation that transfers ownership on success. +/// +/// # Type Parameters +/// * `From` - The source type (supertype) +/// * `To` - The target type (subtype) +/// +/// # Returns +/// * `Ok(To)` - If the runtime type check succeeds +/// * `Err(From)` - If the runtime type check fails, returns the original value +/// +/// # Internal Implementation Detail +/// This function is public for macro expansion but should not be called directly. +/// Use `TryFrom::try_from()` or `.try_into()` for downcasting instead. +#[doc(hidden)] +pub fn try_downcast(value: From) -> Result { + unsafe { + let arc = ::data(&value); + let raw = ObjectArc::as_raw(arc) as *const Object as *const tvm_ffi_sys::TVMFFIObject; + let type_index = (*raw).type_index; + let target_index = ::ContainerType::type_index(); + + if is_instance_of(type_index, target_index) { + // Type check passed, perform the downcast + let arc = ::into_data(value); + let raw = ObjectArc::into_raw(arc); + let casted = ObjectArc::from_raw(raw as *const ::ContainerType); + Ok(::from_data(casted)) + } else { + // Type check failed, return the original value + Err(value) + } + } +} diff --git a/rust/tvm-ffi/src/type_traits.rs b/rust/tvm-ffi/src/type_traits.rs index d39da4b4..eb96bc99 100644 --- a/rust/tvm-ffi/src/type_traits.rs +++ b/rust/tvm-ffi/src/type_traits.rs @@ -16,8 +16,9 @@ * specific language governing permissions and limitations * under the License. */ +use crate::any::{Any, AnyValue}; use tvm_ffi_sys::TVMFFITypeIndex as TypeIndex; -use tvm_ffi_sys::{TVMFFIAny, TVMFFIGetTypeInfo}; +use tvm_ffi_sys::{TVMFFIAny, TVMFFIAnyViewToOwnedAny, TVMFFIGetTypeInfo}; //----------------------------------------------------- // AnyCompatible @@ -150,7 +151,7 @@ unsafe impl AnyCompatible for Option { } unsafe fn copy_to_any_view(src: &Self, data: &mut TVMFFIAny) { - if let Some(ref value) = src { + if let Some(value) = src { T::copy_to_any_view(value, data); } else { data.type_index = TypeIndex::kTVMFFINone as i32; @@ -170,7 +171,7 @@ unsafe impl AnyCompatible for Option { } unsafe fn check_any_strict(data: &TVMFFIAny) -> bool { - return T::check_any_strict(data) || data.type_index == TypeIndex::kTVMFFINone as i32; + T::check_any_strict(data) || data.type_index == TypeIndex::kTVMFFINone as i32 } unsafe fn copy_from_any_view_after_check(data: &TVMFFIAny) -> Self { @@ -198,6 +199,41 @@ unsafe impl AnyCompatible for Option { } } +/// AnyCompatible for AnyValue +unsafe impl AnyCompatible for AnyValue { + fn type_str() -> String { + "Any".to_string() + } + + unsafe fn copy_to_any_view(src: &Self, data: &mut TVMFFIAny) { + *data = src.as_any().as_raw_ffi_any(); + } + + unsafe fn move_to_any(src: Self, data: &mut TVMFFIAny) { + *data = Any::into_raw_ffi_any(src.into_any()); + } + + unsafe fn check_any_strict(_data: &TVMFFIAny) -> bool { + true + } + + unsafe fn copy_from_any_view_after_check(data: &TVMFFIAny) -> Self { + let mut owned = TVMFFIAny::new(); + crate::check_safe_call!(TVMFFIAnyViewToOwnedAny(data, &mut owned)).unwrap(); + AnyValue::from(Any::from_raw_ffi_any(owned)) + } + + unsafe fn move_from_any_after_check(data: &mut TVMFFIAny) -> Self { + let raw = *data; + *data = TVMFFIAny::new(); + AnyValue::from(Any::from_raw_ffi_any(raw)) + } + + unsafe fn try_cast_from_any_view(data: &TVMFFIAny) -> Result { + Ok(Self::copy_from_any_view_after_check(data)) + } +} + /// AnyCompatible for void* unsafe impl AnyCompatible for *mut core::ffi::c_void { unsafe fn copy_to_any_view(src: &Self, data: &mut TVMFFIAny) { @@ -312,16 +348,12 @@ unsafe impl AnyCompatible for () { } unsafe fn check_any_strict(data: &TVMFFIAny) -> bool { - return data.type_index == TypeIndex::kTVMFFINone as i32; + data.type_index == TypeIndex::kTVMFFINone as i32 } - unsafe fn copy_from_any_view_after_check(_data: &TVMFFIAny) -> Self { - () - } + unsafe fn copy_from_any_view_after_check(_data: &TVMFFIAny) -> Self {} - unsafe fn move_from_any_after_check(_data: &mut TVMFFIAny) -> Self { - () - } + unsafe fn move_from_any_after_check(_data: &mut TVMFFIAny) -> Self {} unsafe fn try_cast_from_any_view(data: &TVMFFIAny) -> Result { if data.type_index == TypeIndex::kTVMFFINone as i32 { diff --git a/rust/tvm-ffi/tests/test_object.rs b/rust/tvm-ffi/tests/test_object.rs index 60378c2a..650f0b94 100644 --- a/rust/tvm-ffi/tests/test_object.rs +++ b/rust/tvm-ffi/tests/test_object.rs @@ -16,8 +16,8 @@ * specific language governing permissions and limitations * under the License. */ -use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::Arc; +use std::sync::atomic::{AtomicU32, Ordering}; use tvm_ffi::*; // must have repr(C) for the object header stays in the same position