From 108d008289f76c328eef334acdd513c6784c9026 Mon Sep 17 00:00:00 2001 From: Igor Malovitsa Date: Sat, 4 Apr 2026 16:59:19 +0000 Subject: [PATCH 1/6] Add sparse tensor primitives with einsum contraction MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Integrate N-dimensional sparse tensors into MORK expressions: - SparseTensorF64: PathMap-backed sparse tensor with BOB encoding, arbitrary rank, element-wise add/mul via lattice operations - einsum-dyn: copied into workspace as local crate, runtime Einstein summation with VM compiler supporting arbitrary N-D specs - Tensor sinks (tensor_collect, tensor_einsum, tensor_add, tensor_mul, tensor_free) using WriteResource::TensorStore through the standard resource infrastructure — no thread-locals - tensor_get/tensor_nnz pure functions via ExprSource.context pointer propagated through EvalScope - End-to-end test: load matrix data as S-expressions, collect into tensors via exec, einsum multiply, verify result MeTTa usage: (exec P1 (, (a $r $c $v)) (O (tensor_collect A $r $c $v))) (exec P2 (,) (O (tensor_einsum "ab,bc->ac" A B C))) Co-Authored-By: Claude Opus 4.6 (1M context) --- Cargo.toml | 1 + einsum-dyn/Cargo.toml | 4 + einsum-dyn/src/lib.rs | 1400 ++++++++++++++++++++++++++ einsum-dyn/src/sparse.rs | 1494 ++++++++++++++++++++++++++++ experiments/eval-ffi/src/source.rs | 5 +- experiments/eval/src/lib.rs | 7 +- kernel/Cargo.toml | 2 + kernel/src/lib.rs | 1 + kernel/src/sinks.rs | 326 +++++- kernel/src/space.rs | 20 +- kernel/src/sparse.rs | 361 +++++++ 11 files changed, 3613 insertions(+), 8 deletions(-) create mode 100644 einsum-dyn/Cargo.toml create mode 100644 einsum-dyn/src/lib.rs create mode 100644 einsum-dyn/src/sparse.rs create mode 100644 kernel/src/sparse.rs diff --git a/Cargo.toml b/Cargo.toml index 0ce14db..dba88d8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,6 +8,7 @@ members = [ "experiments/eval", "experiments/eval-ffi", "experiments/eval-examples", + "einsum-dyn/", ] default-members = ["kernel/"] diff --git a/einsum-dyn/Cargo.toml b/einsum-dyn/Cargo.toml new file mode 100644 index 0000000..6832d86 --- /dev/null +++ b/einsum-dyn/Cargo.toml @@ -0,0 +1,4 @@ +[package] +name = "einsum-dyn" +version = "0.1.0" +edition = "2024" diff --git a/einsum-dyn/src/lib.rs b/einsum-dyn/src/lib.rs new file mode 100644 index 0000000..cc6cbde --- /dev/null +++ b/einsum-dyn/src/lib.rs @@ -0,0 +1,1400 @@ +//! Runtime dynamic [Einstein summation](https://en.wikipedia.org/wiki/Einstein_notation) +//! for arbitrary N-dimensional arrays. +//! +//! # Functions +//! +//! | Function | Inputs | Output | +//! |---|---|---| +//! | [`einsum_ary`] | N arrays (`&[&In]`) | tensor (`&mut Out`) | +//! | [`einsum_binary`] | two arrays | tensor (`&mut Out`) | +//! | [`einsum_unary`] | one array | tensor (`&mut Out`) | +//! | [`einsum_binary_scalar`] | two arrays | scalar (returned) | +//! | [`einsum_unary_scalar`] | one array | scalar (returned) | +//! +//! [`einsum_ary`] is the general-purpose entry point — it accepts any number +//! of inputs and subsumes both [`einsum_binary`] and [`einsum_unary`]. The +//! specialised functions remain for ergonomics and because they are ~1.5× +//! faster (stack-only buffers vs heap `Vec`s for patterns/indices). +//! +//! For scalar output with `einsum_ary`, pass a 0-dimensional output tensor +//! (`ndim() == 0`, single element at `&[]`). +//! +//! # Spec format +//! +//! Specs use lowercase letters `a`–`z` as index names, with `->` separating +//! inputs from output: +//! +//! - `"ab,bc->ac"` — matrix multiply (contract over `b`) +//! - `"ab->ba"` — transpose (no contraction) +//! - `"i,i->"` — dot product (scalar output, empty after `->`) +//! - `"aa->"` — trace (scalar output) +//! - `"ab,bc,cd->ad"` — 3-input chain contraction (N-ary) +//! +//! Indices present in inputs but absent from the output are contracted +//! (summed over). All output indices must appear in at least one input. +//! +//! # Implementing `NDIndex` +//! +//! Any type can be used with these functions by implementing [`NDIndex`]: +//! +//! ``` +//! use einsum_dyn::{NDIndex, einsum_ary, einsum_binary, einsum_unary, einsum_binary_scalar, einsum_unary_scalar}; +//! +//! struct MyTensor { +//! data: Vec, +//! shape: Vec, +//! } +//! +//! impl MyTensor { +//! fn new(shape: Vec) -> Self { +//! let n = shape.iter().product(); +//! Self { data: vec![0.0; n], shape } +//! } +//! fn linear_index(&self, ix: &[usize]) -> usize { +//! let mut idx = 0; +//! let mut stride = 1; +//! for (&k, &dim) in ix.iter().rev().zip(self.shape.iter().rev()) { +//! idx += k * stride; +//! stride *= dim; +//! } +//! idx +//! } +//! } +//! +//! impl NDIndex for MyTensor { +//! fn ndim(&self) -> usize { self.shape.len() } +//! fn dim(&self, axis: usize) -> usize { self.shape[axis] } +//! fn get(&self, ix: &[usize]) -> f64 { self.data[self.linear_index(ix)] } +//! fn set(&mut self, ix: &[usize], v: f64) { +//! let i = self.linear_index(ix); +//! self.data[i] = v; +//! } +//! } +//! +//! // Matrix multiply: C = A × B +//! let mut a = MyTensor::new(vec![2, 3]); +//! a.data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; +//! let mut b = MyTensor::new(vec![3, 2]); +//! b.data = vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0]; +//! let mut c = MyTensor::new(vec![2, 2]); +//! einsum_binary("ab,bc->ac", &a, &b, &mut c).unwrap(); +//! assert_eq!(c.data, vec![58.0, 64.0, 139.0, 154.0]); +//! +//! // Transpose +//! let mut t = MyTensor::new(vec![3, 2]); +//! einsum_unary("ab->ba", &a, &mut t).unwrap(); +//! assert_eq!(t.data, vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]); +//! +//! // Dot product (scalar) +//! let mut x = MyTensor::new(vec![3]); +//! x.data = vec![1.0, 2.0, 3.0]; +//! let mut y = MyTensor::new(vec![3]); +//! y.data = vec![4.0, 5.0, 6.0]; +//! let dot: f64 = einsum_binary_scalar("i,i->", &x, &y).unwrap(); +//! assert_eq!(dot, 32.0); +//! +//! // Trace (scalar) +//! let mut m = MyTensor::new(vec![2, 2]); +//! m.data = vec![1.0, 2.0, 3.0, 4.0]; +//! let tr: f64 = einsum_unary_scalar("aa->", &m).unwrap(); +//! assert_eq!(tr, 5.0); +//! +//! // N-ary: 3-input chain A(2×3) × B(3×2) × C(2×2) via einsum_ary +//! let mut d = MyTensor::new(vec![2, 2]); +//! einsum_ary("ab,bc,cd->ad", &[&a, &b, &c], &mut d).unwrap(); +//! +//! // N-ary with scalar output (0-dim tensor) +//! let mut scalar = MyTensor::new(vec![]); +//! einsum_ary("i,i->", &[&x, &y], &mut scalar).unwrap(); +//! assert_eq!(scalar.data, vec![32.0]); +//! ``` +//! +//! # Feature flags +//! +//! - **`ndarray`** — implements `NDIndex` for [`ndarray::ArrayD`], so you +//! can pass dynamic-dimension ndarray arrays directly. + +pub mod sparse; + +use std::fmt; +use std::ops::{AddAssign, Mul}; + +/// Trait for N-dimensional array access. +/// +/// Implement this for your tensor type to use the `einsum_*` functions. +/// All index slices are ordered left-to-right (outermost dimension first). +pub trait NDIndex { + fn ndim(&self) -> usize; + fn dim(&self, axis: usize) -> usize; + fn get(&self, indices: &[usize]) -> T; + fn set(&mut self, indices: &[usize], val: T); + + /// Returns `None` for structurally absent (zero) entries. + /// + /// Dense implementations use the default, which wraps `get` in `Some`. + /// Sparse implementations should override this to return `None` for + /// entries not present in the sparse structure. + fn get_opt(&self, indices: &[usize]) -> Option { + Some(self.get(indices)) + } + + /// Whether this array is a 2D sparse matrix supporting row iteration + /// via `sparse_row_nnz` and `sparse_row_entry`. Default: false. + fn is_sparse_2d(&self) -> bool { false } + + /// Number of non-zero entries in the given row. + /// Only meaningful when `is_sparse_2d()` returns true. + fn sparse_row_nnz(&self, _row: usize) -> usize { 0 } + + /// Get the `idx`-th non-zero entry in the given row as `(col, value)`. + /// Only meaningful when `is_sparse_2d()` returns true. + fn sparse_row_entry(&self, _row: usize, _idx: usize) -> (usize, T) { + panic!("sparse_row_entry called on non-sparse array") + } +} + +#[cfg(feature = "ndarray")] +impl NDIndex for ndarray::ArrayD { + fn ndim(&self) -> usize { + self.ndim() + } + fn dim(&self, axis: usize) -> usize { + self.shape()[axis] + } + fn get(&self, ix: &[usize]) -> T { + self[ndarray::IxDyn(ix)] + } + fn set(&mut self, ix: &[usize], val: T) { + self[ndarray::IxDyn(ix)] = val; + } +} + +/// Error returned when an einsum spec string is invalid. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum InvalidSpec { + MissingArrow, + InvalidIndex { ch: char }, + WrongInputCount { expected: usize, got: usize }, + EmptyInput { input: usize }, + UnboundOutputIndex { index: char }, + InputNdimMismatch { input: usize, array_ndim: usize, spec_ndim: usize }, + DimensionMismatch { index: char, expected: usize, got: usize }, + OutputNdimMismatch { array_ndim: usize, spec_ndim: usize }, + OutputDimMismatch { axis: usize, expected: usize, got: usize }, + NonEmptyScalarOutput, +} + +/// Convert a slot index back to its letter for error messages. +fn slot_to_char(s: u8) -> char { + (s + b'a') as char +} + +impl fmt::Display for InvalidSpec { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::MissingArrow => write!(f, "missing '->'"), + Self::InvalidIndex { ch } => write!(f, "index '{ch}' is not a lowercase letter"), + Self::WrongInputCount { expected, got } => { + write!(f, "expected {expected} input(s), got {got}") + } + Self::EmptyInput { input } => write!(f, "input {input} has no indices"), + Self::UnboundOutputIndex { index } => { + write!(f, "output index '{index}' does not appear in any input") + } + Self::InputNdimMismatch { input, array_ndim, spec_ndim } => { + write!(f, "input {input} has {array_ndim} dimensions but spec has {spec_ndim} indices") + } + Self::DimensionMismatch { index, expected, got } => { + write!(f, "dimension mismatch for index '{index}': {expected} vs {got}") + } + Self::OutputNdimMismatch { array_ndim, spec_ndim } => { + write!(f, "output has {array_ndim} dimensions but spec has {spec_ndim} output indices") + } + Self::OutputDimMismatch { axis, expected, got } => { + write!(f, "output dimension {axis} is {got} but expected {expected}") + } + Self::NonEmptyScalarOutput => { + write!(f, "scalar output requires empty output indices (use '...->')") + } + } + } +} + +impl std::error::Error for InvalidSpec {} + +/// Parsed einsum specification. All index chars are stored as slot indices +/// (`b'a'` → 0, `b'b'` → 1, ..., `b'z'` → 25). +/// +/// The RHS of the spec may contain multiple comma-separated output groups +/// (e.g. `"ab,bc->ac,ca"`). Single-output functions check `outputs.len() == 1`. +pub(crate) struct Spec { + inputs: Vec>, + outputs: Vec>, +} + +impl Spec { + /// Convenience: returns the first (and usually only) output pattern. + pub(crate) fn output(&self) -> &[u8] { + &self.outputs[0] + } + + /// All unique output slots across all outputs. + pub(crate) fn all_output_slots(&self) -> Vec { + let mut seen = [false; 26]; + let mut slots = Vec::new(); + for out in &self.outputs { + for &s in out { + if !seen[s as usize] { + seen[s as usize] = true; + slots.push(s); + } + } + } + slots + } +} + +pub(crate) fn parse_spec(spec: &str, expected_inputs: usize) -> Result { + let spec = spec.replace(' ', ""); + + let (lhs, rhs) = spec + .split_once("->") + .ok_or(InvalidSpec::MissingArrow)?; + + let mut inputs: Vec> = Vec::new(); + for part in lhs.split(',') { + let mut slots = Vec::new(); + for ch in part.bytes() { + if !ch.is_ascii_lowercase() { + return Err(InvalidSpec::InvalidIndex { ch: ch as char }); + } + slots.push(ch - b'a'); + } + inputs.push(slots); + } + + if inputs.len() != expected_inputs { + return Err(InvalidSpec::WrongInputCount { + expected: expected_inputs, + got: inputs.len(), + }); + } + + for (i, inp) in inputs.iter().enumerate() { + if inp.is_empty() { + return Err(InvalidSpec::EmptyInput { input: i }); + } + } + + let mut outputs: Vec> = Vec::new(); + for part in rhs.split(',') { + let mut slots = Vec::new(); + for ch in part.bytes() { + if !ch.is_ascii_lowercase() { + return Err(InvalidSpec::InvalidIndex { ch: ch as char }); + } + slots.push(ch - b'a'); + } + outputs.push(slots); + } + + // Validate: every output index must appear in at least one input + let mut seen = [false; 26]; + for inp in &inputs { + for &s in inp { + seen[s as usize] = true; + } + } + for out in &outputs { + for &s in out { + if !seen[s as usize] { + return Err(InvalidSpec::UnboundOutputIndex { index: slot_to_char(s) }); + } + } + } + + Ok(Spec { inputs, outputs }) +} + +/// Validates that array dimensions match the spec. Returns dims as a `[usize; 26]` +/// array (indexed by slot). Unused slots are 0. +pub(crate) fn validate_dims>( + spec: &Spec, + arrays: &[&Arr], +) -> Result<[usize; 26], InvalidSpec> { + for (i, (inp, arr)) in spec.inputs.iter().zip(arrays.iter()).enumerate() { + if arr.ndim() != inp.len() { + return Err(InvalidSpec::InputNdimMismatch { + input: i, + array_ndim: arr.ndim(), + spec_ndim: inp.len(), + }); + } + } + + let mut dims = [0usize; 26]; + let mut set = [false; 26]; + for (pi, inp) in spec.inputs.iter().enumerate() { + for (pos, &s) in inp.iter().enumerate() { + let si = s as usize; + let d = arrays[pi].dim(pos); + if set[si] { + if dims[si] != d { + return Err(InvalidSpec::DimensionMismatch { + index: slot_to_char(s), + expected: dims[si], + got: d, + }); + } + } else { + dims[si] = d; + set[si] = true; + } + } + } + + Ok(dims) +} + +/// Collects all unique indices in order of first appearance, as a SlotList. +fn all_slots_ordered(spec: &Spec) -> SlotList { + let mut seen = [false; 26]; + let mut slots = [0u8; 26]; + let mut len = 0u8; + for inp in &spec.inputs { + for &s in inp { + if !seen[s as usize] { + seen[s as usize] = true; + slots[len as usize] = s; + len += 1; + } + } + } + SlotList { slots, len } +} + +/// Stack-only index buffer: fixed array + length, no heap. +struct Idx { + data: [usize; 26], + len: u8, +} + +impl Idx { + const ZERO: Self = Idx { data: [0; 26], len: 0 }; + + #[inline(always)] + fn as_slice(&self) -> &[usize] { + &self.data[..self.len as usize] + } +} + +/// Precomputed gather pattern: slot indices stored on the stack. +struct Pattern { + slots: [u8; 26], + len: u8, +} + +impl Pattern { + fn from_slots(slot_indices: &[u8]) -> Self { + let mut slots = [0u8; 26]; + slots[..slot_indices.len()].copy_from_slice(slot_indices); + Pattern { + slots, + len: slot_indices.len() as u8, + } + } + + /// Gather index values from `vals` into `out` according to this pattern. + #[inline(always)] + fn gather(&self, vals: &[usize; 26], out: &mut Idx) { + out.len = self.len; + for i in 0..self.len as usize { + out.data[i] = vals[self.slots[i] as usize]; + } + } +} + +/// Precomputed loop-slot list stored on the stack. +struct SlotList { + slots: [u8; 26], + len: u8, +} + +impl SlotList { + fn from_slots(slot_indices: &[u8]) -> Self { + let mut slots = [0u8; 26]; + slots[..slot_indices.len()].copy_from_slice(slot_indices); + SlotList { + slots, + len: slot_indices.len() as u8, + } + } + + fn as_slice(&self) -> &[u8] { + &self.slots[..self.len as usize] + } + + fn contains(&self, s: u8) -> bool { + self.as_slice().contains(&s) + } + + fn filtered_complement(all: &[u8], free: &SlotList) -> Self { + let mut slots = [0u8; 26]; + let mut len = 0u8; + for &s in all { + if !free.contains(s) { + slots[len as usize] = s; + len += 1; + } + } + SlotList { slots, len } + } +} + +/// Recursive loop nest over slots. `loop_slots[i]` is a slot index, +/// `dims` and `vals` are flat [usize; 26] arrays. +#[inline(always)] +fn loop_nest( + loop_slots: &[u8], + dims: &[usize; 26], + vals: &mut [usize; 26], + emit: &mut impl FnMut(&[usize; 26]), +) { + if loop_slots.is_empty() { + emit(vals); + return; + } + let s = loop_slots[0] as usize; + let rest = &loop_slots[1..]; + let n = dims[s]; + for v in 0..n { + vals[s] = v; + loop_nest(rest, dims, vals, emit); + } +} + +// Iterative variant using an explicit counter stack. +// Benchmarked ~same as recursive. +#[cfg(any())] +#[inline(always)] +fn loop_nest_iterative( + loop_slots: &[u8], + dims: &[usize; 26], + vals: &mut [usize; 26], + emit: &mut impl FnMut(&[usize; 26]), +) { + let depth = loop_slots.len(); + if depth == 0 { + emit(vals); + return; + } + let mut counters = [0usize; 26]; + for i in 0..depth { + vals[loop_slots[i] as usize] = 0; + } + loop { + emit(vals); + let mut level = depth - 1; + loop { + let s = loop_slots[level] as usize; + counters[level] += 1; + if counters[level] < dims[s] { + vals[s] = counters[level]; + break; + } + counters[level] = 0; + vals[s] = 0; + if level == 0 { + return; + } + level -= 1; + } + } +} + +/// Validate output array dimensions against the spec. +pub(crate) fn validate_output>( + spec: &Spec, + dims: &[usize; 26], + out: &Arr, +) -> Result<(), InvalidSpec> { + if out.ndim() != spec.output().len() { + return Err(InvalidSpec::OutputNdimMismatch { + array_ndim: out.ndim(), + spec_ndim: spec.output().len(), + }); + } + for (pos, &s) in spec.output().iter().enumerate() { + if out.dim(pos) != dims[s as usize] { + return Err(InvalidSpec::OutputDimMismatch { + axis: pos, + expected: dims[s as usize], + got: out.dim(pos), + }); + } + } + Ok(()) +} + +/// `einsum_ary(spec, inputs, out)` — N-ary einsum with tensor output. +/// +/// Generalises [`einsum_binary`] and [`einsum_unary`] to an arbitrary number +/// of inputs. The spec must contain exactly `inputs.len()` comma-separated +/// input index groups. +/// +/// For scalar output, pass a 0-dimensional output tensor (`ndim() == 0`, +/// one element at index `&[]`) and use an empty output spec (e.g. `"i,i->"`). +/// +/// The specialised binary/unary functions are ~1.5× faster due to stack-only +/// buffers; prefer them for the 1- and 2-input cases on hot paths. +/// +/// ``` +/// # use einsum_dyn::{NDIndex, einsum_ary}; +/// # struct T { m: Vec, d: Vec } +/// # impl T { fn new(d: Vec) -> Self { let n = d.iter().product(); Self { m: vec![0.0; n], d } } fn li(&self, ix: &[usize]) -> usize { let mut i=0; let mut s=1; for (&k,&d) in ix.iter().rev().zip(self.d.iter().rev()) { i+=k*s; s*=d; } i } } +/// # impl NDIndex for T { fn ndim(&self)->usize{self.d.len()} fn dim(&self,a:usize)->usize{self.d[a]} fn get(&self,ix:&[usize])->f32{self.m[self.li(ix)]} fn set(&mut self,ix:&[usize],v:f32){let i=self.li(ix);self.m[i]=v;} } +/// let mut a = T::new(vec![2, 3]); +/// a.m = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; +/// let mut b = T::new(vec![3, 2]); +/// b.m = vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0]; +/// let mut c = T::new(vec![2, 2]); +/// einsum_ary("ab,bc->ac", &[&a, &b], &mut c).unwrap(); +/// assert_eq!(c.m, vec![58.0, 64.0, 139.0, 154.0]); +/// ``` +pub fn einsum_ary(spec: &str, inputs: &[&In], out: &mut Out) -> Result<(), InvalidSpec> +where + T: Default + Copy + AddAssign + Mul, + In: NDIndex, + Out: NDIndex, +{ + let spec = parse_spec(spec, inputs.len())?; + let dims = validate_dims(&spec, inputs)?; + validate_output(&spec, &dims, out)?; + + let free_slots = SlotList::from_slots(&spec.output()); + let all = all_slots_ordered(&spec); + let contracted_slots = SlotList::filtered_complement(all.as_slice(), &free_slots); + + let pats: Vec = spec.inputs.iter().map(|inp| Pattern::from_slots(inp)).collect(); + let pat_out = Pattern::from_slots(&spec.output()); + + let n = inputs.len(); + let mut vals = [0usize; 26]; + let mut bufs: Vec = (0..n).map(|_| Idx::ZERO).collect(); + let mut buf_out = Idx::ZERO; + + if contracted_slots.len == 0 { + // No contraction — direct assignment + loop_nest(free_slots.as_slice(), &dims, &mut vals, &mut |vals| { + for i in 0..n { + pats[i].gather(vals, &mut bufs[i]); + } + // Multiply all input values; if any is sparse-absent, result is zero + let first = match inputs[0].get_opt(bufs[0].as_slice()) { + Some(v) => v, + None => { pat_out.gather(vals, &mut buf_out); out.set(buf_out.as_slice(), Default::default()); return; } + }; + let mut product = first; + for i in 1..n { + match inputs[i].get_opt(bufs[i].as_slice()) { + Some(v) => product = product * v, + None => { pat_out.gather(vals, &mut buf_out); out.set(buf_out.as_slice(), Default::default()); return; } + } + } + pat_out.gather(vals, &mut buf_out); + out.set(buf_out.as_slice(), product); + }); + } else { + // With contraction — accumulate per output element + loop_nest(free_slots.as_slice(), &dims, &mut vals, &mut |free_vals| { + let mut acc: T = Default::default(); + let mut inner_vals = *free_vals; + loop_nest( + contracted_slots.as_slice(), + &dims, + &mut inner_vals, + &mut |vals| { + for i in 0..n { + pats[i].gather(vals, &mut bufs[i]); + } + let first = match inputs[0].get_opt(bufs[0].as_slice()) { + Some(v) => v, + None => return, + }; + let mut product = first; + for i in 1..n { + match inputs[i].get_opt(bufs[i].as_slice()) { + Some(v) => product = product * v, + None => return, + } + } + acc += product; + }, + ); + pat_out.gather(free_vals, &mut buf_out); + out.set(buf_out.as_slice(), acc); + }); + } + + Ok(()) +} + +/// `einsum_binary(spec, a, b, out)` — binary einsum with tensor output. +/// +/// Spec format: `"ab,bc->ac"` (numpy-style). +/// All indices in the output must appear in at least one input. +/// Indices present in inputs but absent from the output are contracted (summed over). +/// The output array must already have the correct shape. +pub fn einsum_binary(spec: &str, a: &In, b: &In, out: &mut Out) -> Result<(), InvalidSpec> +where + T: Default + Copy + AddAssign + Mul, + In: NDIndex, + Out: NDIndex, +{ + let spec = parse_spec(spec, 2)?; + let dims = validate_dims(&spec, &[a, b])?; + validate_output(&spec, &dims, out)?; + + let free_slots = SlotList::from_slots(&spec.output()); + let all = all_slots_ordered(&spec); + let contracted_slots = SlotList::filtered_complement(all.as_slice(), &free_slots); + + let pat_a = Pattern::from_slots(&spec.inputs[0]); + let pat_b = Pattern::from_slots(&spec.inputs[1]); + let pat_out = Pattern::from_slots(&spec.output()); + + let mut vals = [0usize; 26]; + let mut buf_a = Idx::ZERO; + let mut buf_b = Idx::ZERO; + let mut buf_out = Idx::ZERO; + + if contracted_slots.len == 0 { + // No contraction — direct assignment + loop_nest(free_slots.as_slice(), &dims, &mut vals, &mut |vals| { + pat_a.gather(vals, &mut buf_a); + pat_b.gather(vals, &mut buf_b); + pat_out.gather(vals, &mut buf_out); + let v = match (a.get_opt(buf_a.as_slice()), b.get_opt(buf_b.as_slice())) { + (Some(av), Some(bv)) => av * bv, + _ => Default::default(), + }; + out.set(buf_out.as_slice(), v); + }); + } else { + // With contraction — accumulate per output element + loop_nest(free_slots.as_slice(), &dims, &mut vals, &mut |free_vals| { + let mut acc: T = Default::default(); + let mut inner_vals = *free_vals; + loop_nest( + contracted_slots.as_slice(), + &dims, + &mut inner_vals, + &mut |vals| { + pat_a.gather(vals, &mut buf_a); + pat_b.gather(vals, &mut buf_b); + if let (Some(av), Some(bv)) = + (a.get_opt(buf_a.as_slice()), b.get_opt(buf_b.as_slice())) + { + acc += av * bv; + } + }, + ); + pat_out.gather(free_vals, &mut buf_out); + out.set(buf_out.as_slice(), acc); + }); + } + + Ok(()) +} + +/// `einsum_unary(spec, a, out)` — unary einsum with tensor output. +/// +/// Spec format: `"ab->ba"` (numpy-style). +pub fn einsum_unary(spec: &str, a: &In, out: &mut Out) -> Result<(), InvalidSpec> +where + T: Default + Copy + AddAssign + Mul, + In: NDIndex, + Out: NDIndex, +{ + let spec = parse_spec(spec, 1)?; + let dims = validate_dims(&spec, &[a])?; + validate_output(&spec, &dims, out)?; + + let free_slots = SlotList::from_slots(&spec.output()); + let all = all_slots_ordered(&spec); + let contracted_slots = SlotList::filtered_complement(all.as_slice(), &free_slots); + + let pat_a = Pattern::from_slots(&spec.inputs[0]); + let pat_out = Pattern::from_slots(&spec.output()); + let mut vals = [0usize; 26]; + let mut buf_a = Idx::ZERO; + let mut buf_out = Idx::ZERO; + + if contracted_slots.len == 0 { + loop_nest(free_slots.as_slice(), &dims, &mut vals, &mut |vals| { + pat_a.gather(vals, &mut buf_a); + pat_out.gather(vals, &mut buf_out); + let v = a.get_opt(buf_a.as_slice()).unwrap_or_default(); + out.set(buf_out.as_slice(), v); + }); + } else { + loop_nest(free_slots.as_slice(), &dims, &mut vals, &mut |free_vals| { + let mut acc: T = Default::default(); + let mut inner_vals = *free_vals; + loop_nest( + contracted_slots.as_slice(), + &dims, + &mut inner_vals, + &mut |vals| { + pat_a.gather(vals, &mut buf_a); + if let Some(av) = a.get_opt(buf_a.as_slice()) { + acc += av; + } + }, + ); + pat_out.gather(free_vals, &mut buf_out); + out.set(buf_out.as_slice(), acc); + }); + } + + Ok(()) +} + +/// `einsum_binary_scalar(spec, a, b)` — binary einsum with scalar output. +/// +/// Spec format: `"ab,ab->"` (empty output = scalar). +pub fn einsum_binary_scalar(spec: &str, a: &Arr, b: &Arr) -> Result +where + T: Default + Copy + AddAssign + Mul, + Arr: NDIndex, +{ + let spec = parse_spec(spec, 2)?; + let dims = validate_dims(&spec, &[a, b])?; + + if !spec.output().is_empty() { + return Err(InvalidSpec::NonEmptyScalarOutput); + } + + let all = all_slots_ordered(&spec); + let pat_a = Pattern::from_slots(&spec.inputs[0]); + let pat_b = Pattern::from_slots(&spec.inputs[1]); + let mut vals = [0usize; 26]; + let mut buf_a = Idx::ZERO; + let mut buf_b = Idx::ZERO; + let mut acc: T = Default::default(); + + loop_nest(all.as_slice(), &dims, &mut vals, &mut |vals| { + pat_a.gather(vals, &mut buf_a); + pat_b.gather(vals, &mut buf_b); + if let (Some(av), Some(bv)) = + (a.get_opt(buf_a.as_slice()), b.get_opt(buf_b.as_slice())) + { + acc += av * bv; + } + }); + + Ok(acc) +} + +/// `einsum_unary_scalar(spec, a)` — unary einsum with scalar output. +/// +/// Spec format: `"aa->"` (empty output = scalar). +pub fn einsum_unary_scalar(spec: &str, a: &Arr) -> Result +where + T: Default + Copy + AddAssign + Mul, + Arr: NDIndex, +{ + let spec = parse_spec(spec, 1)?; + let dims = validate_dims(&spec, &[a])?; + + if !spec.output().is_empty() { + return Err(InvalidSpec::NonEmptyScalarOutput); + } + + let all = all_slots_ordered(&spec); + let pat_a = Pattern::from_slots(&spec.inputs[0]); + let mut vals = [0usize; 26]; + let mut buf_a = Idx::ZERO; + let mut acc: T = Default::default(); + + loop_nest(all.as_slice(), &dims, &mut vals, &mut |vals| { + pat_a.gather(vals, &mut buf_a); + if let Some(av) = a.get_opt(buf_a.as_slice()) { + acc += av; + } + }); + + Ok(acc) +} + +#[cfg(test)] +mod tests { + use super::*; + + /// Minimal dense tensor for testing. + struct Tensor { + m: Vec, + d: Vec, + } + + impl Tensor { + fn new(d: Vec) -> Self { + let n: usize = d.iter().product(); + Self { + m: vec![0.0; n], + d, + } + } + + fn linear_index(&self, ix: &[usize]) -> usize { + let mut idx = 0usize; + let mut stride = 1usize; + for (&k, &dim) in ix.iter().rev().zip(self.d.iter().rev()) { + idx += k * stride; + stride *= dim; + } + idx + } + } + + impl NDIndex for Tensor { + fn ndim(&self) -> usize { + self.d.len() + } + + fn dim(&self, axis: usize) -> usize { + self.d[axis] + } + + fn get(&self, ix: &[usize]) -> f32 { + self.m[self.linear_index(ix)] + } + + fn set(&mut self, ix: &[usize], v: f32) { + let i = self.linear_index(ix); + self.m[i] = v; + } + } + + fn set_matrix(t: &mut Tensor, vals: &[f32]) { + for (i, &v) in vals.iter().enumerate() { + t.m[i] = v; + } + } + + #[test] + fn test_matmul() { + let mut a = Tensor::new(vec![2, 3]); + let mut b = Tensor::new(vec![3, 2]); + set_matrix(&mut a, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]); + set_matrix(&mut b, &[7.0, 8.0, 9.0, 10.0, 11.0, 12.0]); + + let mut c = Tensor::new(vec![2, 2]); + einsum_binary("ab,bc->ac", &a, &b, &mut c).unwrap(); + + assert_eq!(c.get(&[0, 0]), 58.0); + assert_eq!(c.get(&[0, 1]), 64.0); + assert_eq!(c.get(&[1, 0]), 139.0); + assert_eq!(c.get(&[1, 1]), 154.0); + } + + #[test] + fn test_transpose() { + let mut a = Tensor::new(vec![2, 3]); + set_matrix(&mut a, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]); + + let mut t = Tensor::new(vec![3, 2]); + einsum_unary("ab->ba", &a, &mut t).unwrap(); + + assert_eq!(t.get(&[0, 0]), 1.0); + assert_eq!(t.get(&[0, 1]), 4.0); + assert_eq!(t.get(&[1, 0]), 2.0); + assert_eq!(t.get(&[1, 1]), 5.0); + assert_eq!(t.get(&[2, 0]), 3.0); + assert_eq!(t.get(&[2, 1]), 6.0); + } + + #[test] + fn test_outer_product() { + let mut a = Tensor::new(vec![3]); + set_matrix(&mut a, &[1.0, 2.0, 3.0]); + let mut b = Tensor::new(vec![2]); + set_matrix(&mut b, &[4.0, 5.0]); + + let mut c = Tensor::new(vec![3, 2]); + einsum_binary("a,b->ab", &a, &b, &mut c).unwrap(); + + assert_eq!(c.get(&[0, 0]), 4.0); + assert_eq!(c.get(&[0, 1]), 5.0); + assert_eq!(c.get(&[1, 0]), 8.0); + assert_eq!(c.get(&[1, 1]), 10.0); + assert_eq!(c.get(&[2, 0]), 12.0); + assert_eq!(c.get(&[2, 1]), 15.0); + } + + #[test] + fn test_vecmat() { + let mut v = Tensor::new(vec![2]); + set_matrix(&mut v, &[1.0, 2.0]); + let mut m = Tensor::new(vec![2, 2]); + set_matrix(&mut m, &[3.0, 4.0, 5.0, 6.0]); + + let mut r = Tensor::new(vec![2]); + einsum_binary("a,ab->b", &v, &m, &mut r).unwrap(); + + assert_eq!(r.get(&[0]), 13.0); + assert_eq!(r.get(&[1]), 16.0); + } + + #[test] + fn test_dot() { + let mut a = Tensor::new(vec![4]); + let mut b = Tensor::new(vec![4]); + set_matrix(&mut a, &[1.0, 2.0, 3.0, 4.0]); + set_matrix(&mut b, &[5.0, 6.0, 7.0, 8.0]); + + let result: f32 = einsum_binary_scalar("i,i->", &a, &b).unwrap(); + assert_eq!(result, 70.0); + } + + #[test] + fn test_trace() { + let mut m = Tensor::new(vec![3, 3]); + set_matrix( + &mut m, + &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0], + ); + + let result: f32 = einsum_unary_scalar("aa->", &m).unwrap(); + assert_eq!(result, 15.0); + } + + #[test] + fn test_frobenius2() { + let mut a = Tensor::new(vec![2, 2]); + set_matrix(&mut a, &[1.0, 2.0, 3.0, 4.0]); + + let result: f32 = einsum_binary_scalar("ab,ab->", &a, &a).unwrap(); + assert_eq!(result, 30.0); + } + + #[test] + fn test_attention() { + let (b, h, q_len, k_len, dim) = (2, 2, 3, 2, 4); + let mut q = Tensor::new(vec![b, h, q_len, dim]); + let mut k = Tensor::new(vec![b, h, k_len, dim]); + + for bi in 0..b { + for hi in 0..h { + for qi in 0..q_len { + for di in 0..dim { + let v = + (bi + 1) as f32 * (hi + 1) as f32 * (qi + 1) as f32 + di as f32; + q.set(&[bi, hi, qi, di], v); + } + } + for ki in 0..k_len { + for di in 0..dim { + let v = + (bi + 1) as f32 * (hi + 1) as f32 * (ki + 1) as f32 * (di + 1) as f32; + k.set(&[bi, hi, ki, di], v); + } + } + } + } + + let mut out = Tensor::new(vec![b, h, q_len, k_len]); + einsum_binary("bhqd,bhkd->bhqk", &q, &k, &mut out).unwrap(); + + for bi in 0..b { + for hi in 0..h { + for qi in 0..q_len { + for ki in 0..k_len { + let mut expected = 0.0f32; + for di in 0..dim { + expected += q.get(&[bi, hi, qi, di]) * k.get(&[bi, hi, ki, di]); + } + let actual = out.get(&[bi, hi, qi, ki]); + assert!( + (actual - expected).abs() < 1e-3, + "mismatch at [{bi},{hi},{qi},{ki}]: got {actual}, expected {expected}" + ); + } + } + } + } + } + + #[test] + fn test_err_missing_arrow() { + let a = Tensor::new(vec![2, 3]); + let b = Tensor::new(vec![3, 2]); + let mut c = Tensor::new(vec![2, 2]); + assert_eq!( + einsum_binary("ab,bc", &a, &b, &mut c).unwrap_err(), + InvalidSpec::MissingArrow + ); + } + + #[test] + fn test_err_invalid_index() { + let a = Tensor::new(vec![2, 3]); + let b = Tensor::new(vec![3, 2]); + let mut c = Tensor::new(vec![2, 2]); + assert_eq!( + einsum_binary("aB,bc->ac", &a, &b, &mut c).unwrap_err(), + InvalidSpec::InvalidIndex { ch: 'B' } + ); + // Invalid char in output + assert_eq!( + einsum_binary("ab,bc->a1", &a, &b, &mut c).unwrap_err(), + InvalidSpec::InvalidIndex { ch: '1' } + ); + } + + #[test] + fn test_err_wrong_input_count() { + let a = Tensor::new(vec![2, 3]); + let b = Tensor::new(vec![3, 2]); + let mut c = Tensor::new(vec![2, 2]); + // Binary function but spec has 1 input + assert_eq!( + einsum_binary("ab->ab", &a, &b, &mut c).unwrap_err(), + InvalidSpec::WrongInputCount { expected: 2, got: 1 } + ); + // Unary function but spec has 2 inputs + assert_eq!( + einsum_unary("ab,bc->ac", &a, &mut c).unwrap_err(), + InvalidSpec::WrongInputCount { expected: 1, got: 2 } + ); + } + + #[test] + fn test_err_empty_input() { + let a = Tensor::new(vec![2, 3]); + let b = Tensor::new(vec![3, 2]); + let mut c = Tensor::new(vec![2, 2]); + assert_eq!( + einsum_binary(",bc->bc", &a, &b, &mut c).unwrap_err(), + InvalidSpec::EmptyInput { input: 0 } + ); + } + + #[test] + fn test_err_unbound_output_index() { + let a = Tensor::new(vec![2, 3]); + let b = Tensor::new(vec![3, 2]); + let mut c = Tensor::new(vec![2, 2]); + assert_eq!( + einsum_binary("ab,bc->az", &a, &b, &mut c).unwrap_err(), + InvalidSpec::UnboundOutputIndex { index: 'z' } + ); + } + + #[test] + fn test_err_input_ndim_mismatch() { + // a is 2D but spec says 3 indices + let a = Tensor::new(vec![2, 3]); + let b = Tensor::new(vec![3, 2]); + let mut c = Tensor::new(vec![2, 2]); + assert_eq!( + einsum_binary("abc,bd->ad", &a, &b, &mut c).unwrap_err(), + InvalidSpec::InputNdimMismatch { input: 0, array_ndim: 2, spec_ndim: 3 } + ); + } + + #[test] + fn test_err_dimension_mismatch() { + // a is 2×3, b is 3×2, spec says first dims must match (a=2 vs a=3) + let a = Tensor::new(vec![2, 3]); + let b = Tensor::new(vec![3, 2]); + let mut c = Tensor::new(vec![2, 2]); + assert_eq!( + einsum_binary("ab,ac->bc", &a, &b, &mut c).unwrap_err(), + InvalidSpec::DimensionMismatch { index: 'a', expected: 2, got: 3 } + ); + } + + #[test] + fn test_err_output_ndim_mismatch() { + let a = Tensor::new(vec![2, 3]); + let b = Tensor::new(vec![3, 2]); + // Output is 1D but spec says 2D + let mut c = Tensor::new(vec![2]); + assert_eq!( + einsum_binary("ab,bc->ac", &a, &b, &mut c).unwrap_err(), + InvalidSpec::OutputNdimMismatch { array_ndim: 1, spec_ndim: 2 } + ); + } + + #[test] + fn test_err_output_dim_mismatch() { + let a = Tensor::new(vec![2, 3]); + let b = Tensor::new(vec![3, 2]); + // Output should be 2×2 but we give 2×3 + let mut c = Tensor::new(vec![2, 3]); + assert_eq!( + einsum_binary("ab,bc->ac", &a, &b, &mut c).unwrap_err(), + InvalidSpec::OutputDimMismatch { axis: 1, expected: 2, got: 3 } + ); + } + + #[test] + fn test_err_non_empty_scalar_output() { + let a = Tensor::new(vec![2, 3]); + let b = Tensor::new(vec![2, 3]); + assert_eq!( + einsum_binary_scalar::("ab,ab->a", &a, &b).unwrap_err(), + InvalidSpec::NonEmptyScalarOutput + ); + assert_eq!( + einsum_unary_scalar::("ab->a", &a).unwrap_err(), + InvalidSpec::NonEmptyScalarOutput + ); + } + + #[test] + fn test_unary_row_sum() { + let mut a = Tensor::new(vec![2, 3]); + set_matrix(&mut a, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]); + + let mut out = Tensor::new(vec![2]); + einsum_unary("ab->a", &a, &mut out).unwrap(); + + assert_eq!(out.get(&[0]), 6.0); + assert_eq!(out.get(&[1]), 15.0); + } + + // --- einsum_ary tests --- + + #[test] + fn test_ary_matmul_as_binary() { + // einsum_ary with 2 inputs should match einsum_binary + let mut a = Tensor::new(vec![2, 3]); + let mut b = Tensor::new(vec![3, 2]); + set_matrix(&mut a, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]); + set_matrix(&mut b, &[7.0, 8.0, 9.0, 10.0, 11.0, 12.0]); + + let mut c = Tensor::new(vec![2, 2]); + einsum_ary("ab,bc->ac", &[&a, &b], &mut c).unwrap(); + + assert_eq!(c.get(&[0, 0]), 58.0); + assert_eq!(c.get(&[0, 1]), 64.0); + assert_eq!(c.get(&[1, 0]), 139.0); + assert_eq!(c.get(&[1, 1]), 154.0); + } + + #[test] + fn test_ary_transpose_as_unary() { + // einsum_ary with 1 input should match einsum_unary + let mut a = Tensor::new(vec![2, 3]); + set_matrix(&mut a, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]); + + let mut t = Tensor::new(vec![3, 2]); + einsum_ary("ab->ba", &[&a], &mut t).unwrap(); + + assert_eq!(t.get(&[0, 0]), 1.0); + assert_eq!(t.get(&[0, 1]), 4.0); + assert_eq!(t.get(&[1, 0]), 2.0); + assert_eq!(t.get(&[2, 1]), 6.0); + } + + #[test] + fn test_ary_three_input_chain() { + // A(2×3) × B(3×4) × C(4×2) -> D(2×2) + // spec: "ab,bc,cd->ad" contracts b and c + let mut a = Tensor::new(vec![2, 3]); + set_matrix(&mut a, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]); + let mut b = Tensor::new(vec![3, 4]); + set_matrix(&mut b, &[1.0, 0.0, 0.0, 0.0, + 0.0, 1.0, 0.0, 0.0, + 0.0, 0.0, 1.0, 0.0]); + let mut c = Tensor::new(vec![4, 2]); + set_matrix(&mut c, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]); + + let mut d = Tensor::new(vec![2, 2]); + einsum_ary("ab,bc,cd->ad", &[&a, &b, &c], &mut d).unwrap(); + + // AB = [[1,2,3,0],[4,5,6,0]], ABC = AB×C = [[1*1+2*3+3*5, 1*2+2*4+3*6],[4*1+5*3+6*5, 4*2+5*4+6*6]] + // = [[22, 28],[49, 64]] + assert_eq!(d.get(&[0, 0]), 22.0); + assert_eq!(d.get(&[0, 1]), 28.0); + assert_eq!(d.get(&[1, 0]), 49.0); + assert_eq!(d.get(&[1, 1]), 64.0); + } + + #[test] + fn test_ary_outer_product_three() { + // No contraction: a(i) × b(j) × c(k) -> out(i,j,k) + let mut a = Tensor::new(vec![2]); + set_matrix(&mut a, &[2.0, 3.0]); + let mut b = Tensor::new(vec![2]); + set_matrix(&mut b, &[5.0, 7.0]); + let mut c = Tensor::new(vec![2]); + set_matrix(&mut c, &[11.0, 13.0]); + + let mut out = Tensor::new(vec![2, 2, 2]); + einsum_ary("a,b,c->abc", &[&a, &b, &c], &mut out).unwrap(); + + // out[i,j,k] = a[i]*b[j]*c[k] + assert_eq!(out.get(&[0, 0, 0]), 2.0 * 5.0 * 11.0); + assert_eq!(out.get(&[0, 0, 1]), 2.0 * 5.0 * 13.0); + assert_eq!(out.get(&[0, 1, 0]), 2.0 * 7.0 * 11.0); + assert_eq!(out.get(&[1, 1, 1]), 3.0 * 7.0 * 13.0); + } + + #[test] + fn test_ary_scalar_dot() { + // Scalar output via 0-dim tensor: "i,i->" (dot product) + let mut a = Tensor::new(vec![4]); + let mut b = Tensor::new(vec![4]); + set_matrix(&mut a, &[1.0, 2.0, 3.0, 4.0]); + set_matrix(&mut b, &[5.0, 6.0, 7.0, 8.0]); + + let mut out = Tensor::new(vec![]); // 0-dim + einsum_ary("i,i->", &[&a, &b], &mut out).unwrap(); + + assert_eq!(out.get(&[]), 70.0); + } + + #[test] + fn test_ary_scalar_trace() { + // Scalar output via 0-dim tensor: "aa->" (trace) + let mut m = Tensor::new(vec![3, 3]); + set_matrix(&mut m, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]); + + let mut out = Tensor::new(vec![]); + einsum_ary("aa->", &[&m], &mut out).unwrap(); + + assert_eq!(out.get(&[]), 15.0); + } + + #[test] + fn test_ary_scalar_three_input() { + // "i,i,i->" — element-wise product summed to scalar + let mut a = Tensor::new(vec![3]); + let mut b = Tensor::new(vec![3]); + let mut c = Tensor::new(vec![3]); + set_matrix(&mut a, &[1.0, 2.0, 3.0]); + set_matrix(&mut b, &[4.0, 5.0, 6.0]); + set_matrix(&mut c, &[7.0, 8.0, 9.0]); + + let mut out = Tensor::new(vec![]); + einsum_ary("i,i,i->", &[&a, &b, &c], &mut out).unwrap(); + + // 1*4*7 + 2*5*8 + 3*6*9 = 28 + 80 + 162 = 270 + assert_eq!(out.get(&[]), 270.0); + } + + #[test] + fn test_ary_row_sum_unary() { + let mut a = Tensor::new(vec![2, 3]); + set_matrix(&mut a, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]); + + let mut out = Tensor::new(vec![2]); + einsum_ary("ab->a", &[&a], &mut out).unwrap(); + + assert_eq!(out.get(&[0]), 6.0); + assert_eq!(out.get(&[1]), 15.0); + } +} + +#[cfg(test)] +#[cfg(feature = "ndarray")] +mod ndarray_tests { + use super::*; + use ndarray::{ArrayD, IxDyn}; + + fn arr1(data: &[f64]) -> ArrayD { + ArrayD::from_shape_vec(IxDyn(&[data.len()]), data.to_vec()).unwrap() + } + + fn arr2(rows: usize, cols: usize, data: &[f64]) -> ArrayD { + ArrayD::from_shape_vec(IxDyn(&[rows, cols]), data.to_vec()).unwrap() + } + + fn zeros(shape: &[usize]) -> ArrayD { + ArrayD::zeros(IxDyn(shape)) + } + + #[test] + fn test_ndindex_trait() { + let a = arr2(2, 2, &[1.0, 2.0, 3.0, 4.0]); + assert_eq!(NDIndex::ndim(&a), 2); + assert_eq!(NDIndex::dim(&a, 0), 2); + assert_eq!(NDIndex::dim(&a, 1), 2); + assert_eq!(NDIndex::get(&a, &[0, 1]), 2.0); + assert_eq!(NDIndex::get(&a, &[1, 0]), 3.0); + + let mut b = a.clone(); + NDIndex::set(&mut b, &[0, 0], 99.0); + assert_eq!(NDIndex::get(&b, &[0, 0]), 99.0); + } + + #[test] + fn test_matmul() { + let a = arr2(2, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]); + let b = arr2(3, 2, &[7.0, 8.0, 9.0, 10.0, 11.0, 12.0]); + let mut c = zeros(&[2, 2]); + + einsum_binary("ab,bc->ac", &a, &b, &mut c).unwrap(); + + assert_eq!(c[IxDyn(&[0, 0])], 58.0); + assert_eq!(c[IxDyn(&[0, 1])], 64.0); + assert_eq!(c[IxDyn(&[1, 0])], 139.0); + assert_eq!(c[IxDyn(&[1, 1])], 154.0); + } + + #[test] + fn test_transpose() { + let a = arr2(2, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]); + let mut t = zeros(&[3, 2]); + + einsum_unary("ab->ba", &a, &mut t).unwrap(); + + assert_eq!(t[IxDyn(&[0, 0])], 1.0); + assert_eq!(t[IxDyn(&[0, 1])], 4.0); + assert_eq!(t[IxDyn(&[1, 0])], 2.0); + assert_eq!(t[IxDyn(&[2, 1])], 6.0); + } + + #[test] + fn test_dot() { + let a = arr1(&[1.0, 2.0, 3.0, 4.0]); + let b = arr1(&[5.0, 6.0, 7.0, 8.0]); + + let result: f64 = einsum_binary_scalar("i,i->", &a, &b).unwrap(); + assert_eq!(result, 70.0); + } + + #[test] + fn test_trace() { + let m = arr2(3, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]); + let result: f64 = einsum_unary_scalar("aa->", &m).unwrap(); + assert_eq!(result, 15.0); + } + + #[test] + fn test_outer_product() { + let a = arr1(&[1.0, 2.0, 3.0]); + let b = arr1(&[4.0, 5.0]); + let mut c = zeros(&[3, 2]); + + einsum_binary("a,b->ab", &a, &b, &mut c).unwrap(); + + assert_eq!(c[IxDyn(&[0, 0])], 4.0); + assert_eq!(c[IxDyn(&[1, 1])], 10.0); + assert_eq!(c[IxDyn(&[2, 0])], 12.0); + } + + #[test] + fn test_row_sum() { + let a = arr2(2, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]); + let mut out = zeros(&[2]); + + einsum_unary("ab->a", &a, &mut out).unwrap(); + + assert_eq!(out[IxDyn(&[0])], 6.0); + assert_eq!(out[IxDyn(&[1])], 15.0); + } +} diff --git a/einsum-dyn/src/sparse.rs b/einsum-dyn/src/sparse.rs new file mode 100644 index 0000000..da2ed2c --- /dev/null +++ b/einsum-dyn/src/sparse.rs @@ -0,0 +1,1494 @@ +//! Sparse einsum: multiple approaches for sparse 2D matrix einsum. +//! +//! # Approaches +//! +//! ## 1. Baseline (`einsum_binary` with `get_opt`) +//! The existing dense-loop einsum from `lib.rs`. Iterates the full n×n×n index +//! space for matmul, but skips zero products via `get_opt()`. Complexity: O(n³). +//! No new code needed — use `einsum_binary` directly with any `NDIndex` impl. +//! +//! ## 2. Sparse-driven with dense accumulator (`einsum_sparse_driven`) +//! Equivalent to matmul (C = A × B). Only supports the `"ab,bc->ac"` pattern. +//! Iterates only non-zero entries of inputs using the `Sparse2D` trait: +//! loops A's rows, for each NZ (k, a_val) in row i, iterates B's row k +//! sparsely. Dense `Vec` accumulator per output row. +//! Complexity: O(Σ_i Σ_{k∈row(i)} row_nnz_B(k)) = O(flops). +//! +//! ## 3. Custom VM (`einsum_vm`) +//! Compiles the einsum spec into a tree of VM operations. A greedy scheduler +//! decides which loops iterate sparsely (via `Sparse2D::row_entry`) vs densely. +//! The VM interpreter walks the tree recursively. Same asymptotic complexity +//! as approach 2, but handles arbitrary 2D specs without hardcoding patterns. +//! +//! ## 4. Sparse-driven with hash accumulator (`einsum_sparse_hash`) +//! Same sparse iteration as approach 2, but accumulates into a `HashMap` +//! per output row instead of a dense vector. Better when the output is very sparse +//! (each row has few non-zeros), since it avoids allocating/clearing an n-wide +//! accumulator. Worse for dense output due to hashing overhead. + +use std::collections::HashMap; +use std::ops::{Add, AddAssign, Mul}; + +use crate::{NDIndex, InvalidSpec}; + +// ═══════════════════════════════════════════════════════════════════════════ +// Sparse2D trait +// ═══════════════════════════════════════════════════════════════════════════ + +/// Extension of `NDIndex` for sparse 2D arrays (matrices). +/// +/// Provides structured row-wise access to non-zero entries, enabling +/// sparse-driven einsum execution that skips entire zero regions. +pub trait Sparse2D: NDIndex { + /// Total number of structural non-zeros. + fn nnz(&self) -> usize; + + /// Number of rows (axis 0 dimension). + fn n_rows(&self) -> usize; + + /// Number of non-zero entries in the given row. + fn row_nnz(&self, row: usize) -> usize; + + /// Get the `idx`-th non-zero entry in the given row as `(col, value)`. + /// `idx` must be in `0..row_nnz(row)`. + fn row_entry(&self, row: usize, idx: usize) -> (usize, T); +} + +// ═══════════════════════════════════════════════════════════════════════════ +// Approach 2: Sparse-driven with dense accumulator +// ═══════════════════════════════════════════════════════════════════════════ + +/// Sparse-driven binary einsum with dense row accumulator. +/// +/// Only supports the standard matmul pattern: A's axis-1 index must equal +/// B's axis-0 index (the contracted/inner dimension). Any letter names work +/// (`"ab,bc->ac"`, `"xy,yz->xz"`, etc.) but the structure must be matmul. +/// For other specs, use `einsum_vm_oneshot` or `einsum_sparse_hash`. +/// +/// Both inputs must implement `Sparse2D`. The output must be a writable +/// `NDIndex` (typically dense) with the correct shape. +pub fn einsum_sparse_driven( + spec_str: &str, + a: &S, + b: &S, + out: &mut Out, +) -> Result<(), InvalidSpec> +where + T: Default + Copy + PartialEq + AddAssign + Mul, + S: Sparse2D, + Out: NDIndex, +{ + let spec = crate::parse_spec(spec_str, 2)?; + let dims = crate::validate_dims::(&spec, &[a, b])?; + crate::validate_output::(&spec, &dims, out)?; + + let a_slots = &spec.inputs[0]; + let b_slots = &spec.inputs[1]; + let out_slots = &spec.output(); + + assert_eq!(a_slots.len(), 2, "sparse-driven requires 2D inputs"); + assert_eq!(b_slots.len(), 2, "sparse-driven requires 2D inputs"); + + let (a0, a1) = (a_slots[0], a_slots[1]); + let (_b0, b1) = (b_slots[0], b_slots[1]); + + // Verify A's col index == B's row index. + // This restricts us to the matmul pattern (C = A × B) — the only + // structural arrangement where both inputs can be walked row-wise + // through CSR without transposing. + if a1 != b_slots[0] { + panic!( + "einsum_sparse_driven only supports matmul pattern \ + (A's axis-1 == B's axis-0), got spec '{spec_str}'" + ); + } + + // Map output slot positions: which output axis gets which slot value + let out_pos_a0 = out_slots.iter().position(|&s| s == a0); + let out_pos_b1 = out_slots.iter().position(|&s| s == b1); + let n_out_dims = out_slots.len(); + + // Dense accumulator sized to the full output column range. + // Track which columns were touched to avoid clearing the entire vector. + let n_cols_out = dims[b1 as usize]; + let mut acc = vec![T::default(); n_cols_out]; + let mut nz_cols: Vec = Vec::new(); + let mut out_ix = vec![0usize; n_out_dims]; + + let n_rows_a = a.n_rows(); + + for i in 0..n_rows_a { + // Scatter: iterate A's row i, then B's matching rows + for ai in 0..a.row_nnz(i) { + let (k, a_val) = a.row_entry(i, ai); + for bi in 0..b.row_nnz(k) { + let (j, b_val) = b.row_entry(k, bi); + if acc[j] == T::default() { + nz_cols.push(j); + } + acc[j] += a_val * b_val; + } + } + + // Write touched columns to output, then clear them + if let Some(p) = out_pos_a0 { + out_ix[p] = i; + } + for &j in &nz_cols { + if let Some(p) = out_pos_b1 { + out_ix[p] = j; + } + out.set(&out_ix[..n_out_dims], acc[j]); + acc[j] = T::default(); + } + nz_cols.clear(); + } + + Ok(()) +} + +// ═══════════════════════════════════════════════════════════════════════════ +// Approach 3: Custom VM +// ═══════════════════════════════════════════════════════════════════════════ + +/// VM operation — flat bytecode. `DenseLoop` and `SparseRowLoop` mark loop +/// starts; `LoopEnd` marks the end of the innermost enclosing loop. +/// Each loop-start stores the pc of its matching `LoopEnd` (and vice versa) +/// so the interpreter never needs to scan for matching brackets. +#[derive(Debug)] +pub enum VmOp { + /// Dense loop: iterate `slot` from 0 to `dim-1`. + /// `end_pc` points one past the matching `LoopEnd`. + /// When `fused`, the body is a single `MulAcc` — the loop runs the + /// multiply-accumulate inline without recursing into `exec_at`. + DenseLoop { slot: u8, dim: usize, end_pc: usize, fused: bool }, + /// Sparse row iteration: for each non-zero in `input[input_idx]` + /// at the row given by `vals[row_slot]`, set `vals[col_slot]` to the + /// column index. + /// `end_pc` points one past the matching `LoopEnd`. + /// When `fused`, the body is a single `MulAcc` — the loop runs the + /// multiply-accumulate inline without recursing into `exec_at`. + SparseRowLoop { + input_idx: usize, + row_slot: u8, + col_slot: u8, + end_pc: usize, + fused: bool, + }, + /// End of the enclosing loop. `start_pc` points back to the loop-start. + LoopEnd { start_pc: usize }, + /// Initialize a dense accumulator of size `dim`, indexed by `acc_slot`. + /// Placed just before the loops that should accumulate. + AccStart { acc_slot: u8, acc_out_pos: u8, dim: usize }, + /// Flush the dense accumulator to the output (scatter-gather: only write + /// and clear touched entries), then reset. + AccFlush, + /// Read input values at current slot positions, multiply, and + /// accumulate into the output (or into the active accumulator). + MulAcc, +} + +/// Precompiled VM program for sparse einsum. +pub struct VmProgram { + /// Flat bytecode. + pub ops: Vec, + /// Input slot patterns: input_patterns[i] = list of slot indices for input i. + input_patterns: Vec>, + /// Output slot patterns: one per output in the spec. + output_patterns: Vec>, + /// For each input: `Some(loop_index)` of the SparseRowLoop that fully covers + /// it (both axes iterated by that one loop), or `None` if `get_opt` is needed. + /// When `Some`, `MulAcc` reads the cached sparse value instead of calling `get_opt`. + sparse_value_source: Vec>, +} + +/// Compile an einsum spec into a VM program. +/// +/// Accepts inputs of any dimensionality. The compiler uses a greedy strategy: +/// 1. For each 2D sparse input (`is_sparse_2d()`), its axis-1 slot can be +/// iterated sparsely when its axis-0 slot is already fixed. +/// 2. Non-sparse or higher-dimensional inputs use dense loops for all slots. +/// 3. When choosing dense loops, prefer axis-0 slots of sparse inputs first +/// (so their axis-1 can be sparse in the next level). +/// +/// The `inputs` slice is used to check dimensionality and `is_sparse_2d()`. +/// It is not retained — execution uses separately provided inputs. +pub fn compile_vm( + spec_str: &str, + inputs: &[&dyn NDIndex], +) -> Result { + let n_inputs = inputs.len(); + let spec = crate::parse_spec(spec_str, n_inputs)?; + + // Validate dimensions + let mut dims = [0usize; 26]; + let mut dim_set = [false; 26]; + for (pi, inp_spec) in spec.inputs.iter().enumerate() { + let arr = inputs[pi]; + if arr.ndim() != inp_spec.len() { + return Err(InvalidSpec::InputNdimMismatch { + input: pi, + array_ndim: arr.ndim(), + spec_ndim: inp_spec.len(), + }); + } + for (pos, &s) in inp_spec.iter().enumerate() { + let si = s as usize; + let d = arr.dim(pos); + if dim_set[si] { + if dims[si] != d { + return Err(InvalidSpec::DimensionMismatch { + index: (s + b'a') as char, + expected: dims[si], + got: d, + }); + } + } else { + dims[si] = d; + dim_set[si] = true; + } + } + } + + // Identify which inputs are sparse-2D eligible for SparseRowLoop. + // An input qualifies when it has exactly 2 spec indices AND is_sparse_2d(). + // For qualifying inputs, record (axis_0_slot, axis_1_slot). + let sparse_axes: Vec> = spec + .inputs + .iter() + .zip(inputs.iter()) + .map(|(inp_spec, arr)| { + if inp_spec.len() == 2 && arr.is_sparse_2d() { + Some((inp_spec[0], inp_spec[1])) + } else { + None + } + }) + .collect(); + + // Collect all unique slots in order of first appearance + let mut all_slots = Vec::new(); + let mut seen = [false; 26]; + for inp in &spec.inputs { + for &s in inp { + if !seen[s as usize] { + seen[s as usize] = true; + all_slots.push(s); + } + } + } + for out in &spec.outputs { + for &s in out { + if !seen[s as usize] { + seen[s as usize] = true; + all_slots.push(s); + } + } + } + + // Greedy scheduler: decide loop order, then emit flat bytecode. + let mut fixed = [false; 26]; + let mut loop_order: Vec = Vec::new(); + let mut n_fixed = 0usize; + + while n_fixed < all_slots.len() { + let mut found_sparse = false; + + // Try to find a slot that can be sparse-iterated: + // s is axis-1 of a sparse-2D input whose axis-0 is already fixed. + for &s in &all_slots { + if fixed[s as usize] { + continue; + } + for (idx, axes) in sparse_axes.iter().enumerate() { + if let Some((ax0, ax1)) = axes { + if *ax1 == s && fixed[*ax0 as usize] { + loop_order.push(VmOp::SparseRowLoop { + input_idx: idx, + row_slot: *ax0, + col_slot: s, + end_pc: 0, // patched below + fused: false, // patched below + }); + fixed[s as usize] = true; + n_fixed += 1; + found_sparse = true; + break; + } + } + } + if found_sparse { + break; + } + } + + if !found_sparse { + // Dense loop. Prefer axis-0 slots of sparse inputs (enables future sparse). + let mut best = None; + for &s in &all_slots { + if fixed[s as usize] { + continue; + } + let is_sparse_ax0 = sparse_axes + .iter() + .any(|axes| matches!(axes, Some((ax0, _)) if *ax0 == s)); + if is_sparse_ax0 || best.is_none() { + best = Some(s); + if is_sparse_ax0 { + break; + } + } + } + let s = best.unwrap(); + loop_order.push(VmOp::DenseLoop { + slot: s, + dim: dims[s as usize], + end_pc: 0, // patched below + fused: false, // patched below + }); + fixed[s as usize] = true; + n_fixed += 1; + } + } + + // For each input, check if a single SparseRowLoop covers both its axes. + // If so, MulAcc can use the cached sparse value instead of get_opt(). + let sparse_value_source: Vec> = spec + .inputs + .iter() + .enumerate() + .map(|(inp_idx, inp_spec)| { + if inp_spec.len() != 2 { + return None; + } + // Find a SparseRowLoop in loop_order that iterates this input + for (loop_idx, op) in loop_order.iter().enumerate() { + if let VmOp::SparseRowLoop { input_idx, row_slot, col_slot, .. } = op { + if *input_idx == inp_idx { + // Check this loop covers both axes of this input + if sparse_axes[inp_idx] == Some((*row_slot, *col_slot)) { + return Some(loop_idx); + } + } + } + } + None + }) + .collect(); + + // Accumulator: if the innermost loop's slot appears in the output pattern + // and there's at least one other output slot, we can use a dense accumulator. + // We emit AccStart just inside the outermost output-slot loop and AccFlush + // just before each iteration's end of that same loop. + // + // For "ab,bc->ac": output=[a,c], innermost loop is c. + // acc_slot=c, flush_loop_idx=0 (the 'a' loop). + // Bytecode: FOR a | AccStart | FOR b(sparse) | FOR c(sparse) | MulAcc | + // LoopEnd(c) | LoopEnd(b) | AccFlush | LoopEnd(a) + // Accumulator optimization: only for single-output specs. + // With multi-output, different outputs may have different index layouts, + // so the dense accumulator trick doesn't generalise cleanly. + let all_output_slots = spec.all_output_slots(); + let mut acc_info: Option<(u8, u8, usize, usize)> = None; // (acc_slot, acc_out_pos, acc_dim, flush_loop_idx) + if spec.outputs.len() == 1 { + if let Some(last_op) = loop_order.last() { + let inner_slot = match last_op { + VmOp::DenseLoop { slot, .. } => *slot, + VmOp::SparseRowLoop { col_slot, .. } => *col_slot, + _ => unreachable!(), + }; + if let Some(pos) = spec.output().iter().position(|&s| s == inner_slot) { + for (i, op) in loop_order.iter().enumerate().rev() { + let s = match op { + VmOp::DenseLoop { slot, .. } => *slot, + VmOp::SparseRowLoop { col_slot, .. } => *col_slot, + _ => unreachable!(), + }; + if s != inner_slot && all_output_slots.contains(&s) { + acc_info = Some((inner_slot, pos as u8, dims[inner_slot as usize], i)); + break; + } + } + } + } + } + + // Emit flat bytecode. + // Layout: loops... MulAcc LoopEnd... with AccStart/AccFlush injected. + let n_loops = loop_order.len(); + let mut ops: Vec = Vec::with_capacity(n_loops * 2 + 4); + + // Emit loop-starts, injecting AccStart after the flush loop + for (i, op) in loop_order.into_iter().enumerate() { + ops.push(op); + if let Some((acc_slot, acc_out_pos, dim, flush_idx)) = acc_info { + if i == flush_idx { + ops.push(VmOp::AccStart { acc_slot, acc_out_pos, dim }); + } + } + } + ops.push(VmOp::MulAcc); + + // Emit LoopEnds (innermost first), injecting AccFlush before the flush loop's LoopEnd + for i in 0..n_loops { + let loop_idx = n_loops - 1 - i; + if let Some((_, _, _, flush_idx)) = acc_info { + if loop_idx == flush_idx { + ops.push(VmOp::AccFlush); + } + } + // Find this loop's start_pc by scanning back for the loop_idx-th loop-start + // The loop-starts are at varying positions due to AccStart injection. + // Track them: loop_idx corresponds to the loop_idx-th loop-start in ops. + let start_pc = ops.iter().enumerate() + .filter(|(_, op)| matches!(op, VmOp::DenseLoop{..} | VmOp::SparseRowLoop{..})) + .nth(loop_idx) + .unwrap().0; + let end_pc = ops.len() + 1; // one past the LoopEnd we're about to push + match &mut ops[start_pc] { + VmOp::DenseLoop { end_pc: ep, .. } => *ep = end_pc, + VmOp::SparseRowLoop { end_pc: ep, .. } => *ep = end_pc, + _ => unreachable!(), + } + ops.push(VmOp::LoopEnd { start_pc }); + } + + // Fusion: if a loop's body is just MulAcc (followed by LoopEnd), mark it fused. + // The loop will inline the multiply-accumulate without recursing. + for pc in 0..ops.len() { + let is_loop = matches!(&ops[pc], VmOp::DenseLoop{..} | VmOp::SparseRowLoop{..}); + if is_loop && matches!(&ops[pc + 1], VmOp::MulAcc) { + match &mut ops[pc] { + VmOp::DenseLoop { fused, .. } => *fused = true, + VmOp::SparseRowLoop { fused, .. } => *fused = true, + _ => unreachable!(), + } + } + } + + Ok(VmProgram { + ops, + input_patterns: spec.inputs.clone(), + output_patterns: spec.outputs.clone(), + sparse_value_source, + }) +} + +/// Accumulator state for the VM interpreter. +struct AccState { + acc: Vec, + nz_cols: Vec, + acc_slot: u8, + acc_out_pos: u8, +} + +impl VmProgram { + /// Execute this compiled VM program. + /// + /// Inputs can be any mix of dense and sparse `NDIndex` implementations. + /// `SparseRowLoop` ops call `sparse_row_nnz` / `sparse_row_entry` on the + /// relevant input — the compiler only emits these for inputs where + /// `is_sparse_2d()` returned true. + /// + /// Optimizations over naive interpretation: + /// - Sparse values from `row_entry()` are cached and reused in `MulAcc`, + /// avoiding redundant `get_opt()` binary searches. + /// - `AccStart`/`AccFlush` ops enable scatter-gather accumulation for the + /// innermost output dimension. + pub fn exec( + &self, + inputs: &[&dyn NDIndex], + outs: &mut [&mut dyn NDIndex], + ) where + T: Default + Copy + Add + AddAssign + Mul + PartialEq, + { + let mut vals = [0usize; 26]; + let mut buf = [0usize; 26]; + let mut sparse_vals: Vec = vec![T::default(); inputs.len()]; + let mut acc_state: Option> = None; + self.exec_at(0, &mut vals, &mut buf, &mut sparse_vals, &mut acc_state, inputs, outs); + } + + /// Execute bytecode starting at `pc`. Returns the pc after the + /// matching `LoopEnd` (or end of program). + fn exec_at( + &self, + mut pc: usize, + vals: &mut [usize; 26], + buf: &mut [usize; 26], + sparse_vals: &mut [T], + acc_state: &mut Option>, + inputs: &[&dyn NDIndex], + outs: &mut [&mut dyn NDIndex], + ) -> usize + where + T: Default + Copy + Add + AddAssign + Mul + PartialEq, + { + let ops = &self.ops; + while pc < ops.len() { + match &ops[pc] { + VmOp::DenseLoop { slot, dim, end_pc, fused } => { + let s = *slot as usize; + if *fused { + for v in 0..*dim { + vals[s] = v; + self.mul_acc(vals, buf, sparse_vals, acc_state, inputs, outs); + } + } else { + for v in 0..*dim { + vals[s] = v; + self.exec_at(pc + 1, vals, buf, sparse_vals, acc_state, inputs, outs); + } + } + pc = *end_pc; + } + VmOp::SparseRowLoop { + input_idx, + row_slot, + col_slot, + end_pc, + fused, + } => { + let row = vals[*row_slot as usize]; + let cs = *col_slot as usize; + let input = inputs[*input_idx]; + let nnz = input.sparse_row_nnz(row); + if *fused { + for ei in 0..nnz { + let (col, val) = input.sparse_row_entry(row, ei); + vals[cs] = col; + sparse_vals[*input_idx] = val; + self.mul_acc(vals, buf, sparse_vals, acc_state, inputs, outs); + } + } else { + for ei in 0..nnz { + let (col, val) = input.sparse_row_entry(row, ei); + vals[cs] = col; + sparse_vals[*input_idx] = val; + self.exec_at(pc + 1, vals, buf, sparse_vals, acc_state, inputs, outs); + } + } + pc = *end_pc; + } + VmOp::LoopEnd { .. } => { + return pc + 1; + } + VmOp::AccStart { acc_slot, acc_out_pos, dim } => { + *acc_state = Some(AccState { + acc: vec![T::default(); *dim], + nz_cols: Vec::new(), + acc_slot: *acc_slot, + acc_out_pos: *acc_out_pos, + }); + pc += 1; + } + VmOp::AccFlush => { + // AccFlush only emitted for single-output specs + if let Some(st) = acc_state { + let pattern = &self.output_patterns[0]; + let len = pattern.len(); + for &j in st.nz_cols.iter() { + for (i, &s) in pattern.iter().enumerate() { + buf[i] = vals[s as usize]; + } + buf[st.acc_out_pos as usize] = j; + outs[0].set(&buf[..len], st.acc[j]); + st.acc[j] = T::default(); + } + st.nz_cols.clear(); + } + pc += 1; + } + VmOp::MulAcc => { + self.mul_acc(vals, buf, sparse_vals, acc_state, inputs, outs); + pc += 1; + } + } + } + pc + } + + /// Compute one multiply-accumulate step: read inputs, multiply, write to + /// accumulator or output(s). + #[inline] + fn mul_acc( + &self, + vals: &[usize; 26], + buf: &mut [usize; 26], + sparse_vals: &[T], + acc_state: &mut Option>, + inputs: &[&dyn NDIndex], + outs: &mut [&mut dyn NDIndex], + ) where + T: Default + Copy + Add + AddAssign + Mul + PartialEq, + { + let mut product = None::; + for (i, pattern) in self.input_patterns.iter().enumerate() { + let v = if self.sparse_value_source[i].is_some() { + Some(sparse_vals[i]) + } else { + let len = pattern.len(); + for (p, &s) in pattern.iter().enumerate() { + buf[p] = vals[s as usize]; + } + inputs[i].get_opt(&buf[..len]) + }; + if let Some(v) = v { + product = Some(match product { + Some(p) => p * v, + None => v, + }); + } else { + product = None; + break; + } + } + if let Some(p) = product { + if let Some(st) = acc_state { + // Single-output accumulator path + let idx = vals[st.acc_slot as usize]; + if st.acc[idx] == T::default() { + st.nz_cols.push(idx); + } + st.acc[idx] += p; + } else { + // Write to all outputs + for (oi, pattern) in self.output_patterns.iter().enumerate() { + let len = pattern.len(); + for (i, &s) in pattern.iter().enumerate() { + buf[i] = vals[s as usize]; + } + let cur = outs[oi].get(&buf[..len]); + outs[oi].set(&buf[..len], cur + p); + } + } + } + } +} + +/// Convenience: compile and execute in one call. +/// +/// Accepts any number of inputs and outputs of any dimensionality. +/// All inputs must be the same concrete type; for mixed types use +/// [`einsum_vm_oneshot_dyn`]. +/// +/// Multiple outputs are supported: `"ab,bc->ac,ca"` writes to two output +/// tensors simultaneously from a single loop nest. +pub fn einsum_vm_oneshot( + spec_str: &str, + inputs: &[&In], + outs: &mut [&mut Out], +) -> Result<(), InvalidSpec> +where + T: Copy + Default + Add + AddAssign + Mul + PartialEq, + In: NDIndex, + Out: NDIndex, +{ + let dyn_inputs: Vec<&dyn NDIndex> = inputs.iter().map(|&x| x as &dyn NDIndex).collect(); + let mut dyn_outs: Vec<&mut dyn NDIndex> = outs.iter_mut().map(|o| { + let r: &mut Out = *o; + r as &mut dyn NDIndex + }).collect(); + einsum_vm_oneshot_dyn(spec_str, &dyn_inputs, &mut dyn_outs) +} + +/// Like [`einsum_vm_oneshot`] but accepts trait-object inputs, allowing +/// mixed concrete types (e.g. one sparse and one dense input). +pub fn einsum_vm_oneshot_dyn( + spec_str: &str, + inputs: &[&dyn NDIndex], + outs: &mut [&mut dyn NDIndex], +) -> Result<(), InvalidSpec> +where + T: Copy + Default + Add + AddAssign + Mul + PartialEq, +{ + let program = compile_vm(spec_str, inputs)?; + program.exec(inputs, outs); + Ok(()) +} + +// ═══════════════════════════════════════════════════════════════════════════ +// Approach 4: Sparse-driven with hash accumulator +// ═══════════════════════════════════════════════════════════════════════════ + +/// Sparse-driven binary einsum with hash accumulator per output row. +/// +/// Equivalent to matmul (C = A × B) — same restriction as +/// `einsum_sparse_driven`: only supports the `"ab,bc->ac"` pattern. +/// +/// Uses a `HashMap` per row instead of a dense `Vec`. +/// Better when the output is very sparse; worse for dense output. +pub fn einsum_sparse_hash( + spec_str: &str, + a: &S, + b: &S, + out: &mut Out, +) -> Result<(), InvalidSpec> +where + T: Default + Copy + AddAssign + Mul + PartialEq, + S: Sparse2D, + Out: NDIndex, +{ + let spec = crate::parse_spec(spec_str, 2)?; + let dims = crate::validate_dims::(&spec, &[a, b])?; + crate::validate_output::(&spec, &dims, out)?; + + let a_slots = &spec.inputs[0]; + let b_slots = &spec.inputs[1]; + let out_slots = &spec.output(); + + assert_eq!(a_slots.len(), 2, "sparse-hash requires 2D inputs"); + assert_eq!(b_slots.len(), 2, "sparse-hash requires 2D inputs"); + + let (a0, a1) = (a_slots[0], a_slots[1]); + let b1 = b_slots[1]; + + if a1 != b_slots[0] { + panic!( + "einsum_sparse_hash requires A's axis-1 == B's axis-0, \ + got spec '{spec_str}'" + ); + } + + let out_pos_a0 = out_slots.iter().position(|&s| s == a0); + let out_pos_b1 = out_slots.iter().position(|&s| s == b1); + let n_out_dims = out_slots.len(); + + let n_rows_a = a.n_rows(); + let mut acc: HashMap = HashMap::new(); + let mut out_ix = vec![0usize; n_out_dims]; + + for i in 0..n_rows_a { + acc.clear(); + + for ai in 0..a.row_nnz(i) { + let (k, a_val) = a.row_entry(i, ai); + for bi in 0..b.row_nnz(k) { + let (j, b_val) = b.row_entry(k, bi); + *acc.entry(j).or_insert_with(T::default) += a_val * b_val; + } + } + + if let Some(p) = out_pos_a0 { + out_ix[p] = i; + } + for (&j, &v) in &acc { + if let Some(p) = out_pos_b1 { + out_ix[p] = j; + } + out.set(&out_ix[..n_out_dims], v); + } + } + + Ok(()) +} + +// ═══════════════════════════════════════════════════════════════════════════ +// Display for VmOp (debugging / documentation) +// ═══════════════════════════════════════════════════════════════════════════ + +impl std::fmt::Display for VmProgram { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + writeln!(f, "VM Program:")?; + writeln!( + f, + " inputs: {:?}", + self.input_patterns + .iter() + .map(|p| p.iter().map(|&s| (s + b'a') as char).collect::()) + .collect::>() + )?; + writeln!( + f, + " outputs: {:?}", + self.output_patterns + .iter() + .map(|p| p.iter().map(|&s| (s + b'a') as char).collect::()) + .collect::>() + )?; + writeln!(f, " plan:")?; + let mut indent = 2usize; + for op in &self.ops { + let pad = " ".repeat(indent); + match op { + VmOp::DenseLoop { slot, dim, fused, .. } => { + let ch = (slot + b'a') as char; + let tag = if *fused { " [FUSED]" } else { "" }; + writeln!(f, "{pad}FOR {ch} IN 0..{dim}{tag}")?; + indent += 1; + } + VmOp::SparseRowLoop { + input_idx, + row_slot, + col_slot, + fused, + .. + } => { + let row_ch = (row_slot + b'a') as char; + let col_ch = (col_slot + b'a') as char; + let tag = if *fused { " [SPARSE,FUSED]" } else { " [SPARSE]" }; + writeln!( + f, + "{pad}FOR ({col_ch}, val) IN input[{input_idx}].row({row_ch}){tag}" + )?; + indent += 1; + } + VmOp::LoopEnd { .. } => { + indent -= 1; + } + VmOp::AccStart { acc_slot, dim, .. } => { + let ch = (acc_slot + b'a') as char; + writeln!(f, "{pad}ACC_START {ch}[0..{dim}]")?; + } + VmOp::AccFlush => { + writeln!(f, "{pad}ACC_FLUSH → output")?; + } + VmOp::MulAcc => { + writeln!(f, "{pad}MUL_ACC → acc")?; + } + } + } + Ok(()) + } +} + +// ═══════════════════════════════════════════════════════════════════════════ +// Tests +// ═══════════════════════════════════════════════════════════════════════════ + +#[cfg(test)] +mod tests { + use super::*; + use crate::einsum_binary; + + /// Minimal dense tensor for output. + struct DenseMat { + data: Vec, + rows: usize, + cols: usize, + } + + impl DenseMat { + fn new(rows: usize, cols: usize) -> Self { + Self { + data: vec![0; rows * cols], + rows, + cols, + } + } + } + + impl NDIndex for DenseMat { + fn ndim(&self) -> usize { + 2 + } + fn dim(&self, axis: usize) -> usize { + if axis == 0 { + self.rows + } else { + self.cols + } + } + fn get(&self, ix: &[usize]) -> u32 { + self.data[ix[0] * self.cols + ix[1]] + } + fn set(&mut self, ix: &[usize], v: u32) { + self.data[ix[0] * self.cols + ix[1]] = v; + } + } + + /// Minimal sparse matrix for testing (COO-based, implements Sparse2D). + struct SparseMat { + n: usize, + /// Sorted by (row, col). Each row's entries are contiguous. + row_ptr: Vec, + col_idx: Vec, + values: Vec, + } + + impl SparseMat { + /// Build from (row, col, val) triplets. n = matrix dimension. + fn from_triplets(n: usize, trips: &[(usize, usize, u32)]) -> Self { + let mut sorted = trips.to_vec(); + sorted.sort_by_key(|&(r, c, _)| (r, c)); + + let mut row_ptr = vec![0; n + 1]; + let mut col_idx = Vec::new(); + let mut values = Vec::new(); + + for &(r, c, v) in &sorted { + if v == 0 { + continue; + } + col_idx.push(c); + values.push(v); + row_ptr[r + 1] = col_idx.len(); + } + // Fill forward + for i in 1..=n { + if row_ptr[i] == 0 && i > 0 { + row_ptr[i] = row_ptr[i - 1]; + } + } + // Fix: ensure monotonic + for i in 1..=n { + row_ptr[i] = row_ptr[i].max(row_ptr[i - 1]); + } + + Self { + n, + row_ptr, + col_idx, + values, + } + } + } + + impl NDIndex for SparseMat { + fn ndim(&self) -> usize { + 2 + } + fn dim(&self, _axis: usize) -> usize { + self.n + } + fn get(&self, ix: &[usize]) -> u32 { + let r = ix[0]; + let c = ix[1]; + let start = self.row_ptr[r]; + let end = self.row_ptr[r + 1]; + for i in start..end { + if self.col_idx[i] == c { + return self.values[i]; + } + } + 0 + } + fn set(&mut self, _ix: &[usize], _v: u32) { + panic!("SparseMat is read-only"); + } + fn get_opt(&self, ix: &[usize]) -> Option { + let r = ix[0]; + let c = ix[1]; + let start = self.row_ptr[r]; + let end = self.row_ptr[r + 1]; + for i in start..end { + if self.col_idx[i] == c { + return Some(self.values[i]); + } + } + None + } + fn is_sparse_2d(&self) -> bool { true } + fn sparse_row_nnz(&self, row: usize) -> usize { + self.row_ptr[row + 1] - self.row_ptr[row] + } + fn sparse_row_entry(&self, row: usize, idx: usize) -> (usize, u32) { + let start = self.row_ptr[row]; + (self.col_idx[start + idx], self.values[start + idx]) + } + } + + impl Sparse2D for SparseMat { + fn nnz(&self) -> usize { + self.values.len() + } + fn n_rows(&self) -> usize { + self.n + } + fn row_nnz(&self, row: usize) -> usize { + self.row_ptr[row + 1] - self.row_ptr[row] + } + fn row_entry(&self, row: usize, idx: usize) -> (usize, u32) { + let start = self.row_ptr[row]; + (self.col_idx[start + idx], self.values[start + idx]) + } + } + + /// Reference matmul for verification. + fn naive_matmul(a: &SparseMat, b: &SparseMat) -> Vec> { + let n = a.n; + let mut c = vec![vec![0u32; n]; n]; + for i in 0..n { + for k in 0..n { + let a_ik = a.get(&[i, k]); + if a_ik == 0 { + continue; + } + for j in 0..n { + c[i][j] += a_ik * b.get(&[k, j]); + } + } + } + c + } + + #[test] + fn test_approach1_baseline() { + let a = SparseMat::from_triplets(3, &[(0, 1, 1), (1, 2, 1), (2, 0, 1)]); + let b = SparseMat::from_triplets(3, &[(0, 1, 1), (1, 2, 1), (2, 0, 1)]); + let mut out = DenseMat::new(3, 3); + + einsum_binary("ab,bc->ac", &a, &b, &mut out).unwrap(); + + let expected = naive_matmul(&a, &b); + for i in 0..3 { + for j in 0..3 { + assert_eq!( + out.get(&[i, j]), + expected[i][j], + "baseline mismatch at ({i},{j})" + ); + } + } + } + + #[test] + fn test_approach2_sparse_driven() { + let a = SparseMat::from_triplets(4, &[(0, 1, 2), (0, 2, 3), (1, 3, 1), (2, 3, 4)]); + let b = SparseMat::from_triplets(4, &[(1, 0, 5), (2, 0, 6), (3, 1, 7)]); + let mut out = DenseMat::new(4, 4); + + einsum_sparse_driven("ab,bc->ac", &a, &b, &mut out).unwrap(); + + let expected = naive_matmul(&a, &b); + for i in 0..4 { + for j in 0..4 { + assert_eq!( + out.get(&[i, j]), + expected[i][j], + "sparse-driven mismatch at ({i},{j})" + ); + } + } + } + + #[test] + fn test_approach3_vm() { + let a = SparseMat::from_triplets(4, &[(0, 1, 2), (0, 2, 3), (1, 3, 1), (2, 3, 4)]); + let b = SparseMat::from_triplets(4, &[(1, 0, 5), (2, 0, 6), (3, 1, 7)]); + let mut out = DenseMat::new(4, 4); + + einsum_vm_oneshot("ab,bc->ac", &[&a, &b], &mut [&mut out]).unwrap(); + + let expected = naive_matmul(&a, &b); + for i in 0..4 { + for j in 0..4 { + assert_eq!( + out.get(&[i, j]), + expected[i][j], + "VM mismatch at ({i},{j})" + ); + } + } + } + + #[test] + fn test_approach3_vm_display() { + let a = SparseMat::from_triplets(4, &[(0, 1, 1)]); + let b = SparseMat::from_triplets(4, &[(1, 0, 1)]); + let program = compile_vm("ab,bc->ac", &[&a as &dyn NDIndex, &b]).unwrap(); + let display = format!("{program}"); + assert!(display.contains("SPARSE"), "VM should use sparse loops:\n{display}"); + println!("{display}"); + } + + #[test] + fn test_approach4_sparse_hash() { + let a = SparseMat::from_triplets(4, &[(0, 1, 2), (0, 2, 3), (1, 3, 1), (2, 3, 4)]); + let b = SparseMat::from_triplets(4, &[(1, 0, 5), (2, 0, 6), (3, 1, 7)]); + let mut out = DenseMat::new(4, 4); + + einsum_sparse_hash("ab,bc->ac", &a, &b, &mut out).unwrap(); + + let expected = naive_matmul(&a, &b); + for i in 0..4 { + for j in 0..4 { + assert_eq!( + out.get(&[i, j]), + expected[i][j], + "hash mismatch at ({i},{j})" + ); + } + } + } + + #[test] + fn test_all_approaches_agree_diamond() { + // Diamond graph: 0→1, 0→2, 1→3, 2→3 + let a = SparseMat::from_triplets( + 4, + &[(0, 1, 1), (0, 2, 1), (1, 3, 1), (2, 3, 1)], + ); + let b = a.clone_sparse(); + let expected = naive_matmul(&a, &b); + + let mut out1 = DenseMat::new(4, 4); + einsum_binary("ab,bc->ac", &a, &b, &mut out1).unwrap(); + + let mut out2 = DenseMat::new(4, 4); + einsum_sparse_driven("ab,bc->ac", &a, &b, &mut out2).unwrap(); + + let mut out3 = DenseMat::new(4, 4); + einsum_vm_oneshot("ab,bc->ac", &[&a, &b], &mut [&mut out3]).unwrap(); + + let mut out4 = DenseMat::new(4, 4); + einsum_sparse_hash("ab,bc->ac", &a, &b, &mut out4).unwrap(); + + for i in 0..4 { + for j in 0..4 { + let e = expected[i][j]; + assert_eq!(out1.get(&[i, j]), e, "baseline@({i},{j})"); + assert_eq!(out2.get(&[i, j]), e, "sparse-driven@({i},{j})"); + assert_eq!(out3.get(&[i, j]), e, "VM@({i},{j})"); + assert_eq!(out4.get(&[i, j]), e, "hash@({i},{j})"); + } + } + } + + #[test] + fn test_identity_matmul() { + // A × I = A + let a = SparseMat::from_triplets(3, &[(0, 1, 5), (1, 2, 3), (2, 0, 7)]); + let id = SparseMat::from_triplets(3, &[(0, 0, 1), (1, 1, 1), (2, 2, 1)]); + + let mut out = DenseMat::new(3, 3); + einsum_sparse_driven("ab,bc->ac", &a, &id, &mut out).unwrap(); + + assert_eq!(out.get(&[0, 1]), 5); + assert_eq!(out.get(&[1, 2]), 3); + assert_eq!(out.get(&[2, 0]), 7); + assert_eq!(out.get(&[0, 0]), 0); + } + + /// Dense N-dimensional tensor for testing higher-dimensional VM inputs. + struct DenseTensor { + data: Vec, + shape: Vec, + } + + impl DenseTensor { + fn new(shape: Vec) -> Self { + let n: usize = shape.iter().product(); + Self { data: vec![0; n], shape } + } + fn linear_index(&self, ix: &[usize]) -> usize { + let mut idx = 0; + let mut stride = 1; + for (&k, &dim) in ix.iter().rev().zip(self.shape.iter().rev()) { + idx += k * stride; + stride *= dim; + } + idx + } + } + + impl NDIndex for DenseTensor { + fn ndim(&self) -> usize { self.shape.len() } + fn dim(&self, axis: usize) -> usize { self.shape[axis] } + fn get(&self, ix: &[usize]) -> u32 { self.data[self.linear_index(ix)] } + fn set(&mut self, ix: &[usize], v: u32) { + let i = self.linear_index(ix); + self.data[i] = v; + } + } + + #[test] + fn test_vm_3d_dense_inputs() { + // "abc,cd->abd": batched matmul with 3D × 2D → 3D + // All dense — no sparse loops, but the VM should handle it. + let mut a = DenseTensor::new(vec![2, 3, 4]); // batch=2, rows=3, inner=4 + let mut b = DenseTensor::new(vec![4, 5]); // inner=4, cols=5 + + // Fill with simple values + for i in 0..2 { + for j in 0..3 { + for k in 0..4 { + a.set(&[i, j, k], (i * 12 + j * 4 + k + 1) as u32); + } + } + } + for k in 0..4 { + for l in 0..5 { + b.set(&[k, l], (k * 5 + l + 1) as u32); + } + } + + let mut out = DenseTensor::new(vec![2, 3, 5]); + einsum_vm_oneshot( + "abc,cd->abd", + &[&a, &b], + &mut [&mut out], + ).unwrap(); + + // Verify against naive computation + for i in 0..2 { + for j in 0..3 { + for l in 0..5 { + let mut expected = 0u32; + for k in 0..4 { + expected += a.get(&[i, j, k]) * b.get(&[k, l]); + } + assert_eq!( + out.get(&[i, j, l]), expected, + "3D mismatch at ({i},{j},{l})" + ); + } + } + } + } + + #[test] + fn test_vm_mixed_2d_sparse_and_dense() { + // "ab,bc->ac": one sparse, one dense — VM should use sparse for the + // sparse input's axis and dense for the dense input. + let a = SparseMat::from_triplets(3, &[(0, 1, 2), (1, 2, 3)]); + let mut b = DenseTensor::new(vec![3, 3]); + for i in 0..3 { + for j in 0..3 { + b.set(&[i, j], (i * 3 + j + 1) as u32); + } + } + + let mut out = DenseMat::new(3, 3); + einsum_vm_oneshot_dyn( + "ab,bc->ac", + &[&a as &dyn NDIndex, &b as &dyn NDIndex], + &mut [&mut out as &mut dyn NDIndex], + ).unwrap(); + + // Verify: row 0 of A has only (1, 2), row 1 has only (2, 3) + // C[0, j] = A[0,1]*B[1,j] = 2 * B[1,j] + // C[1, j] = A[1,2]*B[2,j] = 3 * B[2,j] + for j in 0..3 { + assert_eq!(out.get(&[0, j]), 2 * b.get(&[1, j]), "mixed@(0,{j})"); + assert_eq!(out.get(&[1, j]), 3 * b.get(&[2, j]), "mixed@(1,{j})"); + assert_eq!(out.get(&[2, j]), 0, "mixed@(2,{j})"); + } + } + + #[test] + fn test_vm_4d_attention() { + // "bhqd,bhkd->bhqk": attention-style with all dense inputs + let (ba, h, q, k, d) = (1, 1, 2, 2, 3); + let mut qm = DenseTensor::new(vec![ba, h, q, d]); + let mut km = DenseTensor::new(vec![ba, h, k, d]); + + for qi in 0..q { + for di in 0..d { + qm.set(&[0, 0, qi, di], (qi * d + di + 1) as u32); + } + } + for ki in 0..k { + for di in 0..d { + km.set(&[0, 0, ki, di], (ki * d + di + 1) as u32); + } + } + + let mut out = DenseTensor::new(vec![ba, h, q, k]); + einsum_vm_oneshot( + "bhqd,bhkd->bhqk", + &[&qm, &km], + &mut [&mut out], + ).unwrap(); + + for qi in 0..q { + for ki in 0..k { + let mut expected = 0u32; + for di in 0..d { + expected += qm.get(&[0, 0, qi, di]) * km.get(&[0, 0, ki, di]); + } + assert_eq!( + out.get(&[0, 0, qi, ki]), expected, + "attention@(0,0,{qi},{ki})" + ); + } + } + } + + impl SparseMat { + fn clone_sparse(&self) -> Self { + Self { + n: self.n, + row_ptr: self.row_ptr.clone(), + col_idx: self.col_idx.clone(), + values: self.values.clone(), + } + } + } + + /// 0-dim scalar output tensor for VM scalar tests. + struct ScalarOut { + val: u32, + } + + impl ScalarOut { + fn new() -> Self { Self { val: 0 } } + } + + impl NDIndex for ScalarOut { + fn ndim(&self) -> usize { 0 } + fn dim(&self, _axis: usize) -> usize { panic!("0-dim has no axes") } + fn get(&self, _ix: &[usize]) -> u32 { self.val } + fn set(&mut self, _ix: &[usize], v: u32) { self.val = v; } + } + + #[test] + fn test_vm_scalar_dot_sparse() { + // "i,i->" dot product with sparse inputs + // a = [0, 2, 0, 3], b = [0, 5, 7, 0] + // dot = 0 + 2*5 + 0 + 0 = 10 + let a = SparseMat::from_triplets(1, &[(0, 1, 2), (0, 3, 3)]); + let b = SparseMat::from_triplets(1, &[(0, 1, 5), (0, 2, 7)]); + + // For dot product on 1D vectors stored as 1-row matrices, + // use spec "ab,ab->" (Frobenius inner product) since SparseMat is 2D + let mut out = ScalarOut::new(); + einsum_vm_oneshot("ab,ab->", &[&a, &b], &mut [&mut out]).unwrap(); + + // a[0,1]*b[0,1] + a[0,3]*b[0,3] = 2*5 + 0 = 10 + assert_eq!(out.val, 10); + } + + #[test] + fn test_vm_scalar_trace_sparse() { + // "aa->" trace of a sparse matrix + // Only diagonal entries matter: (0,0)=1, (1,1)=5, (2,2)=9 + let m = SparseMat::from_triplets(3, &[ + (0, 0, 1), (0, 2, 2), + (1, 1, 5), (1, 2, 3), + (2, 2, 9), + ]); + + let mut out = ScalarOut::new(); + einsum_vm_oneshot("aa->", &[&m], &mut [&mut out]).unwrap(); + + assert_eq!(out.val, 15); // 1 + 5 + 9 + } + + #[test] + fn test_vm_scalar_frobenius_sparse() { + // "ab,ab->" Frobenius inner product (sum of element-wise products) + let a = SparseMat::from_triplets(3, &[ + (0, 1, 2), (1, 0, 3), (2, 2, 4), + ]); + let b = SparseMat::from_triplets(3, &[ + (0, 1, 5), (1, 0, 7), (1, 1, 99), (2, 2, 2), + ]); + + let mut out = ScalarOut::new(); + einsum_vm_oneshot("ab,ab->", &[&a, &b], &mut [&mut out]).unwrap(); + + // overlapping entries: (0,1): 2*5=10, (1,0): 3*7=21, (2,2): 4*2=8 + assert_eq!(out.val, 39); // 10 + 21 + 8 + } + + #[test] + fn test_vm_scalar_dense_dot() { + // "i,i->" with dense 1D inputs + let mut a = DenseTensor::new(vec![4]); + let mut b = DenseTensor::new(vec![4]); + for i in 0..4 { + a.set(&[i], (i + 1) as u32); // [1, 2, 3, 4] + b.set(&[i], (i + 5) as u32); // [5, 6, 7, 8] + } + + let mut out = ScalarOut::new(); + einsum_vm_oneshot("i,i->", &[&a, &b], &mut [&mut out]).unwrap(); + + // 1*5 + 2*6 + 3*7 + 4*8 = 5 + 12 + 21 + 32 = 70 + assert_eq!(out.val, 70); + } + + #[test] + fn test_vm_scalar_three_input() { + // "i,i,i->" element-wise triple product summed to scalar + let mut a = DenseTensor::new(vec![3]); + let mut b = DenseTensor::new(vec![3]); + let mut c = DenseTensor::new(vec![3]); + for i in 0..3 { + a.set(&[i], (i + 1) as u32); // [1, 2, 3] + b.set(&[i], (i + 4) as u32); // [4, 5, 6] + c.set(&[i], (i + 7) as u32); // [7, 8, 9] + } + + let mut out = ScalarOut::new(); + einsum_vm_oneshot("i,i,i->", &[&a, &b, &c], &mut [&mut out]).unwrap(); + + // 1*4*7 + 2*5*8 + 3*6*9 = 28 + 80 + 162 = 270 + assert_eq!(out.val, 270); + } + + // --- Multi-output tests --- + + #[test] + fn test_vm_multi_output_matmul_and_transpose() { + // "ab,bc->ac,ca": matmul result written to both C and C^T simultaneously + let a = SparseMat::from_triplets(3, &[(0, 1, 2), (1, 2, 3), (2, 0, 1)]); + let b = SparseMat::from_triplets(3, &[(0, 1, 4), (1, 0, 5), (2, 2, 6)]); + + let mut out_ac = DenseMat::new(3, 3); + let mut out_ca = DenseMat::new(3, 3); + einsum_vm_oneshot("ab,bc->ac,ca", &[&a, &b], &mut [&mut out_ac, &mut out_ca]).unwrap(); + + // Verify against naive matmul + let expected = naive_matmul(&a, &b); + for i in 0..3 { + for j in 0..3 { + assert_eq!(out_ac.get(&[i, j]), expected[i][j], "ac@({i},{j})"); + assert_eq!(out_ca.get(&[j, i]), expected[i][j], "ca@({j},{i})"); + } + } + } + + #[test] + fn test_vm_multi_output_dense() { + // "ab,bc->ac,ca" with dense inputs + let mut a = DenseTensor::new(vec![3, 3]); + let mut b = DenseTensor::new(vec![3, 3]); + for i in 0..3 { + for j in 0..3 { + a.set(&[i, j], (i * 3 + j + 1) as u32); + b.set(&[i, j], (i * 3 + j + 10) as u32); + } + } + + let mut out_ac = DenseTensor::new(vec![3, 3]); + let mut out_ca = DenseTensor::new(vec![3, 3]); + einsum_vm_oneshot("ab,bc->ac,ca", &[&a, &b], &mut [&mut out_ac, &mut out_ca]).unwrap(); + + // Verify C[i,j] = sum_k A[i,k]*B[k,j] + for i in 0..3 { + for j in 0..3 { + let mut expected = 0u32; + for k in 0..3 { + expected += a.get(&[i, k]) * b.get(&[k, j]); + } + assert_eq!(out_ac.get(&[i, j]), expected, "ac@({i},{j})"); + assert_eq!(out_ca.get(&[j, i]), expected, "ca@({j},{i})"); + } + } + } + + #[test] + fn test_vm_multi_output_matmul_and_scalar() { + // "ab,bc->ac," : matmul + scalar (Frobenius-like total sum) + // Wait — we need different contracted indices for the scalar. + // Actually "ab,ba->,ab" would be: scalar = sum_ab A[a,b]*B[b,a], + // and output ab = A[a,b]*B[b,a] (element-wise, no contraction for ab output) + // Let's test a simpler case: "ab,bc->ac,ac" (same output written twice) + let a = SparseMat::from_triplets(3, &[(0, 1, 2), (1, 2, 3)]); + let b = SparseMat::from_triplets(3, &[(1, 0, 5), (2, 2, 6)]); + + let mut out1 = DenseMat::new(3, 3); + let mut out2 = DenseMat::new(3, 3); + einsum_vm_oneshot("ab,bc->ac,ac", &[&a, &b], &mut [&mut out1, &mut out2]).unwrap(); + + let expected = naive_matmul(&a, &b); + for i in 0..3 { + for j in 0..3 { + assert_eq!(out1.get(&[i, j]), expected[i][j], "out1@({i},{j})"); + assert_eq!(out2.get(&[i, j]), expected[i][j], "out2@({i},{j})"); + } + } + } +} diff --git a/experiments/eval-ffi/src/source.rs b/experiments/eval-ffi/src/source.rs index be21a48..0588800 100644 --- a/experiments/eval-ffi/src/source.rs +++ b/experiments/eval-ffi/src/source.rs @@ -8,6 +8,9 @@ pub struct ExprSource { pub ptr: *const u8, // len: usize, pub position: usize, + /// Opaque context pointer. Pure functions can use this to access + /// extra state (e.g. tensor store) passed in by the eval scope. + pub context: *mut (), } #[cfg(feature = "std")] @@ -71,7 +74,7 @@ impl ExprSource { } pub fn new(ptr: *const u8) -> Self { // let len = vec.len(); - Self { ptr, position: 0 } + Self { ptr, position: 0, context: core::ptr::null_mut() } } pub fn read(&mut self) -> SourceItem { let byte = unsafe { *self.ptr.add(self.position) }; diff --git a/experiments/eval/src/lib.rs b/experiments/eval/src/lib.rs index ba90304..7595bf7 100644 --- a/experiments/eval/src/lib.rs +++ b/experiments/eval/src/lib.rs @@ -37,6 +37,8 @@ pub struct EvalScope { /// These are used to create ExprSinks for stack frames. /// This avoids repeated allocations during evaluation. alloc_pool: Vec>, + /// Opaque context pointer propagated to ExprSource for pure functions. + pub context: *mut (), } macro_rules! alloc { @@ -59,6 +61,7 @@ impl EvalScope { stack: Vec::new(), alloc_pool: Vec::new(), expr: ExprSource::new(core::ptr::null()), + context: core::ptr::null_mut(), } } pub fn add_func(&mut self, name: &str, func: FuncPtr, ty: FuncType) { @@ -136,7 +139,9 @@ impl EvalScope { let prev_frame = parent_frames.last_mut().unwrap(); let mut data = core::mem::take(&mut top_frame.sink).finish(); let offset = prev_frame.sink.as_ref().len(); - (top_frame.func)(&mut ExprSource::new(data.as_ptr()), &mut prev_frame.sink)?; + let mut src = ExprSource::new(data.as_ptr()); + src.context = self.context; + (top_frame.func)(&mut src, &mut prev_frame.sink)?; trace!(target: "eval", "{:?} ==> {:?}", mork_expr::serialize(&data[..]), mork_expr::serialize(&prev_frame.sink.as_ref()[offset..])); let top = self.stack.pop().unwrap(); // return buffer to pool diff --git a/kernel/Cargo.toml b/kernel/Cargo.toml index 6a42ac4..502394f 100644 --- a/kernel/Cargo.toml +++ b/kernel/Cargo.toml @@ -26,6 +26,8 @@ itertools = "0.14.0" base64 = "0.22.1" hex = "0.4.3" subprocess = { version = "0.2.13" } +einsum-dyn = { path = "../einsum-dyn" } +num-traits = "0.2" [features] default = ["grounding", "specialize_io"] diff --git a/kernel/src/lib.rs b/kernel/src/lib.rs index b566759..94bc701 100644 --- a/kernel/src/lib.rs +++ b/kernel/src/lib.rs @@ -8,3 +8,4 @@ pub mod space; mod sources; mod sinks; mod pure; +pub mod sparse; diff --git a/kernel/src/sinks.rs b/kernel/src/sinks.rs index f5e2c62..0ac826e 100644 --- a/kernel/src/sinks.rs +++ b/kernel/src/sinks.rs @@ -37,6 +37,7 @@ pub(crate) enum WriteResourceRequest { BTM(&'static [u8]), ACT(&'static str), Z3(&'static str), + TensorStore, } impl WriteResourceRequest { @@ -64,6 +65,12 @@ impl WriteResourceRequest { _ => { None } } } + WriteResourceRequest::TensorStore => { + match other { + WriteResourceRequest::TensorStore => { Some(WriteResourceRequest::TensorStore) } + _ => { None } + } + } } } } @@ -86,6 +93,11 @@ impl PartialOrd for WriteResourceRequest { if s == o { Some(Ordering::Equal) } else { None } } else { None } } + WriteResourceRequest::TensorStore => { + if let WriteResourceRequest::TensorStore = other { + Some(Ordering::Equal) + } else { None } + } } } } @@ -93,7 +105,8 @@ impl PartialOrd for WriteResourceRequest { pub(crate) enum WriteResource<'w, 'a, 'k> { BTM(&'w mut WriteZipperTracked<'a, 'k, ()>), ACT(()), - Z3(&'w mut subprocess::Popen) + Z3(&'w mut subprocess::Popen), + TensorStore(&'w mut std::collections::HashMap, crate::sparse::SparseTensorF64>), } // trait JoinLattice { @@ -1058,6 +1071,7 @@ impl Sink for PureSink { fn new(e: Expr) -> Self { let mut scope = EvalScope::new(); pure::register(&mut scope); + crate::sparse::register(&mut scope); PureSink { e, unique: PathMap::new(), scope } } fn request(&self) -> impl Iterator { @@ -1191,7 +1205,268 @@ impl Sink for Z3Sink { } +// ============================================================================ +// Tensor sinks — all tensor I/O goes through WriteResource::TensorStore +// ============================================================================ + +/// Helper: parse symbol args from an expression path, skipping a known header. +fn parse_symbol_args<'a>(path: &'a [u8], header_size: usize) -> Vec<&'a [u8]> { + let rest = &path[header_size..]; + let mut pos = 0; + let mut args = Vec::new(); + while pos < rest.len() { + match byte_item(rest[pos]) { + Tag::SymbolSize(len) => { + let len = len as usize; + pos += 1; + if pos + len <= rest.len() { + args.push(&rest[pos..pos+len]); + } + pos += len; + } + _ => { pos += 1; } + } + } + args +} + +/// TensorCollectSink — accumulates (indices, value) tuples into a named SparseTensorF64. +/// Syntax: (tensor_collect name $i0 $i1 ... $val) +pub struct TensorCollectSink { + e: Expr, + name: Vec, + rank: usize, + entries: Vec<(Vec, f64)>, +} + +impl TensorCollectSink { + const HEADER_SIZE: usize = 16; // Arity(N) + SymbolSize(14) + "tensor_collect" +} + +impl Sink for TensorCollectSink { + fn new(e: Expr) -> Self { + let arity = unsafe { + if let Tag::Arity(a) = byte_item(*e.ptr) { a as usize } else { panic!("tensor_collect: expected arity") } + }; + let rank = arity.saturating_sub(3); + let name = unsafe { + let name_tag = *e.ptr.add(Self::HEADER_SIZE); + if let Tag::SymbolSize(len) = byte_item(name_tag) { + std::slice::from_raw_parts(e.ptr.add(Self::HEADER_SIZE + 1), len as usize).to_vec() + } else { + panic!("tensor_collect: second arg must be a symbol (tensor name)") + } + }; + TensorCollectSink { e, name, rank, entries: Vec::new() } + } + + fn request(&self) -> impl Iterator { + std::iter::once(WriteResourceRequest::TensorStore) + } + + fn sink<'w, 'a, 'k, It: Iterator>>(&mut self, _it: It, path: &[u8]) where 'a: 'w, 'k: 'w { + // Parse args after header + name + let name_len = self.name.len(); + let args = parse_symbol_args(path, Self::HEADER_SIZE + 1 + name_len); + + if args.len() >= self.rank + 1 { + let mut indices = Vec::with_capacity(self.rank); + for i in 0..self.rank { + if let Ok(s) = std::str::from_utf8(args[i]) { + if let Ok(idx) = s.parse::() { + indices.push(idx); + } else { return; } + } else { return; } + } + if let Ok(s) = std::str::from_utf8(args[self.rank]) { + if let Ok(val) = s.parse::() { + self.entries.push((indices, val)); + } + } + } + } + + fn finalize<'w, 'a, 'k, It: Iterator>>(&mut self, mut it: It) -> bool where 'a: 'w, 'k: 'w { + if self.entries.is_empty() { return false; } + let WriteResource::TensorStore(store) = it.next().unwrap() else { unreachable!() }; + + let mut tensor = crate::sparse::SparseTensorF64::new(self.rank); + for (indices, value) in self.entries.drain(..) { + tensor.set(&indices, value); + } + store.insert(self.name.clone(), tensor); + true + } +} + +/// TensorEinsumSink — runs einsum on named tensors. +/// Syntax: (tensor_einsum "spec" input1 input2 ... output) +/// Variables are bound from pattern matching; actual values arrive in sink(). +pub struct TensorEinsumSink { + e: Expr, + spec: Vec, + input_names: Vec>, + output_name: Vec, + parsed: bool, +} + +impl TensorEinsumSink { + // "tensor_einsum" is 13 chars → header = Arity + SymbolSize(13) + "tensor_einsum" = 15 + const HEADER_SIZE: usize = 15; +} + +impl Sink for TensorEinsumSink { + fn new(e: Expr) -> Self { + let span = unsafe { e.span().as_ref().unwrap() }; + let args = parse_symbol_args(span, Self::HEADER_SIZE); + let spec = args.first().map(|a| a.to_vec()).unwrap_or_default(); + let mut names: Vec> = args.get(1..).unwrap_or(&[]).iter().map(|a| a.to_vec()).collect(); + let output_name = names.pop().unwrap_or_default(); + TensorEinsumSink { e, spec, input_names: names, output_name, parsed: !args.is_empty() } + } + fn request(&self) -> impl Iterator { + std::iter::once(WriteResourceRequest::TensorStore) + } + fn sink<'w, 'a, 'k, It: Iterator>>(&mut self, _it: It, _path: &[u8]) where 'a: 'w, 'k: 'w {} + fn finalize<'w, 'a, 'k, It: Iterator>>(&mut self, mut it: It) -> bool where 'a: 'w, 'k: 'w { + if !self.parsed { return false; } + let WriteResource::TensorStore(store) = it.next().unwrap() else { unreachable!() }; + // Strip quotes from spec if present (MORK parser includes them) + let spec_bytes = if self.spec.starts_with(b"\"") && self.spec.ends_with(b"\"") { + &self.spec[1..self.spec.len()-1] + } else { + &self.spec[..] + }; + let spec_str = match std::str::from_utf8(spec_bytes) { + Ok(s) => s, + Err(_) => { log::error!(target: "tensor", "tensor_einsum: invalid spec"); return false; } + }; + + let inputs: Vec<&crate::sparse::SparseTensorF64> = self.input_names.iter() + .filter_map(|name| store.get(name)) + .collect(); + if inputs.len() != self.input_names.len() { + log::error!(target: "tensor", "tensor_einsum: missing input tensor(s)"); + return false; + } + + // Parse spec to determine output shape + let arrow = match spec_str.find("->") { + Some(a) => a, + None => { log::error!(target: "tensor", "tensor_einsum: spec missing '->'"); return false; } + }; + let output_spec = &spec_str[arrow+2..]; + let input_specs: Vec<&str> = spec_str[..arrow].split(',').collect(); + + let mut dim_map = std::collections::HashMap::new(); + for (i, ispec) in input_specs.iter().enumerate() { + if let Some(inp) = inputs.get(i) { + for (axis, ch) in ispec.bytes().enumerate() { + let d = inp.dims.get(axis).copied().unwrap_or(0); + dim_map.entry(ch).or_insert(d); + } + } + } + let out_dims: Vec = output_spec.bytes() + .map(|ch| dim_map.get(&ch).copied().unwrap_or(1)) + .collect(); + + let mut output = crate::sparse::SparseTensorF64::with_dims(out_dims); + let dyn_inputs: Vec<&dyn einsum_dyn::NDIndex> = inputs.iter() + .map(|t| *t as &dyn einsum_dyn::NDIndex) + .collect(); + let mut dyn_out: &mut dyn einsum_dyn::NDIndex = &mut output; + if let Err(e) = einsum_dyn::sparse::einsum_vm_oneshot_dyn(spec_str, &dyn_inputs, &mut [dyn_out]) { + log::error!(target: "tensor", "tensor_einsum failed: {}", e); + return false; + } + store.insert(self.output_name.clone(), output); + true + } +} + +/// TensorBinopSink — element-wise add or mul on named tensors. +/// Syntax: (tensor_add A B C) or (tensor_mul A B C) +pub struct TensorBinopSink { + e: Expr, + op: TensorBinop, + header_size: usize, + a_name: Vec, + b_name: Vec, + c_name: Vec, + parsed: bool, +} + +enum TensorBinop { Add, Mul } + +impl TensorBinopSink { + fn new_with_op(e: Expr, op: TensorBinop, header_size: usize) -> Self { + TensorBinopSink { e, op, header_size, a_name: Vec::new(), b_name: Vec::new(), c_name: Vec::new(), parsed: false } + } +} + +impl Sink for TensorBinopSink { + fn new(e: Expr) -> Self { panic!("use new_with_op") } + fn request(&self) -> impl Iterator { + std::iter::once(WriteResourceRequest::TensorStore) + } + fn sink<'w, 'a, 'k, It: Iterator>>(&mut self, _it: It, path: &[u8]) where 'a: 'w, 'k: 'w { + if self.parsed { return; } + let args = parse_symbol_args(path, self.header_size); + self.a_name = args.first().map(|a| a.to_vec()).unwrap_or_default(); + self.b_name = args.get(1).map(|a| a.to_vec()).unwrap_or_default(); + self.c_name = args.get(2).map(|a| a.to_vec()).unwrap_or_default(); + self.parsed = true; + } + fn finalize<'w, 'a, 'k, It: Iterator>>(&mut self, mut it: It) -> bool where 'a: 'w, 'k: 'w { + if !self.parsed { return false; } + let WriteResource::TensorStore(store) = it.next().unwrap() else { unreachable!() }; + let (a, b) = match (store.get(&self.a_name), store.get(&self.b_name)) { + (Some(a), Some(b)) => (a, b), + _ => { log::error!(target: "tensor", "tensor binop: missing input"); return false; } + }; + let c = match self.op { + TensorBinop::Add => a.add(b), + TensorBinop::Mul => a.mul(b), + }; + store.insert(self.c_name.clone(), c); + true + } +} + +/// TensorFreeSink — removes a named tensor. +/// Syntax: (tensor_free A) +pub struct TensorFreeSink { + e: Expr, + name: Vec, + parsed: bool, +} + +impl Sink for TensorFreeSink { + fn new(e: Expr) -> Self { + TensorFreeSink { e, name: Vec::new(), parsed: false } + } + fn request(&self) -> impl Iterator { + std::iter::once(WriteResourceRequest::TensorStore) + } + fn sink<'w, 'a, 'k, It: Iterator>>(&mut self, _it: It, path: &[u8]) where 'a: 'w, 'k: 'w { + if self.parsed { return; } + // "tensor_free" is 11 chars → header = 13 + let args = parse_symbol_args(path, 13); + self.name = args.first().map(|a| a.to_vec()).unwrap_or_default(); + self.parsed = true; + } + fn finalize<'w, 'a, 'k, It: Iterator>>(&mut self, mut it: It) -> bool where 'a: 'w, 'k: 'w { + let WriteResource::TensorStore(store) = it.next().unwrap() else { unreachable!() }; + store.remove(&self.name).is_some() + } +} + pub enum ASink { AddSink(AddSink), RemoveSink(RemoveSink), HeadSink(HeadSink), CountSink(CountSink), HashSink(HashSink), SumSink(SumSink), AndSink(AndSink), ACTSink(ACTSink), + TensorCollectSink(TensorCollectSink), + TensorEinsumSink(TensorEinsumSink), + TensorBinopSink(TensorBinopSink), + TensorFreeSink(TensorFreeSink), #[cfg(feature = "wasm")] WASMSink(WASMSink), #[cfg(feature = "grounding")] @@ -1211,6 +1486,14 @@ impl ASink { pub fn compat(e: Expr) -> Self { ASink::CompatSink(CompatSink::new(e)) } + /// Set an opaque context pointer on any PureSink's EvalScope. + /// This allows pure functions to access external state (e.g. tensor store). + pub fn set_context(&mut self, ctx: *mut ()) { + #[cfg(feature = "grounding")] + if let ASink::PureSink(s) = self { + s.scope.context = ctx; + } + } } impl Sink for ASink { @@ -1271,6 +1554,35 @@ impl Sink for ASink { return ASink::Z3Sink(Z3Sink::new(e)); #[cfg(not(feature = "z3"))] panic!("MORK was not built with the z3 feature, yet trying to call {:?}", e); + } else if unsafe { *e.ptr.offset(1) == item_byte(Tag::SymbolSize(14)) && + *e.ptr.offset(2) == b't' && *e.ptr.offset(3) == b'e' && *e.ptr.offset(4) == b'n' && *e.ptr.offset(5) == b's' && + *e.ptr.offset(6) == b'o' && *e.ptr.offset(7) == b'r' && *e.ptr.offset(8) == b'_' && *e.ptr.offset(9) == b'c' && + *e.ptr.offset(10) == b'o' && *e.ptr.offset(11) == b'l' && *e.ptr.offset(12) == b'l' && *e.ptr.offset(13) == b'e' && + *e.ptr.offset(14) == b'c' && *e.ptr.offset(15) == b't' } { + return ASink::TensorCollectSink(TensorCollectSink::new(e)); + } else if unsafe { *e.ptr.offset(1) == item_byte(Tag::SymbolSize(13)) && + *e.ptr.offset(2) == b't' && *e.ptr.offset(3) == b'e' && *e.ptr.offset(4) == b'n' && *e.ptr.offset(5) == b's' && + *e.ptr.offset(6) == b'o' && *e.ptr.offset(7) == b'r' && *e.ptr.offset(8) == b'_' && *e.ptr.offset(9) == b'e' && + *e.ptr.offset(10) == b'i' && *e.ptr.offset(11) == b'n' && *e.ptr.offset(12) == b's' && *e.ptr.offset(13) == b'u' && + *e.ptr.offset(14) == b'm' } { + return ASink::TensorEinsumSink(TensorEinsumSink::new(e)); + } else if unsafe { *e.ptr.offset(1) == item_byte(Tag::SymbolSize(10)) && + *e.ptr.offset(2) == b't' && *e.ptr.offset(3) == b'e' && *e.ptr.offset(4) == b'n' && *e.ptr.offset(5) == b's' && + *e.ptr.offset(6) == b'o' && *e.ptr.offset(7) == b'r' && *e.ptr.offset(8) == b'_' && *e.ptr.offset(9) == b'a' && + *e.ptr.offset(10) == b'd' && *e.ptr.offset(11) == b'd' } { + // "tensor_add" = 10 chars, header = 12 + return ASink::TensorBinopSink(TensorBinopSink::new_with_op(e, TensorBinop::Add, 12)); + } else if unsafe { *e.ptr.offset(1) == item_byte(Tag::SymbolSize(10)) && + *e.ptr.offset(2) == b't' && *e.ptr.offset(3) == b'e' && *e.ptr.offset(4) == b'n' && *e.ptr.offset(5) == b's' && + *e.ptr.offset(6) == b'o' && *e.ptr.offset(7) == b'r' && *e.ptr.offset(8) == b'_' && *e.ptr.offset(9) == b'm' && + *e.ptr.offset(10) == b'u' && *e.ptr.offset(11) == b'l' } { + // "tensor_mul" = 10 chars, header = 12 + return ASink::TensorBinopSink(TensorBinopSink::new_with_op(e, TensorBinop::Mul, 12)); + } else if unsafe { *e.ptr.offset(1) == item_byte(Tag::SymbolSize(11)) && + *e.ptr.offset(2) == b't' && *e.ptr.offset(3) == b'e' && *e.ptr.offset(4) == b'n' && *e.ptr.offset(5) == b's' && + *e.ptr.offset(6) == b'o' && *e.ptr.offset(7) == b'r' && *e.ptr.offset(8) == b'_' && *e.ptr.offset(9) == b'f' && + *e.ptr.offset(10) == b'r' && *e.ptr.offset(11) == b'e' && *e.ptr.offset(12) == b'e' } { + return ASink::TensorFreeSink(TensorFreeSink::new(e)); } else { panic!("unrecognized sink") } @@ -1300,6 +1612,10 @@ impl Sink for ASink { ASink::FMinSink(s) => { for i in s.request().into_iter() { yield i } } ASink::FMaxSink(s) => { for i in s.request().into_iter() { yield i } } ASink::FProdSink(s) => { for i in s.request().into_iter() { yield i } } + ASink::TensorCollectSink(s) => { for i in s.request().into_iter() { yield i } } + ASink::TensorEinsumSink(s) => { for i in s.request().into_iter() { yield i } } + ASink::TensorBinopSink(s) => { for i in s.request().into_iter() { yield i } } + ASink::TensorFreeSink(s) => { for i in s.request().into_iter() { yield i } } } } } @@ -1326,6 +1642,10 @@ impl Sink for ASink { ASink::FMinSink(s) => { s.sink(it, path) } ASink::FMaxSink(s) => { s.sink(it, path) } ASink::FProdSink(s) => { s.sink(it, path) } + ASink::TensorCollectSink(s) => { s.sink(it, path) } + ASink::TensorEinsumSink(s) => { s.sink(it, path) } + ASink::TensorBinopSink(s) => { s.sink(it, path) } + ASink::TensorFreeSink(s) => { s.sink(it, path) } } } @@ -1352,6 +1672,10 @@ impl Sink for ASink { ASink::FMinSink(s) => { s.finalize(it) } ASink::FMaxSink(s) => { s.finalize(it) } ASink::FProdSink(s) => { s.finalize(it) } + ASink::TensorCollectSink(s) => { s.finalize(it) } + ASink::TensorEinsumSink(s) => { s.finalize(it) } + ASink::TensorBinopSink(s) => { s.finalize(it) } + ASink::TensorFreeSink(s) => { s.finalize(it) } } } } diff --git a/kernel/src/space.rs b/kernel/src/space.rs index f954638..4811a9d 100644 --- a/kernel/src/space.rs +++ b/kernel/src/space.rs @@ -40,6 +40,7 @@ pub struct Space { pub sm: SharedMappingHandle, pub mmaps: HashMap>, pub z3s: HashMap>, + pub tensors: HashMap, crate::sparse::SparseTensorF64>, pub last_merkleize: Instant, pub timing: bool } @@ -444,7 +445,7 @@ macro_rules! sexpr { impl Space { pub fn new() -> Self { - Self { btm: PathMap::new(), sm: SharedMapping::new(), mmaps: HashMap::new(), z3s: HashMap::new(), last_merkleize: Instant::now(), timing: false } + Self { btm: PathMap::new(), sm: SharedMapping::new(), mmaps: HashMap::new(), z3s: HashMap::new(), tensors: HashMap::new(), last_merkleize: Instant::now(), timing: false } } pub fn parse_sexpr(&mut self, r: &[u8], buf: *mut u8) -> Result<(Expr, usize), ParserError> { @@ -1078,6 +1079,7 @@ impl Space { unsafe fn write_handler<'w, 'a, 'k>(zh_wzs: (*mut ZipperHead<'w, 'a, ()>, *mut Vec>), mmaps: *mut HashMap>, z3s: *mut HashMap>, + tensors: *mut HashMap, crate::sparse::SparseTensorF64>, request: &WriteResourceRequest) -> WriteResource<'w, 'a, 'k> where 'w : 'a { match *request { WriteResourceRequest::BTM(p) => { @@ -1103,6 +1105,9 @@ impl Space { }).as_mut(); WriteResource::Z3(instance) } + WriteResourceRequest::TensorStore => { + WriteResource::TensorStore(tensors.as_mut().unwrap()) + } } } @@ -1488,6 +1493,7 @@ impl Space { #[cfg(feature="specialize_io")] pub fn transform_multi_multi_o(&mut self, pat_expr: Expr, tpl_expr: Expr, add: Expr) -> (usize, bool) { use crate::sinks::*; + let tensors_ptr = (&self.tensors as *const HashMap, crate::sparse::SparseTensorF64>).cast_mut(); let mut buffer = Vec::with_capacity(1 << 32); unsafe { buffer.set_len(1 << 32); } let mut tpl_args = Vec::with_capacity(64); @@ -1511,7 +1517,7 @@ impl Space { template_prefixes.iter().enumerate().for_each(|(i, request)| { if subsumption[i] == i { placements[i] = template_resources.len(); - template_resources.push(unsafe { Self::write_handler((zh_ptr, outstanding_wzs_ptr), acts_ptr, z3s_ptr, request) }); + template_resources.push(unsafe { Self::write_handler((zh_ptr, outstanding_wzs_ptr), acts_ptr, z3s_ptr, tensors_ptr, request) }); } }); for i in 0..subsumption.len() { @@ -1523,7 +1529,7 @@ impl Space { let mut assignments: Vec<(u8, u8)> = vec![]; let mut trace: Vec<(u8, u8)> = vec![]; - + let mut ass = Vec::with_capacity(64); let mut astack = Vec::with_capacity(64); @@ -1569,6 +1575,7 @@ impl Space { }); for (i, s) in sinks.iter_mut().enumerate() { + s.set_context(tensors_ptr.cast()); let wz = unsafe { std::ptr::read(&template_resources[subsumption[i]]) }; any_new |= s.finalize(std::iter::once(wz)); } @@ -1581,6 +1588,8 @@ impl Space { pub fn transform_multi_multi_io(&mut self, pat_expr: Expr, tpl_expr: Expr, add: Expr, no_source: bool, no_sink: bool) -> (usize, bool) { use crate::sinks::*; + // Set tensor store pointer for sinks/pure functions that need it + let tensors_ptr = (&self.tensors as *const HashMap, crate::sparse::SparseTensorF64>).cast_mut(); let mut buffer = Vec::with_capacity(1 << 32); unsafe { buffer.set_len(1 << 32); } let mut tpl_args = Vec::with_capacity(64); @@ -1604,7 +1613,7 @@ impl Space { template_prefixes.iter().enumerate().for_each(|(i, request)| { if subsumption[i] == i { placements[i] = template_resources.len(); - template_resources.push(unsafe { Self::write_handler((zh_ptr, outstanding_wzs_ptr), acts_ptr, z3s_ptr, request) }); + template_resources.push(unsafe { Self::write_handler((zh_ptr, outstanding_wzs_ptr), acts_ptr, z3s_ptr, tensors_ptr, request) }); } }); for i in 0..subsumption.len() { @@ -1663,6 +1672,7 @@ impl Space { }); for (i, s) in sinks.iter_mut().enumerate() { + s.set_context(tensors_ptr.cast()); let wz = unsafe { std::ptr::read(&template_resources[subsumption[i]]) }; any_new |= s.finalize(std::iter::once(wz)); } @@ -1672,7 +1682,7 @@ impl Space { (touched, any_new) } - + // (exec (, ) // (, )) pub fn interpret(&mut self, rt: Expr) -> Result<(), &'static str> { diff --git a/kernel/src/sparse.rs b/kernel/src/sparse.rs new file mode 100644 index 0000000..d1ecf54 --- /dev/null +++ b/kernel/src/sparse.rs @@ -0,0 +1,361 @@ +use std::collections::HashMap; +use std::hash::Hasher; +use num_traits::Zero; +use pathmap::PathMap; +use pathmap::ring::{AlgebraicResult, Lattice}; +use pathmap::utils::ints::indices_to_bob; + +// ============================================================================ +// FAddMulF64 — Lattice wrapper for f64 (join=add, meet=mul) +// ============================================================================ + +#[derive(Copy, Clone, Debug)] +#[repr(transparent)] +pub struct FAddMulF64(pub f64); + +impl std::ops::Deref for FAddMulF64 { + type Target = f64; + fn deref(&self) -> &Self::Target { &self.0 } +} + +impl std::hash::Hash for FAddMulF64 { + fn hash(&self, state: &mut H) { + self.0.to_bits().hash(state); + } +} + +impl Lattice for FAddMulF64 { + fn pjoin(&self, other: &Self) -> AlgebraicResult { + if self.0.is_zero() { return AlgebraicResult::Identity(1) } + if other.0.is_zero() { return AlgebraicResult::Identity(2) } + let s = self.0 + other.0; + if self.0 * other.0 < 0f64 && s.abs() < 1e-15 { return AlgebraicResult::None } + AlgebraicResult::Element(FAddMulF64(s)) + } + + fn pmeet(&self, other: &Self) -> AlgebraicResult { + let s = self.0 * other.0; + if s.abs() < 1e-15 { return AlgebraicResult::None } + AlgebraicResult::Element(FAddMulF64(s)) + } +} + +// ============================================================================ +// SparseTensorF64 — PathMap-backed sparse tensor with BOB encoding +// ============================================================================ + +pub struct SparseTensorF64 { + pub m: PathMap, + pub d: usize, + pub dims: Vec, + p: Vec, +} + +impl SparseTensorF64 { + pub fn new(rank: usize) -> Self { + Self { m: PathMap::new(), d: rank, dims: vec![0; rank], p: Vec::new() } + } + + pub fn with_dims(dims: Vec) -> Self { + let d = dims.len(); + Self { m: PathMap::new(), d, dims, p: Vec::new() } + } + + pub fn set(&mut self, ix: &[usize], v: f64) { + for (i, &idx) in ix.iter().enumerate() { + if idx >= self.dims[i] { + self.dims[i] = idx + 1; + } + } + let path = Self::index_to_path_static(ix); + self.m.insert(&path[..], v); + } + + pub fn get(&self, ix: &[usize]) -> Option { + let path = Self::index_to_path_static(ix); + self.m.get(&path[..]).copied() + } + + pub fn remove(&mut self, ix: &[usize]) -> Option { + let path = Self::index_to_path_static(ix); + self.m.remove(&path[..]) + } + + pub fn nnz(&self) -> usize { + self.m.val_count() + } + + fn index_to_path_static(ix: &[usize]) -> Vec { + let mut p = Vec::new(); + let len = indices_to_bob(ix, &mut vec![]); + p.extend(std::iter::repeat_n(0u8, 64 - len)); + indices_to_bob(ix, &mut p); + p + } + + // Safety: FAddMulF64 has the same layout as f64 + fn vf(&self) -> &PathMap { + unsafe { (&self.m as *const PathMap as *const PathMap).as_ref().unwrap_unchecked() } + } + + fn from_vf(m: PathMap, d: usize, dims: Vec) -> Self { + unsafe { Self { m: std::mem::transmute::, PathMap>(m), d, dims, p: Vec::new() } } + } + + pub fn add(&self, other: &Self) -> Self { + let dims = self.dims.iter().zip(&other.dims).map(|(&a, &b)| a.max(b)).collect(); + Self::from_vf(self.vf().join(other.vf()), self.d, dims) + } + + pub fn mul(&self, other: &Self) -> Self { + let dims = self.dims.iter().zip(&other.dims).map(|(&a, &b)| a.max(b)).collect(); + Self::from_vf(self.vf().meet(other.vf()), self.d, dims) + } +} + +// ============================================================================ +// NDIndex implementation for einsum-dyn compatibility +// ============================================================================ + +impl einsum_dyn::NDIndex for SparseTensorF64 { + fn ndim(&self) -> usize { self.d } + fn dim(&self, axis: usize) -> usize { self.dims[axis] } + + fn get(&self, indices: &[usize]) -> f64 { + self.get(indices).unwrap_or(0.0) + } + + fn set(&mut self, indices: &[usize], val: f64) { + if val.abs() < 1e-15 { + self.remove(indices); + } else { + self.set(indices, val); + } + } + + fn get_opt(&self, indices: &[usize]) -> Option { + self.get(indices) + } + + fn is_sparse_2d(&self) -> bool { false } +} + +// ============================================================================ +// Pure functions (access tensor store via ExprSource.context) +// ============================================================================ + +use eval_ffi::{ExprSource, ExprSink, EvalError}; +use mork_expr::SourceItem; +use eval::{EvalScope, FuncType}; + +/// Read tensor store from the ExprSource context pointer. +/// The context is set by PureSink via ASink::set_context before eval. +unsafe fn tensor_store_from_context(expr: &ExprSource) -> Option<&HashMap, SparseTensorF64>> { + (expr.context as *const HashMap, SparseTensorF64>).as_ref() +} + +/// (tensor_get name i0 i1 ... iN) -> f64 value at that index +pub extern "C" fn tensor_get(expr: *mut ExprSource, sink: *mut ExprSink) -> Result<(), EvalError> { + let expr = unsafe { &mut *expr }; + let sink = unsafe { &mut *sink }; + let items = expr.consume_head_check(b"tensor_get")?; + if items < 2 { return Err(EvalError::from("tensor_get requires name and indices")) } + + let SourceItem::Symbol(name) = expr.read() else { + return Err(EvalError::from("tensor_get: first arg must be tensor name")) + }; + let name = name.to_vec(); + + let mut indices: Vec = Vec::new(); + for _ in 0..(items - 1) { + let idx = expr.consume::()?; + indices.push(idx as usize); + } + + let store = unsafe { tensor_store_from_context(expr) }; + let val = store + .and_then(|s| s.get(&name)) + .and_then(|t| t.get(&indices)) + .unwrap_or(0.0); + + sink.write(SourceItem::Symbol(&val.to_be_bytes()[..]))?; + Ok(()) +} + +/// (tensor_nnz name) -> u64 count of non-zeros +pub extern "C" fn tensor_nnz(expr: *mut ExprSource, sink: *mut ExprSink) -> Result<(), EvalError> { + let expr = unsafe { &mut *expr }; + let sink = unsafe { &mut *sink }; + let items = expr.consume_head_check(b"tensor_nnz")?; + if items != 1 { return Err(EvalError::from("tensor_nnz takes one argument")) } + + let SourceItem::Symbol(name) = expr.read() else { + return Err(EvalError::from("tensor_nnz: arg must be tensor name")) + }; + let name = name.to_vec(); + + let store = unsafe { tensor_store_from_context(expr) }; + let nnz = store + .and_then(|s| s.get(&name)) + .map(|t| t.nnz()) + .unwrap_or(0) as u64; + + sink.write(SourceItem::Symbol(&nnz.to_be_bytes()[..]))?; + Ok(()) +} + +pub fn register(scope: &mut EvalScope) { + scope.add_func("tensor_get", tensor_get, FuncType::Pure); + scope.add_func("tensor_nnz", tensor_nnz, FuncType::Pure); +} + +// ============================================================================ +// Tests +// ============================================================================ + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_sparse_tensor_basic() { + let mut t = SparseTensorF64::new(2); + t.set(&[0, 1], 3.0); + t.set(&[1, 2], 5.0); + t.set(&[2, 0], 7.0); + + assert_eq!(t.get(&[0, 1]), Some(3.0)); + assert_eq!(t.get(&[1, 2]), Some(5.0)); + assert_eq!(t.get(&[2, 0]), Some(7.0)); + assert_eq!(t.get(&[0, 0]), None); + assert_eq!(t.nnz(), 3); + assert_eq!(t.dims, vec![3, 3]); + } + + #[test] + fn test_sparse_tensor_4d() { + let mut t = SparseTensorF64::new(4); + t.set(&[1, 2, 3, 4], 42.0); + t.set(&[0, 0, 0, 0], 1.0); + + assert_eq!(t.get(&[1, 2, 3, 4]), Some(42.0)); + assert_eq!(t.get(&[0, 0, 0, 0]), Some(1.0)); + assert_eq!(t.get(&[1, 1, 1, 1]), None); + assert_eq!(t.nnz(), 2); + } + + #[test] + fn test_sparse_tensor_add() { + let mut a = SparseTensorF64::new(2); + a.set(&[0, 0], 1.0); + a.set(&[0, 1], 2.0); + + let mut b = SparseTensorF64::new(2); + b.set(&[0, 0], 10.0); + b.set(&[1, 0], 20.0); + + let c = a.add(&b); + assert_eq!(c.get(&[0, 0]), Some(11.0)); + assert_eq!(c.get(&[0, 1]), Some(2.0)); + assert_eq!(c.get(&[1, 0]), Some(20.0)); + } + + #[test] + fn test_sparse_tensor_mul() { + let mut a = SparseTensorF64::new(2); + a.set(&[0, 0], 3.0); + a.set(&[0, 1], 5.0); + + let mut b = SparseTensorF64::new(2); + b.set(&[0, 0], 2.0); + b.set(&[1, 1], 4.0); + + let c = a.mul(&b); + assert_eq!(c.get(&[0, 0]), Some(6.0)); + assert_eq!(c.get(&[0, 1]), None); + assert_eq!(c.get(&[1, 1]), None); + } + + #[test] + fn test_ndindex_impl() { + use einsum_dyn::NDIndex; + let mut t = SparseTensorF64::with_dims(vec![3, 3]); + >::set(&mut t, &[0, 1], 5.0); + assert_eq!(>::get(&t, &[0, 1]), 5.0); + assert_eq!(>::get(&t, &[0, 0]), 0.0); + assert_eq!(t.get_opt(&[0, 1]), Some(5.0)); + assert_eq!(t.get_opt(&[0, 0]), None); + } + + #[test] + fn test_einsum_matmul() { + use einsum_dyn::NDIndex; + let mut a = SparseTensorF64::with_dims(vec![2, 2]); + a.set(&[0, 0], 1.0); a.set(&[0, 1], 2.0); + a.set(&[1, 0], 3.0); a.set(&[1, 1], 4.0); + + let mut b = SparseTensorF64::with_dims(vec![2, 2]); + b.set(&[0, 0], 5.0); b.set(&[0, 1], 6.0); + b.set(&[1, 0], 7.0); b.set(&[1, 1], 8.0); + + let mut c = SparseTensorF64::with_dims(vec![2, 2]); + + let inputs: Vec<&dyn einsum_dyn::NDIndex> = vec![&a, &b]; + let mut out: &mut dyn einsum_dyn::NDIndex = &mut c; + einsum_dyn::sparse::einsum_vm_oneshot_dyn("ab,bc->ac", &inputs, &mut [out]).unwrap(); + + assert_eq!(>::get(&c, &[0, 0]), 19.0); + assert_eq!(>::get(&c, &[0, 1]), 22.0); + assert_eq!(>::get(&c, &[1, 0]), 43.0); + assert_eq!(>::get(&c, &[1, 1]), 50.0); + } + + #[test] + fn test_end_to_end_collect_einsum() { + use crate::space::Space; + + let mut s = Space::new(); + + // Load matrix A = [[1, 2], [3, 4]] as (a row col val) triples + // Load matrix B = [[5, 6], [7, 8]] as (b row col val) triples + // Matrix A = [[1, 2], [3, 4]], B = [[5, 6], [7, 8]] + // Store as (a row col val) triples, collect into tensors, einsum, check result + s.add_all_sexpr(r#" + (a 0 0 1.0) + (a 0 1 2.0) + (a 1 0 3.0) + (a 1 1 4.0) + + (b 0 0 5.0) + (b 0 1 6.0) + (b 1 0 7.0) + (b 1 1 8.0) + + (exec P1 (, (a $r $c $v)) (O (tensor_collect A $r $c $v))) + (exec P2 (, (b $r $c $v)) (O (tensor_collect B $r $c $v))) + (exec P3 (,) (O (tensor_einsum "ab,bc->ac" A B C))) + "#.as_bytes()).unwrap(); + + s.metta_calculus(100); + + // Verify tensors were collected + assert!(s.tensors.contains_key(b"A".as_slice()), "tensor A should exist"); + assert!(s.tensors.contains_key(b"B".as_slice()), "tensor B should exist"); + assert!(s.tensors.contains_key(b"C".as_slice()), "tensor C should exist"); + + let a = s.tensors.get(b"A".as_slice()).unwrap(); + assert_eq!(a.nnz(), 4); + assert_eq!(a.get(&[0, 0]), Some(1.0)); + assert_eq!(a.get(&[1, 1]), Some(4.0)); + + let b = s.tensors.get(b"B".as_slice()).unwrap(); + assert_eq!(b.nnz(), 4); + + // C = A * B = [[19, 22], [43, 50]] + let c = s.tensors.get(b"C".as_slice()).unwrap(); + assert_eq!(c.get(&[0, 0]), Some(19.0)); + assert_eq!(c.get(&[0, 1]), Some(22.0)); + assert_eq!(c.get(&[1, 0]), Some(43.0)); + assert_eq!(c.get(&[1, 1]), Some(50.0)); + } +} From d18ac2f29181e2734f89c0f4966f65473fc6766f Mon Sep 17 00:00:00 2001 From: Igor Malovitsa Date: Sat, 4 Apr 2026 17:06:34 +0000 Subject: [PATCH 2/6] Add tensor_get to end-to-end test and fix index parsing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - tensor_get now parses indices from string symbols (not binary u32), matching how MeTTa S-expressions encode numbers - End-to-end test exercises full pipeline: load data → tensor_collect → tensor_einsum → tensor_get to verify C[0,0]=19.0 and C[1,1]=50.0 - TensorEinsumSink strips quotes from spec string - Command sinks (einsum, binop, free) parse args in new() for literal templates, or in sink() for variable-bound templates Co-Authored-By: Claude Opus 4.6 (1M context) --- kernel/src/sparse.rs | 31 +++++++++++++++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/kernel/src/sparse.rs b/kernel/src/sparse.rs index d1ecf54..1b01aa5 100644 --- a/kernel/src/sparse.rs +++ b/kernel/src/sparse.rs @@ -168,8 +168,14 @@ pub extern "C" fn tensor_get(expr: *mut ExprSource, sink: *mut ExprSink) -> Resu let mut indices: Vec = Vec::new(); for _ in 0..(items - 1) { - let idx = expr.consume::()?; - indices.push(idx as usize); + let SourceItem::Symbol(idx_bytes) = expr.read() else { + return Err(EvalError::from("tensor_get: index must be a symbol")) + }; + let idx = std::str::from_utf8(idx_bytes) + .map_err(|_| EvalError::from("tensor_get: index not utf8"))? + .parse::() + .map_err(|_| EvalError::from("tensor_get: index not a number"))?; + indices.push(idx); } let store = unsafe { tensor_store_from_context(expr) }; @@ -357,5 +363,26 @@ mod tests { assert_eq!(c.get(&[0, 1]), Some(22.0)); assert_eq!(c.get(&[1, 0]), Some(43.0)); assert_eq!(c.get(&[1, 1]), Some(50.0)); + + // Phase 2: test tensor_get directly through EvalScope + { + use eval_ffi::ExprSource; + let mut scope = eval::EvalScope::new(); + crate::sparse::register(&mut scope); + scope.context = (&s.tensors as *const HashMap, SparseTensorF64>).cast_mut().cast(); + + // Build expression: (tensor_get C 0 0) + let expr_bytes = mork_expr::construct!("tensor_get" "C" "0" "0").unwrap(); + let result = scope.eval(ExprSource::new(expr_bytes.as_ptr())).unwrap(); + // Result is SymbolSize(8) + 8 bytes of f64 + assert_eq!(result.len(), 9, "tensor_get should return 9 bytes (tag + f64)"); + let val = f64::from_be_bytes(result[1..9].try_into().unwrap()); + assert!((val - 19.0).abs() < 1e-10, "tensor_get C 0 0 = {} expected 19.0", val); + + let expr_bytes = mork_expr::construct!("tensor_get" "C" "1" "1").unwrap(); + let result = scope.eval(ExprSource::new(expr_bytes.as_ptr())).unwrap(); + let val = f64::from_be_bytes(result[1..9].try_into().unwrap()); + assert!((val - 50.0).abs() < 1e-10, "tensor_get C 1 1 = {} expected 50.0", val); + } } } From 03a3134608e7e72d64400b6ccffdce6d3af339a7 Mon Sep 17 00:00:00 2001 From: Igor Malovitsa Date: Thu, 23 Apr 2026 17:25:16 +0000 Subject: [PATCH 3/6] Pass eval context as third FuncPtr arg instead of ExprSource field The opaque context pointer now flows through the extern "C" call signature rather than being set on each ExprSource before dispatch. tensor_get/tensor_nnz read it directly from the arg; EvalScope.context and ASink::set_context remain the owner-facing knob. Co-Authored-By: Claude Opus 4.7 (1M context) --- experiments/eval-examples/src/lib.rs | 4 +-- experiments/eval-ffi/src/lib.rs | 2 +- experiments/eval-ffi/src/source.rs | 5 +--- experiments/eval/src/lib.rs | 7 +++-- kernel/src/pure.rs | 38 ++++++++++++++-------------- kernel/src/sparse.rs | 22 ++++++++++------ 6 files changed, 40 insertions(+), 38 deletions(-) diff --git a/experiments/eval-examples/src/lib.rs b/experiments/eval-examples/src/lib.rs index 7aaa84c..816e58c 100644 --- a/experiments/eval-examples/src/lib.rs +++ b/experiments/eval-examples/src/lib.rs @@ -1,7 +1,7 @@ use eval_ffi::{ExprSink, ExprSource, EvalError, SourceItem}; #[unsafe(export_name = "ground_mul")] -pub extern "C" fn ground_mul(expr: *mut ExprSource, sink: *mut ExprSink) -> Result<(), EvalError> { +pub extern "C" fn ground_mul(expr: *mut ExprSource, sink: *mut ExprSink, _ctx: *mut ()) -> Result<(), EvalError> { let expr = unsafe { &mut *expr }; let sink = unsafe { &mut *sink }; let items = expr.consume_head_check(b"*")?; @@ -16,7 +16,7 @@ pub extern "C" fn ground_mul(expr: *mut ExprSource, sink: *mut ExprSink) -> Resu } #[unsafe(export_name = "ground_sum")] -pub extern "C" fn ground_sum(expr: *mut ExprSource, sink: *mut ExprSink) -> Result<(), EvalError> { +pub extern "C" fn ground_sum(expr: *mut ExprSource, sink: *mut ExprSink, _ctx: *mut ()) -> Result<(), EvalError> { let expr = unsafe { &mut *expr }; let sink = unsafe { &mut *sink }; let items = expr.consume_head_check(b"+")?; diff --git a/experiments/eval-ffi/src/lib.rs b/experiments/eval-ffi/src/lib.rs index 6e2cff3..8655aa1 100644 --- a/experiments/eval-ffi/src/lib.rs +++ b/experiments/eval-ffi/src/lib.rs @@ -10,7 +10,7 @@ pub use sink::{ExprSink}; pub use source::{ExprSource}; pub use mork_expr::{Tag, SourceItem}; -pub type FuncPtr = extern "C" fn(*mut ExprSource, *mut ExprSink) -> Result<(), EvalError>; +pub type FuncPtr = extern "C" fn(*mut ExprSource, *mut ExprSink, *mut ()) -> Result<(), EvalError>; #[derive(Debug)] pub enum EvalError { diff --git a/experiments/eval-ffi/src/source.rs b/experiments/eval-ffi/src/source.rs index 0588800..be21a48 100644 --- a/experiments/eval-ffi/src/source.rs +++ b/experiments/eval-ffi/src/source.rs @@ -8,9 +8,6 @@ pub struct ExprSource { pub ptr: *const u8, // len: usize, pub position: usize, - /// Opaque context pointer. Pure functions can use this to access - /// extra state (e.g. tensor store) passed in by the eval scope. - pub context: *mut (), } #[cfg(feature = "std")] @@ -74,7 +71,7 @@ impl ExprSource { } pub fn new(ptr: *const u8) -> Self { // let len = vec.len(); - Self { ptr, position: 0, context: core::ptr::null_mut() } + Self { ptr, position: 0 } } pub fn read(&mut self) -> SourceItem { let byte = unsafe { *self.ptr.add(self.position) }; diff --git a/experiments/eval/src/lib.rs b/experiments/eval/src/lib.rs index 7595bf7..202f845 100644 --- a/experiments/eval/src/lib.rs +++ b/experiments/eval/src/lib.rs @@ -7,11 +7,11 @@ use mork_expr::{item_source, SourceItem}; use eval_ffi::{EvalError, ExprSink, ExprSource, FuncPtr, Tag}; use log::trace; -extern "C" fn nothing(src: *mut ExprSource, snk: *mut ExprSink) -> Result<(), EvalError> { +extern "C" fn nothing(src: *mut ExprSource, snk: *mut ExprSink, _ctx: *mut ()) -> Result<(), EvalError> { Ok(()) } -extern "C" fn quote(src: *mut ExprSource, snk: *mut ExprSink) -> Result<(), EvalError> { +extern "C" fn quote(src: *mut ExprSource, snk: *mut ExprSink, _ctx: *mut ()) -> Result<(), EvalError> { unreachable!() } @@ -140,8 +140,7 @@ impl EvalScope { let mut data = core::mem::take(&mut top_frame.sink).finish(); let offset = prev_frame.sink.as_ref().len(); let mut src = ExprSource::new(data.as_ptr()); - src.context = self.context; - (top_frame.func)(&mut src, &mut prev_frame.sink)?; + (top_frame.func)(&mut src, &mut prev_frame.sink, self.context)?; trace!(target: "eval", "{:?} ==> {:?}", mork_expr::serialize(&data[..]), mork_expr::serialize(&prev_frame.sink.as_ref()[offset..])); let top = self.stack.pop().unwrap(); // return buffer to pool diff --git a/kernel/src/pure.rs b/kernel/src/pure.rs index e5ced75..d8114e7 100644 --- a/kernel/src/pure.rs +++ b/kernel/src/pure.rs @@ -9,7 +9,7 @@ use eval_ffi::{ExprSink, ExprSource, EvalError, Tag}; macro_rules! op { (num nary $name:ident($initial:expr, $t:ident: $tt:ty, $x:ident: $tx:ty) => $e:expr) => { - pub extern "C" fn $name(expr: *mut ExprSource, sink: *mut ExprSink) -> Result<(), EvalError> { + pub extern "C" fn $name(expr: *mut ExprSource, sink: *mut ExprSink, _ctx: *mut ()) -> Result<(), EvalError> { let expr = unsafe { &mut *expr }; let sink = unsafe { &mut *sink }; let items = expr.consume_head_check(stringify!($name).as_bytes())?; @@ -23,7 +23,7 @@ macro_rules! op { } }; (num quaternary $name:ident($x:ident: $tx:ty, $y:ident: $ty:ty, $z:ident: $tz:ty, $w:ident: $tw:ty) => $e:expr) => { - pub extern "C" fn $name(expr: *mut ExprSource, sink: *mut ExprSink) -> Result<(), EvalError> { + pub extern "C" fn $name(expr: *mut ExprSource, sink: *mut ExprSink, _ctx: *mut ()) -> Result<(), EvalError> { let expr = unsafe { &mut *expr }; let sink = unsafe { &mut *sink }; let items = expr.consume_head_check(stringify!($name).as_bytes())?; @@ -37,7 +37,7 @@ macro_rules! op { } }; (num ternary $name:ident($x:ident: $tx:ty, $y:ident: $ty:ty, $z:ident: $tz:ty) => $e:expr) => { - pub extern "C" fn $name(expr: *mut ExprSource, sink: *mut ExprSink) -> Result<(), EvalError> { + pub extern "C" fn $name(expr: *mut ExprSource, sink: *mut ExprSink, _ctx: *mut ()) -> Result<(), EvalError> { let expr = unsafe { &mut *expr }; let sink = unsafe { &mut *sink }; let items = expr.consume_head_check(stringify!($name).as_bytes())?; @@ -50,7 +50,7 @@ macro_rules! op { } }; (num binary $name:ident($x:ident: $tx:ty, $y:ident: $ty:ty) => $e:expr) => { - pub extern "C" fn $name(expr: *mut ExprSource, sink: *mut ExprSink) -> Result<(), EvalError> { + pub extern "C" fn $name(expr: *mut ExprSource, sink: *mut ExprSink, _ctx: *mut ()) -> Result<(), EvalError> { let expr = unsafe { &mut *expr }; let sink = unsafe { &mut *sink }; let items = expr.consume_head_check(stringify!($name).as_bytes())?; @@ -62,7 +62,7 @@ macro_rules! op { } }; (num unary $name:ident($x:ident: $tx:ty) => $e:expr) => { - pub extern "C" fn $name(expr: *mut ExprSource, sink: *mut ExprSink) -> Result<(), EvalError> { + pub extern "C" fn $name(expr: *mut ExprSource, sink: *mut ExprSink, _ctx: *mut ()) -> Result<(), EvalError> { let expr = unsafe { &mut *expr }; let sink = unsafe { &mut *sink }; let items = expr.consume_head_check(stringify!($name).as_bytes())?; @@ -73,7 +73,7 @@ macro_rules! op { } }; (num nulary $name:ident() => $e:expr) => { - pub extern "C" fn $name(expr: *mut ExprSource, sink: *mut ExprSink) -> Result<(), EvalError> { + pub extern "C" fn $name(expr: *mut ExprSource, sink: *mut ExprSink, _ctx: *mut ()) -> Result<(), EvalError> { let expr = unsafe { &mut *expr }; let sink = unsafe { &mut *sink }; let items = expr.consume_head_check(stringify!($name).as_bytes())?; @@ -83,7 +83,7 @@ macro_rules! op { } }; (num from_string $name:ident<$t:ty>) => { - pub extern "C" fn $name(expr: *mut ExprSource, sink: *mut ExprSink) -> Result<(), EvalError> { + pub extern "C" fn $name(expr: *mut ExprSource, sink: *mut ExprSink, _ctx: *mut ()) -> Result<(), EvalError> { let expr = unsafe { &mut *expr }; let sink = unsafe { &mut *sink }; let items = expr.consume_head_check(stringify!($name).as_bytes())?; @@ -95,7 +95,7 @@ macro_rules! op { } }; (num to_string $name:ident<$t:ty>) => { - pub extern "C" fn $name(expr: *mut ExprSource, sink: *mut ExprSink) -> Result<(), EvalError> { + pub extern "C" fn $name(expr: *mut ExprSource, sink: *mut ExprSink, _ctx: *mut ()) -> Result<(), EvalError> { let expr = unsafe { &mut *expr }; let sink = unsafe { &mut *sink }; let items = expr.consume_head_check(stringify!($name).as_bytes())?; @@ -703,7 +703,7 @@ op!(num nary min_f32(f32::INFINITY, t: f32, x: f32) => t.min(x)); op!(num from_string f32_from_string); op!(num to_string f32_to_string); -pub extern "C" fn encode_hex(expr: *mut ExprSource, sink: *mut ExprSink) -> Result<(), EvalError> { +pub extern "C" fn encode_hex(expr: *mut ExprSource, sink: *mut ExprSink, _ctx: *mut ()) -> Result<(), EvalError> { let expr = unsafe { &mut *expr }; let sink = unsafe { &mut *sink }; let items = expr.consume_head_check(b"encode_hex")?; @@ -716,7 +716,7 @@ pub extern "C" fn encode_hex(expr: *mut ExprSource, sink: *mut ExprSink) -> Resu Ok(()) } -pub extern "C" fn decode_hex(expr: *mut ExprSource, sink: *mut ExprSink) -> Result<(), EvalError> { +pub extern "C" fn decode_hex(expr: *mut ExprSource, sink: *mut ExprSink, _ctx: *mut ()) -> Result<(), EvalError> { let expr = unsafe { &mut *expr }; let sink = unsafe { &mut *sink }; let items = expr.consume_head_check(b"decode_hex")?; @@ -729,7 +729,7 @@ pub extern "C" fn decode_hex(expr: *mut ExprSource, sink: *mut ExprSink) -> Resu Ok(()) } -pub extern "C" fn decode_base64url(expr: *mut ExprSource, sink: *mut ExprSink) -> Result<(), EvalError> { +pub extern "C" fn decode_base64url(expr: *mut ExprSource, sink: *mut ExprSink, _ctx: *mut ()) -> Result<(), EvalError> { let expr = unsafe { &mut *expr }; let sink = unsafe { &mut *sink }; let items = expr.consume_head_check(b"decode_base64url")?; @@ -742,7 +742,7 @@ pub extern "C" fn decode_base64url(expr: *mut ExprSource, sink: *mut ExprSink) - Ok(()) } -pub extern "C" fn encode_base64url(expr: *mut ExprSource, sink: *mut ExprSink) -> Result<(), EvalError> { +pub extern "C" fn encode_base64url(expr: *mut ExprSource, sink: *mut ExprSink, _ctx: *mut ()) -> Result<(), EvalError> { let expr = unsafe { &mut *expr }; let sink = unsafe { &mut *sink }; let items = expr.consume_head_check(b"encode_base64url")?; @@ -755,7 +755,7 @@ pub extern "C" fn encode_base64url(expr: *mut ExprSource, sink: *mut ExprSink) - Ok(()) } -pub extern "C" fn hash_expr(expr: *mut ExprSource, sink: *mut ExprSink) -> Result<(), EvalError> { +pub extern "C" fn hash_expr(expr: *mut ExprSource, sink: *mut ExprSink, _ctx: *mut ()) -> Result<(), EvalError> { let expr = unsafe { &mut *expr }; let sink = unsafe { &mut *sink }; let items = expr.consume_head_check(b"hash_expr")?; @@ -767,7 +767,7 @@ pub extern "C" fn hash_expr(expr: *mut ExprSource, sink: *mut ExprSink) -> Resul Ok(()) } -pub extern "C" fn reverse_symbol(expr: *mut ExprSource, sink: *mut ExprSink) -> Result<(), EvalError> { +pub extern "C" fn reverse_symbol(expr: *mut ExprSource, sink: *mut ExprSink, _ctx: *mut ()) -> Result<(), EvalError> { let expr = unsafe { &mut *expr }; let sink = unsafe { &mut *sink }; let items = expr.consume_head_check(b"reverse_symbol")?; @@ -780,7 +780,7 @@ pub extern "C" fn reverse_symbol(expr: *mut ExprSource, sink: *mut ExprSink) -> Ok(()) } -pub extern "C" fn collapse_symbol(expr: *mut ExprSource, sink: *mut ExprSink) -> Result<(), EvalError> { +pub extern "C" fn collapse_symbol(expr: *mut ExprSource, sink: *mut ExprSink, _ctx: *mut ()) -> Result<(), EvalError> { let expr = unsafe { &mut *expr }; let sink = unsafe { &mut *sink }; let items = expr.consume_head_check(b"collapse_symbol")?; @@ -800,7 +800,7 @@ pub extern "C" fn collapse_symbol(expr: *mut ExprSource, sink: *mut ExprSink) -> Ok(()) } -pub extern "C" fn explode_symbol(expr: *mut ExprSource, sink: *mut ExprSink) -> Result<(), EvalError> { +pub extern "C" fn explode_symbol(expr: *mut ExprSource, sink: *mut ExprSink, _ctx: *mut ()) -> Result<(), EvalError> { let expr = unsafe { &mut *expr }; let sink = unsafe { &mut *sink }; let items = expr.consume_head_check(b"explode_symbol")?; @@ -813,7 +813,7 @@ pub extern "C" fn explode_symbol(expr: *mut ExprSource, sink: *mut ExprSink) -> Ok(()) } -// pub extern "C" fn nth_expr(expr: *mut ExprSource, sink: *mut ExprSink) -> Result<(), EvalError> { +// pub extern "C" fn nth_expr(expr: *mut ExprSource, sink: *mut ExprSink, _ctx: *mut ()) -> Result<(), EvalError> { // let expr = unsafe { &mut *expr }; // let sink = unsafe { &mut *sink }; // let items = expr.consume_head_check(b"nth_expr")?; @@ -829,7 +829,7 @@ pub extern "C" fn explode_symbol(expr: *mut ExprSource, sink: *mut ExprSink) -> // (ifnz then else ) // The condition may be of any length ( is always len >= 1), // but all bytes in the must be b'\0' in order for the condition to be considered `false` -pub extern "C" fn ifnz(expr: *mut ExprSource, sink: *mut ExprSink) -> Result<(), EvalError> { +pub extern "C" fn ifnz(expr: *mut ExprSource, sink: *mut ExprSink, _ctx: *mut ()) -> Result<(), EvalError> { let expr = unsafe { &mut *expr }; let sink = unsafe { &mut *sink }; let items = expr.consume_head_check(b"ifnz")?; @@ -853,7 +853,7 @@ pub extern "C" fn ifnz(expr: *mut ExprSource, sink: *mut ExprSink) -> Result<(), } } -pub extern "C" fn tuple(expr: *mut ExprSource, sink: *mut ExprSink) -> Result<(), EvalError> { +pub extern "C" fn tuple(expr: *mut ExprSource, sink: *mut ExprSink, _ctx: *mut ()) -> Result<(), EvalError> { let expr = unsafe { &mut *expr }; let sink = unsafe { &mut *sink }; let items = expr.consume_head_check(b"tuple")?; diff --git a/kernel/src/sparse.rs b/kernel/src/sparse.rs index 1b01aa5..7aeb095 100644 --- a/kernel/src/sparse.rs +++ b/kernel/src/sparse.rs @@ -141,21 +141,21 @@ impl einsum_dyn::NDIndex for SparseTensorF64 { } // ============================================================================ -// Pure functions (access tensor store via ExprSource.context) +// Pure functions (access tensor store via ctx arg) // ============================================================================ use eval_ffi::{ExprSource, ExprSink, EvalError}; use mork_expr::SourceItem; use eval::{EvalScope, FuncType}; -/// Read tensor store from the ExprSource context pointer. +/// Reinterpret the opaque context pointer as a tensor store reference. /// The context is set by PureSink via ASink::set_context before eval. -unsafe fn tensor_store_from_context(expr: &ExprSource) -> Option<&HashMap, SparseTensorF64>> { - (expr.context as *const HashMap, SparseTensorF64>).as_ref() +unsafe fn tensor_store_from_context(ctx: *mut ()) -> Option<&'static HashMap, SparseTensorF64>> { + (ctx as *const HashMap, SparseTensorF64>).as_ref() } /// (tensor_get name i0 i1 ... iN) -> f64 value at that index -pub extern "C" fn tensor_get(expr: *mut ExprSource, sink: *mut ExprSink) -> Result<(), EvalError> { +pub extern "C" fn tensor_get(expr: *mut ExprSource, sink: *mut ExprSink, ctx: *mut ()) -> Result<(), EvalError> { let expr = unsafe { &mut *expr }; let sink = unsafe { &mut *sink }; let items = expr.consume_head_check(b"tensor_get")?; @@ -178,7 +178,7 @@ pub extern "C" fn tensor_get(expr: *mut ExprSource, sink: *mut ExprSink) -> Resu indices.push(idx); } - let store = unsafe { tensor_store_from_context(expr) }; + let store = unsafe { tensor_store_from_context(ctx) }; let val = store .and_then(|s| s.get(&name)) .and_then(|t| t.get(&indices)) @@ -189,7 +189,7 @@ pub extern "C" fn tensor_get(expr: *mut ExprSource, sink: *mut ExprSink) -> Resu } /// (tensor_nnz name) -> u64 count of non-zeros -pub extern "C" fn tensor_nnz(expr: *mut ExprSource, sink: *mut ExprSink) -> Result<(), EvalError> { +pub extern "C" fn tensor_nnz(expr: *mut ExprSource, sink: *mut ExprSink, ctx: *mut ()) -> Result<(), EvalError> { let expr = unsafe { &mut *expr }; let sink = unsafe { &mut *sink }; let items = expr.consume_head_check(b"tensor_nnz")?; @@ -200,7 +200,7 @@ pub extern "C" fn tensor_nnz(expr: *mut ExprSource, sink: *mut ExprSink) -> Resu }; let name = name.to_vec(); - let store = unsafe { tensor_store_from_context(expr) }; + let store = unsafe { tensor_store_from_context(ctx) }; let nnz = store .and_then(|s| s.get(&name)) .map(|t| t.nnz()) @@ -386,3 +386,9 @@ mod tests { } } } +/* +sinks create intermediate representation of tensors and they should not +put directly to tensor +binary ops (tensor add/mul) don't work +tensor_clear only clears the first item +*/ \ No newline at end of file From c0bdceac1dc2790e54a736d023df89183dfddbb8 Mon Sep 17 00:00:00 2001 From: Igor Malovitsa Date: Thu, 23 Apr 2026 17:35:55 +0000 Subject: [PATCH 4/6] Fix TensorFreeSink to free every matched tensor, not just the first The `parsed` guard in `sink()` captured only the first pattern match and dropped subsequent ones, so `(exec (, (to_free $n)) (O (tensor_free $n)))` freed one tensor regardless of how many facts matched. Collect all names across `sink()` calls and drain them in `finalize()`. The test exercises three concurrent matches with one non-match that must survive. Co-Authored-By: Claude Opus 4.7 (1M context) --- kernel/src/sinks.rs | 27 +++++++++++++++++++-------- kernel/src/sparse.rs | 29 ++++++++++++++++++++++++++++- 2 files changed, 47 insertions(+), 9 deletions(-) diff --git a/kernel/src/sinks.rs b/kernel/src/sinks.rs index 0ac826e..a5ecd6b 100644 --- a/kernel/src/sinks.rs +++ b/kernel/src/sinks.rs @@ -1434,31 +1434,42 @@ impl Sink for TensorBinopSink { } } -/// TensorFreeSink — removes a named tensor. +/// TensorFreeSink — removes named tensors. /// Syntax: (tensor_free A) +/// +/// Names are buffered during `sink()` and the actual removals happen in +/// `finalize()`. We cannot free in `sink()` because a future source type +/// (tensor-as-source) could be iterating the tensor store while pattern +/// matches drive these calls — mutating the store mid-iteration would +/// invalidate the upstream zipper. pub struct TensorFreeSink { e: Expr, - name: Vec, - parsed: bool, + names: Vec>, } impl Sink for TensorFreeSink { fn new(e: Expr) -> Self { - TensorFreeSink { e, name: Vec::new(), parsed: false } + TensorFreeSink { e, names: Vec::new() } } fn request(&self) -> impl Iterator { std::iter::once(WriteResourceRequest::TensorStore) } fn sink<'w, 'a, 'k, It: Iterator>>(&mut self, _it: It, path: &[u8]) where 'a: 'w, 'k: 'w { - if self.parsed { return; } // "tensor_free" is 11 chars → header = 13 let args = parse_symbol_args(path, 13); - self.name = args.first().map(|a| a.to_vec()).unwrap_or_default(); - self.parsed = true; + if let Some(name) = args.first() { + self.names.push(name.to_vec()); + } } fn finalize<'w, 'a, 'k, It: Iterator>>(&mut self, mut it: It) -> bool where 'a: 'w, 'k: 'w { let WriteResource::TensorStore(store) = it.next().unwrap() else { unreachable!() }; - store.remove(&self.name).is_some() + let mut any = false; + for name in self.names.drain(..) { + if store.remove(&name).is_some() { + any = true; + } + } + any } } diff --git a/kernel/src/sparse.rs b/kernel/src/sparse.rs index 7aeb095..1e6dde6 100644 --- a/kernel/src/sparse.rs +++ b/kernel/src/sparse.rs @@ -385,10 +385,37 @@ mod tests { assert!((val - 50.0).abs() < 1e-10, "tensor_get C 1 1 = {} expected 50.0", val); } } + + #[test] + fn test_tensor_free_multi_match() { + use crate::space::Space; + + let mut s = Space::new(); + + for name in [b"A".as_slice(), b"B", b"C", b"D"] { + let mut t = SparseTensorF64::new(1); + t.set(&[0], 1.0); + s.tensors.insert(name.to_vec(), t); + } + + s.add_all_sexpr(r#" + (to_free A) + (to_free B) + (to_free C) + + (exec F (, (to_free $name)) (O (tensor_free $name))) + "#.as_bytes()).unwrap(); + + s.metta_calculus(100); + + assert!(!s.tensors.contains_key(b"A".as_slice()), "A should be freed"); + assert!(!s.tensors.contains_key(b"B".as_slice()), "B should be freed"); + assert!(!s.tensors.contains_key(b"C".as_slice()), "C should be freed"); + assert!(s.tensors.contains_key(b"D".as_slice()), "D should survive"); + } } /* sinks create intermediate representation of tensors and they should not put directly to tensor binary ops (tensor add/mul) don't work -tensor_clear only clears the first item */ \ No newline at end of file From cd9d4ed48ba195ad31c0564ac01c0c56846fdbc0 Mon Sep 17 00:00:00 2001 From: Igor Malovitsa Date: Thu, 23 Apr 2026 17:39:20 +0000 Subject: [PATCH 5/6] Test tensor_add/tensor_mul sink paths with constant operands MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Both pass — the basic TensorBinopSink path works for constant names under an empty `(,)` pattern. Does not yet cover variable-bound names or multi-match dispatch, which are the likelier breakage surfaces hinted at by the todo comment. Co-Authored-By: Claude Opus 4.7 (1M context) --- kernel/src/sparse.rs | 56 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/kernel/src/sparse.rs b/kernel/src/sparse.rs index 1e6dde6..ae2b9c5 100644 --- a/kernel/src/sparse.rs +++ b/kernel/src/sparse.rs @@ -386,6 +386,62 @@ mod tests { } } + #[test] + fn test_tensor_add_sink() { + use crate::space::Space; + + let mut s = Space::new(); + + let mut a = SparseTensorF64::new(2); + a.set(&[0, 0], 1.0); + a.set(&[0, 1], 2.0); + s.tensors.insert(b"A".to_vec(), a); + + let mut b = SparseTensorF64::new(2); + b.set(&[0, 0], 10.0); + b.set(&[1, 0], 20.0); + s.tensors.insert(b"B".to_vec(), b); + + s.add_all_sexpr(r#" + (exec F (,) (O (tensor_add A B C))) + "#.as_bytes()).unwrap(); + + s.metta_calculus(100); + + let c = s.tensors.get(b"C".as_slice()).expect("C should exist"); + assert_eq!(c.get(&[0, 0]), Some(11.0)); + assert_eq!(c.get(&[0, 1]), Some(2.0)); + assert_eq!(c.get(&[1, 0]), Some(20.0)); + } + + #[test] + fn test_tensor_mul_sink() { + use crate::space::Space; + + let mut s = Space::new(); + + let mut a = SparseTensorF64::new(2); + a.set(&[0, 0], 3.0); + a.set(&[0, 1], 5.0); + s.tensors.insert(b"A".to_vec(), a); + + let mut b = SparseTensorF64::new(2); + b.set(&[0, 0], 2.0); + b.set(&[1, 1], 4.0); + s.tensors.insert(b"B".to_vec(), b); + + s.add_all_sexpr(r#" + (exec F (,) (O (tensor_mul A B C))) + "#.as_bytes()).unwrap(); + + s.metta_calculus(100); + + let c = s.tensors.get(b"C".as_slice()).expect("C should exist"); + assert_eq!(c.get(&[0, 0]), Some(6.0)); + assert_eq!(c.get(&[0, 1]), None); + assert_eq!(c.get(&[1, 1]), None); + } + #[test] fn test_tensor_free_multi_match() { use crate::space::Space; From 070e72b16546bbac8beabb5da0301a44282e10d9 Mon Sep 17 00:00:00 2001 From: Igor Malovitsa Date: Thu, 23 Apr 2026 20:46:15 +0000 Subject: [PATCH 6/6] Write TensorCollectSink entries directly into the tensor store MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Drop the `entries: Vec<(Vec, f64)>` buffer — parse each matched tuple in `sink()` and call `SparseTensorF64::set` on the stored tensor directly. The first match replaces any existing tensor at `name` with a fresh one of the sink's rank, so repeated exec invocations start clean; subsequent matches in the same invocation accumulate. Tracked with one `initialized: bool` instead of the entry Vec. Safe as long as no concurrent tensor-as-source path reads `name` during the exec. Co-Authored-By: Claude Opus 4.7 (1M context) --- kernel/src/sinks.rs | 55 ++++++++++++++++++++++---------------------- kernel/src/sparse.rs | 2 -- 2 files changed, 27 insertions(+), 30 deletions(-) diff --git a/kernel/src/sinks.rs b/kernel/src/sinks.rs index a5ecd6b..50a6887 100644 --- a/kernel/src/sinks.rs +++ b/kernel/src/sinks.rs @@ -1230,13 +1230,19 @@ fn parse_symbol_args<'a>(path: &'a [u8], header_size: usize) -> Vec<&'a [u8]> { args } -/// TensorCollectSink — accumulates (indices, value) tuples into a named SparseTensorF64. +/// TensorCollectSink — writes (indices, value) tuples directly into a named SparseTensorF64. /// Syntax: (tensor_collect name $i0 $i1 ... $val) +/// +/// The first matching tuple replaces any existing tensor at `name` with a +/// fresh one; subsequent matches in the same exec invocation accumulate +/// into it. This assumes no concurrent tensor-as-source path is reading +/// `name` during the exec — if that ever becomes possible the writes must +/// be buffered until `finalize()` (see TensorFreeSink for that pattern). pub struct TensorCollectSink { e: Expr, name: Vec, rank: usize, - entries: Vec<(Vec, f64)>, + initialized: bool, } impl TensorCollectSink { @@ -1257,45 +1263,38 @@ impl Sink for TensorCollectSink { panic!("tensor_collect: second arg must be a symbol (tensor name)") } }; - TensorCollectSink { e, name, rank, entries: Vec::new() } + TensorCollectSink { e, name, rank, initialized: false } } fn request(&self) -> impl Iterator { std::iter::once(WriteResourceRequest::TensorStore) } - fn sink<'w, 'a, 'k, It: Iterator>>(&mut self, _it: It, path: &[u8]) where 'a: 'w, 'k: 'w { - // Parse args after header + name + fn sink<'w, 'a, 'k, It: Iterator>>(&mut self, mut it: It, path: &[u8]) where 'a: 'w, 'k: 'w { let name_len = self.name.len(); let args = parse_symbol_args(path, Self::HEADER_SIZE + 1 + name_len); - if args.len() >= self.rank + 1 { - let mut indices = Vec::with_capacity(self.rank); - for i in 0..self.rank { - if let Ok(s) = std::str::from_utf8(args[i]) { - if let Ok(idx) = s.parse::() { - indices.push(idx); - } else { return; } - } else { return; } - } - if let Ok(s) = std::str::from_utf8(args[self.rank]) { - if let Ok(val) = s.parse::() { - self.entries.push((indices, val)); - } - } + if args.len() < self.rank + 1 { return; } + + let mut indices = Vec::with_capacity(self.rank); + for i in 0..self.rank { + let Ok(s) = std::str::from_utf8(args[i]) else { return; }; + let Ok(idx) = s.parse::() else { return; }; + indices.push(idx); } - } + let Ok(s) = std::str::from_utf8(args[self.rank]) else { return; }; + let Ok(val) = s.parse::() else { return; }; - fn finalize<'w, 'a, 'k, It: Iterator>>(&mut self, mut it: It) -> bool where 'a: 'w, 'k: 'w { - if self.entries.is_empty() { return false; } let WriteResource::TensorStore(store) = it.next().unwrap() else { unreachable!() }; - - let mut tensor = crate::sparse::SparseTensorF64::new(self.rank); - for (indices, value) in self.entries.drain(..) { - tensor.set(&indices, value); + if !self.initialized { + store.insert(self.name.clone(), crate::sparse::SparseTensorF64::new(self.rank)); + self.initialized = true; } - store.insert(self.name.clone(), tensor); - true + store.get_mut(&self.name).unwrap().set(&indices, val); + } + + fn finalize<'w, 'a, 'k, It: Iterator>>(&mut self, _it: It) -> bool where 'a: 'w, 'k: 'w { + self.initialized } } diff --git a/kernel/src/sparse.rs b/kernel/src/sparse.rs index ae2b9c5..77768c3 100644 --- a/kernel/src/sparse.rs +++ b/kernel/src/sparse.rs @@ -471,7 +471,5 @@ mod tests { } } /* -sinks create intermediate representation of tensors and they should not -put directly to tensor binary ops (tensor add/mul) don't work */ \ No newline at end of file