diff --git a/Cargo.toml b/Cargo.toml index 0ce14db..dba88d8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,6 +8,7 @@ members = [ "experiments/eval", "experiments/eval-ffi", "experiments/eval-examples", + "einsum-dyn/", ] default-members = ["kernel/"] diff --git a/einsum-dyn/Cargo.toml b/einsum-dyn/Cargo.toml new file mode 100644 index 0000000..6832d86 --- /dev/null +++ b/einsum-dyn/Cargo.toml @@ -0,0 +1,4 @@ +[package] +name = "einsum-dyn" +version = "0.1.0" +edition = "2024" diff --git a/einsum-dyn/src/lib.rs b/einsum-dyn/src/lib.rs new file mode 100644 index 0000000..cc6cbde --- /dev/null +++ b/einsum-dyn/src/lib.rs @@ -0,0 +1,1400 @@ +//! Runtime dynamic [Einstein summation](https://en.wikipedia.org/wiki/Einstein_notation) +//! for arbitrary N-dimensional arrays. +//! +//! # Functions +//! +//! | Function | Inputs | Output | +//! |---|---|---| +//! | [`einsum_ary`] | N arrays (`&[&In]`) | tensor (`&mut Out`) | +//! | [`einsum_binary`] | two arrays | tensor (`&mut Out`) | +//! | [`einsum_unary`] | one array | tensor (`&mut Out`) | +//! | [`einsum_binary_scalar`] | two arrays | scalar (returned) | +//! | [`einsum_unary_scalar`] | one array | scalar (returned) | +//! +//! [`einsum_ary`] is the general-purpose entry point — it accepts any number +//! of inputs and subsumes both [`einsum_binary`] and [`einsum_unary`]. The +//! specialised functions remain for ergonomics and because they are ~1.5× +//! faster (stack-only buffers vs heap `Vec`s for patterns/indices). +//! +//! For scalar output with `einsum_ary`, pass a 0-dimensional output tensor +//! (`ndim() == 0`, single element at `&[]`). +//! +//! # Spec format +//! +//! Specs use lowercase letters `a`–`z` as index names, with `->` separating +//! inputs from output: +//! +//! - `"ab,bc->ac"` — matrix multiply (contract over `b`) +//! - `"ab->ba"` — transpose (no contraction) +//! - `"i,i->"` — dot product (scalar output, empty after `->`) +//! - `"aa->"` — trace (scalar output) +//! - `"ab,bc,cd->ad"` — 3-input chain contraction (N-ary) +//! +//! Indices present in inputs but absent from the output are contracted +//! (summed over). All output indices must appear in at least one input. +//! +//! # Implementing `NDIndex` +//! +//! Any type can be used with these functions by implementing [`NDIndex`]: +//! +//! ``` +//! use einsum_dyn::{NDIndex, einsum_ary, einsum_binary, einsum_unary, einsum_binary_scalar, einsum_unary_scalar}; +//! +//! struct MyTensor { +//! data: Vec, +//! shape: Vec, +//! } +//! +//! impl MyTensor { +//! fn new(shape: Vec) -> Self { +//! let n = shape.iter().product(); +//! Self { data: vec![0.0; n], shape } +//! } +//! fn linear_index(&self, ix: &[usize]) -> usize { +//! let mut idx = 0; +//! let mut stride = 1; +//! for (&k, &dim) in ix.iter().rev().zip(self.shape.iter().rev()) { +//! idx += k * stride; +//! stride *= dim; +//! } +//! idx +//! } +//! } +//! +//! impl NDIndex for MyTensor { +//! fn ndim(&self) -> usize { self.shape.len() } +//! fn dim(&self, axis: usize) -> usize { self.shape[axis] } +//! fn get(&self, ix: &[usize]) -> f64 { self.data[self.linear_index(ix)] } +//! fn set(&mut self, ix: &[usize], v: f64) { +//! let i = self.linear_index(ix); +//! self.data[i] = v; +//! } +//! } +//! +//! // Matrix multiply: C = A × B +//! let mut a = MyTensor::new(vec![2, 3]); +//! a.data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; +//! let mut b = MyTensor::new(vec![3, 2]); +//! b.data = vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0]; +//! let mut c = MyTensor::new(vec![2, 2]); +//! einsum_binary("ab,bc->ac", &a, &b, &mut c).unwrap(); +//! assert_eq!(c.data, vec![58.0, 64.0, 139.0, 154.0]); +//! +//! // Transpose +//! let mut t = MyTensor::new(vec![3, 2]); +//! einsum_unary("ab->ba", &a, &mut t).unwrap(); +//! assert_eq!(t.data, vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]); +//! +//! // Dot product (scalar) +//! let mut x = MyTensor::new(vec![3]); +//! x.data = vec![1.0, 2.0, 3.0]; +//! let mut y = MyTensor::new(vec![3]); +//! y.data = vec![4.0, 5.0, 6.0]; +//! let dot: f64 = einsum_binary_scalar("i,i->", &x, &y).unwrap(); +//! assert_eq!(dot, 32.0); +//! +//! // Trace (scalar) +//! let mut m = MyTensor::new(vec![2, 2]); +//! m.data = vec![1.0, 2.0, 3.0, 4.0]; +//! let tr: f64 = einsum_unary_scalar("aa->", &m).unwrap(); +//! assert_eq!(tr, 5.0); +//! +//! // N-ary: 3-input chain A(2×3) × B(3×2) × C(2×2) via einsum_ary +//! let mut d = MyTensor::new(vec![2, 2]); +//! einsum_ary("ab,bc,cd->ad", &[&a, &b, &c], &mut d).unwrap(); +//! +//! // N-ary with scalar output (0-dim tensor) +//! let mut scalar = MyTensor::new(vec![]); +//! einsum_ary("i,i->", &[&x, &y], &mut scalar).unwrap(); +//! assert_eq!(scalar.data, vec![32.0]); +//! ``` +//! +//! # Feature flags +//! +//! - **`ndarray`** — implements `NDIndex` for [`ndarray::ArrayD`], so you +//! can pass dynamic-dimension ndarray arrays directly. + +pub mod sparse; + +use std::fmt; +use std::ops::{AddAssign, Mul}; + +/// Trait for N-dimensional array access. +/// +/// Implement this for your tensor type to use the `einsum_*` functions. +/// All index slices are ordered left-to-right (outermost dimension first). +pub trait NDIndex { + fn ndim(&self) -> usize; + fn dim(&self, axis: usize) -> usize; + fn get(&self, indices: &[usize]) -> T; + fn set(&mut self, indices: &[usize], val: T); + + /// Returns `None` for structurally absent (zero) entries. + /// + /// Dense implementations use the default, which wraps `get` in `Some`. + /// Sparse implementations should override this to return `None` for + /// entries not present in the sparse structure. + fn get_opt(&self, indices: &[usize]) -> Option { + Some(self.get(indices)) + } + + /// Whether this array is a 2D sparse matrix supporting row iteration + /// via `sparse_row_nnz` and `sparse_row_entry`. Default: false. + fn is_sparse_2d(&self) -> bool { false } + + /// Number of non-zero entries in the given row. + /// Only meaningful when `is_sparse_2d()` returns true. + fn sparse_row_nnz(&self, _row: usize) -> usize { 0 } + + /// Get the `idx`-th non-zero entry in the given row as `(col, value)`. + /// Only meaningful when `is_sparse_2d()` returns true. + fn sparse_row_entry(&self, _row: usize, _idx: usize) -> (usize, T) { + panic!("sparse_row_entry called on non-sparse array") + } +} + +#[cfg(feature = "ndarray")] +impl NDIndex for ndarray::ArrayD { + fn ndim(&self) -> usize { + self.ndim() + } + fn dim(&self, axis: usize) -> usize { + self.shape()[axis] + } + fn get(&self, ix: &[usize]) -> T { + self[ndarray::IxDyn(ix)] + } + fn set(&mut self, ix: &[usize], val: T) { + self[ndarray::IxDyn(ix)] = val; + } +} + +/// Error returned when an einsum spec string is invalid. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum InvalidSpec { + MissingArrow, + InvalidIndex { ch: char }, + WrongInputCount { expected: usize, got: usize }, + EmptyInput { input: usize }, + UnboundOutputIndex { index: char }, + InputNdimMismatch { input: usize, array_ndim: usize, spec_ndim: usize }, + DimensionMismatch { index: char, expected: usize, got: usize }, + OutputNdimMismatch { array_ndim: usize, spec_ndim: usize }, + OutputDimMismatch { axis: usize, expected: usize, got: usize }, + NonEmptyScalarOutput, +} + +/// Convert a slot index back to its letter for error messages. +fn slot_to_char(s: u8) -> char { + (s + b'a') as char +} + +impl fmt::Display for InvalidSpec { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::MissingArrow => write!(f, "missing '->'"), + Self::InvalidIndex { ch } => write!(f, "index '{ch}' is not a lowercase letter"), + Self::WrongInputCount { expected, got } => { + write!(f, "expected {expected} input(s), got {got}") + } + Self::EmptyInput { input } => write!(f, "input {input} has no indices"), + Self::UnboundOutputIndex { index } => { + write!(f, "output index '{index}' does not appear in any input") + } + Self::InputNdimMismatch { input, array_ndim, spec_ndim } => { + write!(f, "input {input} has {array_ndim} dimensions but spec has {spec_ndim} indices") + } + Self::DimensionMismatch { index, expected, got } => { + write!(f, "dimension mismatch for index '{index}': {expected} vs {got}") + } + Self::OutputNdimMismatch { array_ndim, spec_ndim } => { + write!(f, "output has {array_ndim} dimensions but spec has {spec_ndim} output indices") + } + Self::OutputDimMismatch { axis, expected, got } => { + write!(f, "output dimension {axis} is {got} but expected {expected}") + } + Self::NonEmptyScalarOutput => { + write!(f, "scalar output requires empty output indices (use '...->')") + } + } + } +} + +impl std::error::Error for InvalidSpec {} + +/// Parsed einsum specification. All index chars are stored as slot indices +/// (`b'a'` → 0, `b'b'` → 1, ..., `b'z'` → 25). +/// +/// The RHS of the spec may contain multiple comma-separated output groups +/// (e.g. `"ab,bc->ac,ca"`). Single-output functions check `outputs.len() == 1`. +pub(crate) struct Spec { + inputs: Vec>, + outputs: Vec>, +} + +impl Spec { + /// Convenience: returns the first (and usually only) output pattern. + pub(crate) fn output(&self) -> &[u8] { + &self.outputs[0] + } + + /// All unique output slots across all outputs. + pub(crate) fn all_output_slots(&self) -> Vec { + let mut seen = [false; 26]; + let mut slots = Vec::new(); + for out in &self.outputs { + for &s in out { + if !seen[s as usize] { + seen[s as usize] = true; + slots.push(s); + } + } + } + slots + } +} + +pub(crate) fn parse_spec(spec: &str, expected_inputs: usize) -> Result { + let spec = spec.replace(' ', ""); + + let (lhs, rhs) = spec + .split_once("->") + .ok_or(InvalidSpec::MissingArrow)?; + + let mut inputs: Vec> = Vec::new(); + for part in lhs.split(',') { + let mut slots = Vec::new(); + for ch in part.bytes() { + if !ch.is_ascii_lowercase() { + return Err(InvalidSpec::InvalidIndex { ch: ch as char }); + } + slots.push(ch - b'a'); + } + inputs.push(slots); + } + + if inputs.len() != expected_inputs { + return Err(InvalidSpec::WrongInputCount { + expected: expected_inputs, + got: inputs.len(), + }); + } + + for (i, inp) in inputs.iter().enumerate() { + if inp.is_empty() { + return Err(InvalidSpec::EmptyInput { input: i }); + } + } + + let mut outputs: Vec> = Vec::new(); + for part in rhs.split(',') { + let mut slots = Vec::new(); + for ch in part.bytes() { + if !ch.is_ascii_lowercase() { + return Err(InvalidSpec::InvalidIndex { ch: ch as char }); + } + slots.push(ch - b'a'); + } + outputs.push(slots); + } + + // Validate: every output index must appear in at least one input + let mut seen = [false; 26]; + for inp in &inputs { + for &s in inp { + seen[s as usize] = true; + } + } + for out in &outputs { + for &s in out { + if !seen[s as usize] { + return Err(InvalidSpec::UnboundOutputIndex { index: slot_to_char(s) }); + } + } + } + + Ok(Spec { inputs, outputs }) +} + +/// Validates that array dimensions match the spec. Returns dims as a `[usize; 26]` +/// array (indexed by slot). Unused slots are 0. +pub(crate) fn validate_dims>( + spec: &Spec, + arrays: &[&Arr], +) -> Result<[usize; 26], InvalidSpec> { + for (i, (inp, arr)) in spec.inputs.iter().zip(arrays.iter()).enumerate() { + if arr.ndim() != inp.len() { + return Err(InvalidSpec::InputNdimMismatch { + input: i, + array_ndim: arr.ndim(), + spec_ndim: inp.len(), + }); + } + } + + let mut dims = [0usize; 26]; + let mut set = [false; 26]; + for (pi, inp) in spec.inputs.iter().enumerate() { + for (pos, &s) in inp.iter().enumerate() { + let si = s as usize; + let d = arrays[pi].dim(pos); + if set[si] { + if dims[si] != d { + return Err(InvalidSpec::DimensionMismatch { + index: slot_to_char(s), + expected: dims[si], + got: d, + }); + } + } else { + dims[si] = d; + set[si] = true; + } + } + } + + Ok(dims) +} + +/// Collects all unique indices in order of first appearance, as a SlotList. +fn all_slots_ordered(spec: &Spec) -> SlotList { + let mut seen = [false; 26]; + let mut slots = [0u8; 26]; + let mut len = 0u8; + for inp in &spec.inputs { + for &s in inp { + if !seen[s as usize] { + seen[s as usize] = true; + slots[len as usize] = s; + len += 1; + } + } + } + SlotList { slots, len } +} + +/// Stack-only index buffer: fixed array + length, no heap. +struct Idx { + data: [usize; 26], + len: u8, +} + +impl Idx { + const ZERO: Self = Idx { data: [0; 26], len: 0 }; + + #[inline(always)] + fn as_slice(&self) -> &[usize] { + &self.data[..self.len as usize] + } +} + +/// Precomputed gather pattern: slot indices stored on the stack. +struct Pattern { + slots: [u8; 26], + len: u8, +} + +impl Pattern { + fn from_slots(slot_indices: &[u8]) -> Self { + let mut slots = [0u8; 26]; + slots[..slot_indices.len()].copy_from_slice(slot_indices); + Pattern { + slots, + len: slot_indices.len() as u8, + } + } + + /// Gather index values from `vals` into `out` according to this pattern. + #[inline(always)] + fn gather(&self, vals: &[usize; 26], out: &mut Idx) { + out.len = self.len; + for i in 0..self.len as usize { + out.data[i] = vals[self.slots[i] as usize]; + } + } +} + +/// Precomputed loop-slot list stored on the stack. +struct SlotList { + slots: [u8; 26], + len: u8, +} + +impl SlotList { + fn from_slots(slot_indices: &[u8]) -> Self { + let mut slots = [0u8; 26]; + slots[..slot_indices.len()].copy_from_slice(slot_indices); + SlotList { + slots, + len: slot_indices.len() as u8, + } + } + + fn as_slice(&self) -> &[u8] { + &self.slots[..self.len as usize] + } + + fn contains(&self, s: u8) -> bool { + self.as_slice().contains(&s) + } + + fn filtered_complement(all: &[u8], free: &SlotList) -> Self { + let mut slots = [0u8; 26]; + let mut len = 0u8; + for &s in all { + if !free.contains(s) { + slots[len as usize] = s; + len += 1; + } + } + SlotList { slots, len } + } +} + +/// Recursive loop nest over slots. `loop_slots[i]` is a slot index, +/// `dims` and `vals` are flat [usize; 26] arrays. +#[inline(always)] +fn loop_nest( + loop_slots: &[u8], + dims: &[usize; 26], + vals: &mut [usize; 26], + emit: &mut impl FnMut(&[usize; 26]), +) { + if loop_slots.is_empty() { + emit(vals); + return; + } + let s = loop_slots[0] as usize; + let rest = &loop_slots[1..]; + let n = dims[s]; + for v in 0..n { + vals[s] = v; + loop_nest(rest, dims, vals, emit); + } +} + +// Iterative variant using an explicit counter stack. +// Benchmarked ~same as recursive. +#[cfg(any())] +#[inline(always)] +fn loop_nest_iterative( + loop_slots: &[u8], + dims: &[usize; 26], + vals: &mut [usize; 26], + emit: &mut impl FnMut(&[usize; 26]), +) { + let depth = loop_slots.len(); + if depth == 0 { + emit(vals); + return; + } + let mut counters = [0usize; 26]; + for i in 0..depth { + vals[loop_slots[i] as usize] = 0; + } + loop { + emit(vals); + let mut level = depth - 1; + loop { + let s = loop_slots[level] as usize; + counters[level] += 1; + if counters[level] < dims[s] { + vals[s] = counters[level]; + break; + } + counters[level] = 0; + vals[s] = 0; + if level == 0 { + return; + } + level -= 1; + } + } +} + +/// Validate output array dimensions against the spec. +pub(crate) fn validate_output>( + spec: &Spec, + dims: &[usize; 26], + out: &Arr, +) -> Result<(), InvalidSpec> { + if out.ndim() != spec.output().len() { + return Err(InvalidSpec::OutputNdimMismatch { + array_ndim: out.ndim(), + spec_ndim: spec.output().len(), + }); + } + for (pos, &s) in spec.output().iter().enumerate() { + if out.dim(pos) != dims[s as usize] { + return Err(InvalidSpec::OutputDimMismatch { + axis: pos, + expected: dims[s as usize], + got: out.dim(pos), + }); + } + } + Ok(()) +} + +/// `einsum_ary(spec, inputs, out)` — N-ary einsum with tensor output. +/// +/// Generalises [`einsum_binary`] and [`einsum_unary`] to an arbitrary number +/// of inputs. The spec must contain exactly `inputs.len()` comma-separated +/// input index groups. +/// +/// For scalar output, pass a 0-dimensional output tensor (`ndim() == 0`, +/// one element at index `&[]`) and use an empty output spec (e.g. `"i,i->"`). +/// +/// The specialised binary/unary functions are ~1.5× faster due to stack-only +/// buffers; prefer them for the 1- and 2-input cases on hot paths. +/// +/// ``` +/// # use einsum_dyn::{NDIndex, einsum_ary}; +/// # struct T { m: Vec, d: Vec } +/// # impl T { fn new(d: Vec) -> Self { let n = d.iter().product(); Self { m: vec![0.0; n], d } } fn li(&self, ix: &[usize]) -> usize { let mut i=0; let mut s=1; for (&k,&d) in ix.iter().rev().zip(self.d.iter().rev()) { i+=k*s; s*=d; } i } } +/// # impl NDIndex for T { fn ndim(&self)->usize{self.d.len()} fn dim(&self,a:usize)->usize{self.d[a]} fn get(&self,ix:&[usize])->f32{self.m[self.li(ix)]} fn set(&mut self,ix:&[usize],v:f32){let i=self.li(ix);self.m[i]=v;} } +/// let mut a = T::new(vec![2, 3]); +/// a.m = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; +/// let mut b = T::new(vec![3, 2]); +/// b.m = vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0]; +/// let mut c = T::new(vec![2, 2]); +/// einsum_ary("ab,bc->ac", &[&a, &b], &mut c).unwrap(); +/// assert_eq!(c.m, vec![58.0, 64.0, 139.0, 154.0]); +/// ``` +pub fn einsum_ary(spec: &str, inputs: &[&In], out: &mut Out) -> Result<(), InvalidSpec> +where + T: Default + Copy + AddAssign + Mul, + In: NDIndex, + Out: NDIndex, +{ + let spec = parse_spec(spec, inputs.len())?; + let dims = validate_dims(&spec, inputs)?; + validate_output(&spec, &dims, out)?; + + let free_slots = SlotList::from_slots(&spec.output()); + let all = all_slots_ordered(&spec); + let contracted_slots = SlotList::filtered_complement(all.as_slice(), &free_slots); + + let pats: Vec = spec.inputs.iter().map(|inp| Pattern::from_slots(inp)).collect(); + let pat_out = Pattern::from_slots(&spec.output()); + + let n = inputs.len(); + let mut vals = [0usize; 26]; + let mut bufs: Vec = (0..n).map(|_| Idx::ZERO).collect(); + let mut buf_out = Idx::ZERO; + + if contracted_slots.len == 0 { + // No contraction — direct assignment + loop_nest(free_slots.as_slice(), &dims, &mut vals, &mut |vals| { + for i in 0..n { + pats[i].gather(vals, &mut bufs[i]); + } + // Multiply all input values; if any is sparse-absent, result is zero + let first = match inputs[0].get_opt(bufs[0].as_slice()) { + Some(v) => v, + None => { pat_out.gather(vals, &mut buf_out); out.set(buf_out.as_slice(), Default::default()); return; } + }; + let mut product = first; + for i in 1..n { + match inputs[i].get_opt(bufs[i].as_slice()) { + Some(v) => product = product * v, + None => { pat_out.gather(vals, &mut buf_out); out.set(buf_out.as_slice(), Default::default()); return; } + } + } + pat_out.gather(vals, &mut buf_out); + out.set(buf_out.as_slice(), product); + }); + } else { + // With contraction — accumulate per output element + loop_nest(free_slots.as_slice(), &dims, &mut vals, &mut |free_vals| { + let mut acc: T = Default::default(); + let mut inner_vals = *free_vals; + loop_nest( + contracted_slots.as_slice(), + &dims, + &mut inner_vals, + &mut |vals| { + for i in 0..n { + pats[i].gather(vals, &mut bufs[i]); + } + let first = match inputs[0].get_opt(bufs[0].as_slice()) { + Some(v) => v, + None => return, + }; + let mut product = first; + for i in 1..n { + match inputs[i].get_opt(bufs[i].as_slice()) { + Some(v) => product = product * v, + None => return, + } + } + acc += product; + }, + ); + pat_out.gather(free_vals, &mut buf_out); + out.set(buf_out.as_slice(), acc); + }); + } + + Ok(()) +} + +/// `einsum_binary(spec, a, b, out)` — binary einsum with tensor output. +/// +/// Spec format: `"ab,bc->ac"` (numpy-style). +/// All indices in the output must appear in at least one input. +/// Indices present in inputs but absent from the output are contracted (summed over). +/// The output array must already have the correct shape. +pub fn einsum_binary(spec: &str, a: &In, b: &In, out: &mut Out) -> Result<(), InvalidSpec> +where + T: Default + Copy + AddAssign + Mul, + In: NDIndex, + Out: NDIndex, +{ + let spec = parse_spec(spec, 2)?; + let dims = validate_dims(&spec, &[a, b])?; + validate_output(&spec, &dims, out)?; + + let free_slots = SlotList::from_slots(&spec.output()); + let all = all_slots_ordered(&spec); + let contracted_slots = SlotList::filtered_complement(all.as_slice(), &free_slots); + + let pat_a = Pattern::from_slots(&spec.inputs[0]); + let pat_b = Pattern::from_slots(&spec.inputs[1]); + let pat_out = Pattern::from_slots(&spec.output()); + + let mut vals = [0usize; 26]; + let mut buf_a = Idx::ZERO; + let mut buf_b = Idx::ZERO; + let mut buf_out = Idx::ZERO; + + if contracted_slots.len == 0 { + // No contraction — direct assignment + loop_nest(free_slots.as_slice(), &dims, &mut vals, &mut |vals| { + pat_a.gather(vals, &mut buf_a); + pat_b.gather(vals, &mut buf_b); + pat_out.gather(vals, &mut buf_out); + let v = match (a.get_opt(buf_a.as_slice()), b.get_opt(buf_b.as_slice())) { + (Some(av), Some(bv)) => av * bv, + _ => Default::default(), + }; + out.set(buf_out.as_slice(), v); + }); + } else { + // With contraction — accumulate per output element + loop_nest(free_slots.as_slice(), &dims, &mut vals, &mut |free_vals| { + let mut acc: T = Default::default(); + let mut inner_vals = *free_vals; + loop_nest( + contracted_slots.as_slice(), + &dims, + &mut inner_vals, + &mut |vals| { + pat_a.gather(vals, &mut buf_a); + pat_b.gather(vals, &mut buf_b); + if let (Some(av), Some(bv)) = + (a.get_opt(buf_a.as_slice()), b.get_opt(buf_b.as_slice())) + { + acc += av * bv; + } + }, + ); + pat_out.gather(free_vals, &mut buf_out); + out.set(buf_out.as_slice(), acc); + }); + } + + Ok(()) +} + +/// `einsum_unary(spec, a, out)` — unary einsum with tensor output. +/// +/// Spec format: `"ab->ba"` (numpy-style). +pub fn einsum_unary(spec: &str, a: &In, out: &mut Out) -> Result<(), InvalidSpec> +where + T: Default + Copy + AddAssign + Mul, + In: NDIndex, + Out: NDIndex, +{ + let spec = parse_spec(spec, 1)?; + let dims = validate_dims(&spec, &[a])?; + validate_output(&spec, &dims, out)?; + + let free_slots = SlotList::from_slots(&spec.output()); + let all = all_slots_ordered(&spec); + let contracted_slots = SlotList::filtered_complement(all.as_slice(), &free_slots); + + let pat_a = Pattern::from_slots(&spec.inputs[0]); + let pat_out = Pattern::from_slots(&spec.output()); + let mut vals = [0usize; 26]; + let mut buf_a = Idx::ZERO; + let mut buf_out = Idx::ZERO; + + if contracted_slots.len == 0 { + loop_nest(free_slots.as_slice(), &dims, &mut vals, &mut |vals| { + pat_a.gather(vals, &mut buf_a); + pat_out.gather(vals, &mut buf_out); + let v = a.get_opt(buf_a.as_slice()).unwrap_or_default(); + out.set(buf_out.as_slice(), v); + }); + } else { + loop_nest(free_slots.as_slice(), &dims, &mut vals, &mut |free_vals| { + let mut acc: T = Default::default(); + let mut inner_vals = *free_vals; + loop_nest( + contracted_slots.as_slice(), + &dims, + &mut inner_vals, + &mut |vals| { + pat_a.gather(vals, &mut buf_a); + if let Some(av) = a.get_opt(buf_a.as_slice()) { + acc += av; + } + }, + ); + pat_out.gather(free_vals, &mut buf_out); + out.set(buf_out.as_slice(), acc); + }); + } + + Ok(()) +} + +/// `einsum_binary_scalar(spec, a, b)` — binary einsum with scalar output. +/// +/// Spec format: `"ab,ab->"` (empty output = scalar). +pub fn einsum_binary_scalar(spec: &str, a: &Arr, b: &Arr) -> Result +where + T: Default + Copy + AddAssign + Mul, + Arr: NDIndex, +{ + let spec = parse_spec(spec, 2)?; + let dims = validate_dims(&spec, &[a, b])?; + + if !spec.output().is_empty() { + return Err(InvalidSpec::NonEmptyScalarOutput); + } + + let all = all_slots_ordered(&spec); + let pat_a = Pattern::from_slots(&spec.inputs[0]); + let pat_b = Pattern::from_slots(&spec.inputs[1]); + let mut vals = [0usize; 26]; + let mut buf_a = Idx::ZERO; + let mut buf_b = Idx::ZERO; + let mut acc: T = Default::default(); + + loop_nest(all.as_slice(), &dims, &mut vals, &mut |vals| { + pat_a.gather(vals, &mut buf_a); + pat_b.gather(vals, &mut buf_b); + if let (Some(av), Some(bv)) = + (a.get_opt(buf_a.as_slice()), b.get_opt(buf_b.as_slice())) + { + acc += av * bv; + } + }); + + Ok(acc) +} + +/// `einsum_unary_scalar(spec, a)` — unary einsum with scalar output. +/// +/// Spec format: `"aa->"` (empty output = scalar). +pub fn einsum_unary_scalar(spec: &str, a: &Arr) -> Result +where + T: Default + Copy + AddAssign + Mul, + Arr: NDIndex, +{ + let spec = parse_spec(spec, 1)?; + let dims = validate_dims(&spec, &[a])?; + + if !spec.output().is_empty() { + return Err(InvalidSpec::NonEmptyScalarOutput); + } + + let all = all_slots_ordered(&spec); + let pat_a = Pattern::from_slots(&spec.inputs[0]); + let mut vals = [0usize; 26]; + let mut buf_a = Idx::ZERO; + let mut acc: T = Default::default(); + + loop_nest(all.as_slice(), &dims, &mut vals, &mut |vals| { + pat_a.gather(vals, &mut buf_a); + if let Some(av) = a.get_opt(buf_a.as_slice()) { + acc += av; + } + }); + + Ok(acc) +} + +#[cfg(test)] +mod tests { + use super::*; + + /// Minimal dense tensor for testing. + struct Tensor { + m: Vec, + d: Vec, + } + + impl Tensor { + fn new(d: Vec) -> Self { + let n: usize = d.iter().product(); + Self { + m: vec![0.0; n], + d, + } + } + + fn linear_index(&self, ix: &[usize]) -> usize { + let mut idx = 0usize; + let mut stride = 1usize; + for (&k, &dim) in ix.iter().rev().zip(self.d.iter().rev()) { + idx += k * stride; + stride *= dim; + } + idx + } + } + + impl NDIndex for Tensor { + fn ndim(&self) -> usize { + self.d.len() + } + + fn dim(&self, axis: usize) -> usize { + self.d[axis] + } + + fn get(&self, ix: &[usize]) -> f32 { + self.m[self.linear_index(ix)] + } + + fn set(&mut self, ix: &[usize], v: f32) { + let i = self.linear_index(ix); + self.m[i] = v; + } + } + + fn set_matrix(t: &mut Tensor, vals: &[f32]) { + for (i, &v) in vals.iter().enumerate() { + t.m[i] = v; + } + } + + #[test] + fn test_matmul() { + let mut a = Tensor::new(vec![2, 3]); + let mut b = Tensor::new(vec![3, 2]); + set_matrix(&mut a, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]); + set_matrix(&mut b, &[7.0, 8.0, 9.0, 10.0, 11.0, 12.0]); + + let mut c = Tensor::new(vec![2, 2]); + einsum_binary("ab,bc->ac", &a, &b, &mut c).unwrap(); + + assert_eq!(c.get(&[0, 0]), 58.0); + assert_eq!(c.get(&[0, 1]), 64.0); + assert_eq!(c.get(&[1, 0]), 139.0); + assert_eq!(c.get(&[1, 1]), 154.0); + } + + #[test] + fn test_transpose() { + let mut a = Tensor::new(vec![2, 3]); + set_matrix(&mut a, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]); + + let mut t = Tensor::new(vec![3, 2]); + einsum_unary("ab->ba", &a, &mut t).unwrap(); + + assert_eq!(t.get(&[0, 0]), 1.0); + assert_eq!(t.get(&[0, 1]), 4.0); + assert_eq!(t.get(&[1, 0]), 2.0); + assert_eq!(t.get(&[1, 1]), 5.0); + assert_eq!(t.get(&[2, 0]), 3.0); + assert_eq!(t.get(&[2, 1]), 6.0); + } + + #[test] + fn test_outer_product() { + let mut a = Tensor::new(vec![3]); + set_matrix(&mut a, &[1.0, 2.0, 3.0]); + let mut b = Tensor::new(vec![2]); + set_matrix(&mut b, &[4.0, 5.0]); + + let mut c = Tensor::new(vec![3, 2]); + einsum_binary("a,b->ab", &a, &b, &mut c).unwrap(); + + assert_eq!(c.get(&[0, 0]), 4.0); + assert_eq!(c.get(&[0, 1]), 5.0); + assert_eq!(c.get(&[1, 0]), 8.0); + assert_eq!(c.get(&[1, 1]), 10.0); + assert_eq!(c.get(&[2, 0]), 12.0); + assert_eq!(c.get(&[2, 1]), 15.0); + } + + #[test] + fn test_vecmat() { + let mut v = Tensor::new(vec![2]); + set_matrix(&mut v, &[1.0, 2.0]); + let mut m = Tensor::new(vec![2, 2]); + set_matrix(&mut m, &[3.0, 4.0, 5.0, 6.0]); + + let mut r = Tensor::new(vec![2]); + einsum_binary("a,ab->b", &v, &m, &mut r).unwrap(); + + assert_eq!(r.get(&[0]), 13.0); + assert_eq!(r.get(&[1]), 16.0); + } + + #[test] + fn test_dot() { + let mut a = Tensor::new(vec![4]); + let mut b = Tensor::new(vec![4]); + set_matrix(&mut a, &[1.0, 2.0, 3.0, 4.0]); + set_matrix(&mut b, &[5.0, 6.0, 7.0, 8.0]); + + let result: f32 = einsum_binary_scalar("i,i->", &a, &b).unwrap(); + assert_eq!(result, 70.0); + } + + #[test] + fn test_trace() { + let mut m = Tensor::new(vec![3, 3]); + set_matrix( + &mut m, + &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0], + ); + + let result: f32 = einsum_unary_scalar("aa->", &m).unwrap(); + assert_eq!(result, 15.0); + } + + #[test] + fn test_frobenius2() { + let mut a = Tensor::new(vec![2, 2]); + set_matrix(&mut a, &[1.0, 2.0, 3.0, 4.0]); + + let result: f32 = einsum_binary_scalar("ab,ab->", &a, &a).unwrap(); + assert_eq!(result, 30.0); + } + + #[test] + fn test_attention() { + let (b, h, q_len, k_len, dim) = (2, 2, 3, 2, 4); + let mut q = Tensor::new(vec![b, h, q_len, dim]); + let mut k = Tensor::new(vec![b, h, k_len, dim]); + + for bi in 0..b { + for hi in 0..h { + for qi in 0..q_len { + for di in 0..dim { + let v = + (bi + 1) as f32 * (hi + 1) as f32 * (qi + 1) as f32 + di as f32; + q.set(&[bi, hi, qi, di], v); + } + } + for ki in 0..k_len { + for di in 0..dim { + let v = + (bi + 1) as f32 * (hi + 1) as f32 * (ki + 1) as f32 * (di + 1) as f32; + k.set(&[bi, hi, ki, di], v); + } + } + } + } + + let mut out = Tensor::new(vec![b, h, q_len, k_len]); + einsum_binary("bhqd,bhkd->bhqk", &q, &k, &mut out).unwrap(); + + for bi in 0..b { + for hi in 0..h { + for qi in 0..q_len { + for ki in 0..k_len { + let mut expected = 0.0f32; + for di in 0..dim { + expected += q.get(&[bi, hi, qi, di]) * k.get(&[bi, hi, ki, di]); + } + let actual = out.get(&[bi, hi, qi, ki]); + assert!( + (actual - expected).abs() < 1e-3, + "mismatch at [{bi},{hi},{qi},{ki}]: got {actual}, expected {expected}" + ); + } + } + } + } + } + + #[test] + fn test_err_missing_arrow() { + let a = Tensor::new(vec![2, 3]); + let b = Tensor::new(vec![3, 2]); + let mut c = Tensor::new(vec![2, 2]); + assert_eq!( + einsum_binary("ab,bc", &a, &b, &mut c).unwrap_err(), + InvalidSpec::MissingArrow + ); + } + + #[test] + fn test_err_invalid_index() { + let a = Tensor::new(vec![2, 3]); + let b = Tensor::new(vec![3, 2]); + let mut c = Tensor::new(vec![2, 2]); + assert_eq!( + einsum_binary("aB,bc->ac", &a, &b, &mut c).unwrap_err(), + InvalidSpec::InvalidIndex { ch: 'B' } + ); + // Invalid char in output + assert_eq!( + einsum_binary("ab,bc->a1", &a, &b, &mut c).unwrap_err(), + InvalidSpec::InvalidIndex { ch: '1' } + ); + } + + #[test] + fn test_err_wrong_input_count() { + let a = Tensor::new(vec![2, 3]); + let b = Tensor::new(vec![3, 2]); + let mut c = Tensor::new(vec![2, 2]); + // Binary function but spec has 1 input + assert_eq!( + einsum_binary("ab->ab", &a, &b, &mut c).unwrap_err(), + InvalidSpec::WrongInputCount { expected: 2, got: 1 } + ); + // Unary function but spec has 2 inputs + assert_eq!( + einsum_unary("ab,bc->ac", &a, &mut c).unwrap_err(), + InvalidSpec::WrongInputCount { expected: 1, got: 2 } + ); + } + + #[test] + fn test_err_empty_input() { + let a = Tensor::new(vec![2, 3]); + let b = Tensor::new(vec![3, 2]); + let mut c = Tensor::new(vec![2, 2]); + assert_eq!( + einsum_binary(",bc->bc", &a, &b, &mut c).unwrap_err(), + InvalidSpec::EmptyInput { input: 0 } + ); + } + + #[test] + fn test_err_unbound_output_index() { + let a = Tensor::new(vec![2, 3]); + let b = Tensor::new(vec![3, 2]); + let mut c = Tensor::new(vec![2, 2]); + assert_eq!( + einsum_binary("ab,bc->az", &a, &b, &mut c).unwrap_err(), + InvalidSpec::UnboundOutputIndex { index: 'z' } + ); + } + + #[test] + fn test_err_input_ndim_mismatch() { + // a is 2D but spec says 3 indices + let a = Tensor::new(vec![2, 3]); + let b = Tensor::new(vec![3, 2]); + let mut c = Tensor::new(vec![2, 2]); + assert_eq!( + einsum_binary("abc,bd->ad", &a, &b, &mut c).unwrap_err(), + InvalidSpec::InputNdimMismatch { input: 0, array_ndim: 2, spec_ndim: 3 } + ); + } + + #[test] + fn test_err_dimension_mismatch() { + // a is 2×3, b is 3×2, spec says first dims must match (a=2 vs a=3) + let a = Tensor::new(vec![2, 3]); + let b = Tensor::new(vec![3, 2]); + let mut c = Tensor::new(vec![2, 2]); + assert_eq!( + einsum_binary("ab,ac->bc", &a, &b, &mut c).unwrap_err(), + InvalidSpec::DimensionMismatch { index: 'a', expected: 2, got: 3 } + ); + } + + #[test] + fn test_err_output_ndim_mismatch() { + let a = Tensor::new(vec![2, 3]); + let b = Tensor::new(vec![3, 2]); + // Output is 1D but spec says 2D + let mut c = Tensor::new(vec![2]); + assert_eq!( + einsum_binary("ab,bc->ac", &a, &b, &mut c).unwrap_err(), + InvalidSpec::OutputNdimMismatch { array_ndim: 1, spec_ndim: 2 } + ); + } + + #[test] + fn test_err_output_dim_mismatch() { + let a = Tensor::new(vec![2, 3]); + let b = Tensor::new(vec![3, 2]); + // Output should be 2×2 but we give 2×3 + let mut c = Tensor::new(vec![2, 3]); + assert_eq!( + einsum_binary("ab,bc->ac", &a, &b, &mut c).unwrap_err(), + InvalidSpec::OutputDimMismatch { axis: 1, expected: 2, got: 3 } + ); + } + + #[test] + fn test_err_non_empty_scalar_output() { + let a = Tensor::new(vec![2, 3]); + let b = Tensor::new(vec![2, 3]); + assert_eq!( + einsum_binary_scalar::("ab,ab->a", &a, &b).unwrap_err(), + InvalidSpec::NonEmptyScalarOutput + ); + assert_eq!( + einsum_unary_scalar::("ab->a", &a).unwrap_err(), + InvalidSpec::NonEmptyScalarOutput + ); + } + + #[test] + fn test_unary_row_sum() { + let mut a = Tensor::new(vec![2, 3]); + set_matrix(&mut a, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]); + + let mut out = Tensor::new(vec![2]); + einsum_unary("ab->a", &a, &mut out).unwrap(); + + assert_eq!(out.get(&[0]), 6.0); + assert_eq!(out.get(&[1]), 15.0); + } + + // --- einsum_ary tests --- + + #[test] + fn test_ary_matmul_as_binary() { + // einsum_ary with 2 inputs should match einsum_binary + let mut a = Tensor::new(vec![2, 3]); + let mut b = Tensor::new(vec![3, 2]); + set_matrix(&mut a, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]); + set_matrix(&mut b, &[7.0, 8.0, 9.0, 10.0, 11.0, 12.0]); + + let mut c = Tensor::new(vec![2, 2]); + einsum_ary("ab,bc->ac", &[&a, &b], &mut c).unwrap(); + + assert_eq!(c.get(&[0, 0]), 58.0); + assert_eq!(c.get(&[0, 1]), 64.0); + assert_eq!(c.get(&[1, 0]), 139.0); + assert_eq!(c.get(&[1, 1]), 154.0); + } + + #[test] + fn test_ary_transpose_as_unary() { + // einsum_ary with 1 input should match einsum_unary + let mut a = Tensor::new(vec![2, 3]); + set_matrix(&mut a, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]); + + let mut t = Tensor::new(vec![3, 2]); + einsum_ary("ab->ba", &[&a], &mut t).unwrap(); + + assert_eq!(t.get(&[0, 0]), 1.0); + assert_eq!(t.get(&[0, 1]), 4.0); + assert_eq!(t.get(&[1, 0]), 2.0); + assert_eq!(t.get(&[2, 1]), 6.0); + } + + #[test] + fn test_ary_three_input_chain() { + // A(2×3) × B(3×4) × C(4×2) -> D(2×2) + // spec: "ab,bc,cd->ad" contracts b and c + let mut a = Tensor::new(vec![2, 3]); + set_matrix(&mut a, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]); + let mut b = Tensor::new(vec![3, 4]); + set_matrix(&mut b, &[1.0, 0.0, 0.0, 0.0, + 0.0, 1.0, 0.0, 0.0, + 0.0, 0.0, 1.0, 0.0]); + let mut c = Tensor::new(vec![4, 2]); + set_matrix(&mut c, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]); + + let mut d = Tensor::new(vec![2, 2]); + einsum_ary("ab,bc,cd->ad", &[&a, &b, &c], &mut d).unwrap(); + + // AB = [[1,2,3,0],[4,5,6,0]], ABC = AB×C = [[1*1+2*3+3*5, 1*2+2*4+3*6],[4*1+5*3+6*5, 4*2+5*4+6*6]] + // = [[22, 28],[49, 64]] + assert_eq!(d.get(&[0, 0]), 22.0); + assert_eq!(d.get(&[0, 1]), 28.0); + assert_eq!(d.get(&[1, 0]), 49.0); + assert_eq!(d.get(&[1, 1]), 64.0); + } + + #[test] + fn test_ary_outer_product_three() { + // No contraction: a(i) × b(j) × c(k) -> out(i,j,k) + let mut a = Tensor::new(vec![2]); + set_matrix(&mut a, &[2.0, 3.0]); + let mut b = Tensor::new(vec![2]); + set_matrix(&mut b, &[5.0, 7.0]); + let mut c = Tensor::new(vec![2]); + set_matrix(&mut c, &[11.0, 13.0]); + + let mut out = Tensor::new(vec![2, 2, 2]); + einsum_ary("a,b,c->abc", &[&a, &b, &c], &mut out).unwrap(); + + // out[i,j,k] = a[i]*b[j]*c[k] + assert_eq!(out.get(&[0, 0, 0]), 2.0 * 5.0 * 11.0); + assert_eq!(out.get(&[0, 0, 1]), 2.0 * 5.0 * 13.0); + assert_eq!(out.get(&[0, 1, 0]), 2.0 * 7.0 * 11.0); + assert_eq!(out.get(&[1, 1, 1]), 3.0 * 7.0 * 13.0); + } + + #[test] + fn test_ary_scalar_dot() { + // Scalar output via 0-dim tensor: "i,i->" (dot product) + let mut a = Tensor::new(vec![4]); + let mut b = Tensor::new(vec![4]); + set_matrix(&mut a, &[1.0, 2.0, 3.0, 4.0]); + set_matrix(&mut b, &[5.0, 6.0, 7.0, 8.0]); + + let mut out = Tensor::new(vec![]); // 0-dim + einsum_ary("i,i->", &[&a, &b], &mut out).unwrap(); + + assert_eq!(out.get(&[]), 70.0); + } + + #[test] + fn test_ary_scalar_trace() { + // Scalar output via 0-dim tensor: "aa->" (trace) + let mut m = Tensor::new(vec![3, 3]); + set_matrix(&mut m, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]); + + let mut out = Tensor::new(vec![]); + einsum_ary("aa->", &[&m], &mut out).unwrap(); + + assert_eq!(out.get(&[]), 15.0); + } + + #[test] + fn test_ary_scalar_three_input() { + // "i,i,i->" — element-wise product summed to scalar + let mut a = Tensor::new(vec![3]); + let mut b = Tensor::new(vec![3]); + let mut c = Tensor::new(vec![3]); + set_matrix(&mut a, &[1.0, 2.0, 3.0]); + set_matrix(&mut b, &[4.0, 5.0, 6.0]); + set_matrix(&mut c, &[7.0, 8.0, 9.0]); + + let mut out = Tensor::new(vec![]); + einsum_ary("i,i,i->", &[&a, &b, &c], &mut out).unwrap(); + + // 1*4*7 + 2*5*8 + 3*6*9 = 28 + 80 + 162 = 270 + assert_eq!(out.get(&[]), 270.0); + } + + #[test] + fn test_ary_row_sum_unary() { + let mut a = Tensor::new(vec![2, 3]); + set_matrix(&mut a, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]); + + let mut out = Tensor::new(vec![2]); + einsum_ary("ab->a", &[&a], &mut out).unwrap(); + + assert_eq!(out.get(&[0]), 6.0); + assert_eq!(out.get(&[1]), 15.0); + } +} + +#[cfg(test)] +#[cfg(feature = "ndarray")] +mod ndarray_tests { + use super::*; + use ndarray::{ArrayD, IxDyn}; + + fn arr1(data: &[f64]) -> ArrayD { + ArrayD::from_shape_vec(IxDyn(&[data.len()]), data.to_vec()).unwrap() + } + + fn arr2(rows: usize, cols: usize, data: &[f64]) -> ArrayD { + ArrayD::from_shape_vec(IxDyn(&[rows, cols]), data.to_vec()).unwrap() + } + + fn zeros(shape: &[usize]) -> ArrayD { + ArrayD::zeros(IxDyn(shape)) + } + + #[test] + fn test_ndindex_trait() { + let a = arr2(2, 2, &[1.0, 2.0, 3.0, 4.0]); + assert_eq!(NDIndex::ndim(&a), 2); + assert_eq!(NDIndex::dim(&a, 0), 2); + assert_eq!(NDIndex::dim(&a, 1), 2); + assert_eq!(NDIndex::get(&a, &[0, 1]), 2.0); + assert_eq!(NDIndex::get(&a, &[1, 0]), 3.0); + + let mut b = a.clone(); + NDIndex::set(&mut b, &[0, 0], 99.0); + assert_eq!(NDIndex::get(&b, &[0, 0]), 99.0); + } + + #[test] + fn test_matmul() { + let a = arr2(2, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]); + let b = arr2(3, 2, &[7.0, 8.0, 9.0, 10.0, 11.0, 12.0]); + let mut c = zeros(&[2, 2]); + + einsum_binary("ab,bc->ac", &a, &b, &mut c).unwrap(); + + assert_eq!(c[IxDyn(&[0, 0])], 58.0); + assert_eq!(c[IxDyn(&[0, 1])], 64.0); + assert_eq!(c[IxDyn(&[1, 0])], 139.0); + assert_eq!(c[IxDyn(&[1, 1])], 154.0); + } + + #[test] + fn test_transpose() { + let a = arr2(2, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]); + let mut t = zeros(&[3, 2]); + + einsum_unary("ab->ba", &a, &mut t).unwrap(); + + assert_eq!(t[IxDyn(&[0, 0])], 1.0); + assert_eq!(t[IxDyn(&[0, 1])], 4.0); + assert_eq!(t[IxDyn(&[1, 0])], 2.0); + assert_eq!(t[IxDyn(&[2, 1])], 6.0); + } + + #[test] + fn test_dot() { + let a = arr1(&[1.0, 2.0, 3.0, 4.0]); + let b = arr1(&[5.0, 6.0, 7.0, 8.0]); + + let result: f64 = einsum_binary_scalar("i,i->", &a, &b).unwrap(); + assert_eq!(result, 70.0); + } + + #[test] + fn test_trace() { + let m = arr2(3, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]); + let result: f64 = einsum_unary_scalar("aa->", &m).unwrap(); + assert_eq!(result, 15.0); + } + + #[test] + fn test_outer_product() { + let a = arr1(&[1.0, 2.0, 3.0]); + let b = arr1(&[4.0, 5.0]); + let mut c = zeros(&[3, 2]); + + einsum_binary("a,b->ab", &a, &b, &mut c).unwrap(); + + assert_eq!(c[IxDyn(&[0, 0])], 4.0); + assert_eq!(c[IxDyn(&[1, 1])], 10.0); + assert_eq!(c[IxDyn(&[2, 0])], 12.0); + } + + #[test] + fn test_row_sum() { + let a = arr2(2, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]); + let mut out = zeros(&[2]); + + einsum_unary("ab->a", &a, &mut out).unwrap(); + + assert_eq!(out[IxDyn(&[0])], 6.0); + assert_eq!(out[IxDyn(&[1])], 15.0); + } +} diff --git a/einsum-dyn/src/sparse.rs b/einsum-dyn/src/sparse.rs new file mode 100644 index 0000000..da2ed2c --- /dev/null +++ b/einsum-dyn/src/sparse.rs @@ -0,0 +1,1494 @@ +//! Sparse einsum: multiple approaches for sparse 2D matrix einsum. +//! +//! # Approaches +//! +//! ## 1. Baseline (`einsum_binary` with `get_opt`) +//! The existing dense-loop einsum from `lib.rs`. Iterates the full n×n×n index +//! space for matmul, but skips zero products via `get_opt()`. Complexity: O(n³). +//! No new code needed — use `einsum_binary` directly with any `NDIndex` impl. +//! +//! ## 2. Sparse-driven with dense accumulator (`einsum_sparse_driven`) +//! Equivalent to matmul (C = A × B). Only supports the `"ab,bc->ac"` pattern. +//! Iterates only non-zero entries of inputs using the `Sparse2D` trait: +//! loops A's rows, for each NZ (k, a_val) in row i, iterates B's row k +//! sparsely. Dense `Vec` accumulator per output row. +//! Complexity: O(Σ_i Σ_{k∈row(i)} row_nnz_B(k)) = O(flops). +//! +//! ## 3. Custom VM (`einsum_vm`) +//! Compiles the einsum spec into a tree of VM operations. A greedy scheduler +//! decides which loops iterate sparsely (via `Sparse2D::row_entry`) vs densely. +//! The VM interpreter walks the tree recursively. Same asymptotic complexity +//! as approach 2, but handles arbitrary 2D specs without hardcoding patterns. +//! +//! ## 4. Sparse-driven with hash accumulator (`einsum_sparse_hash`) +//! Same sparse iteration as approach 2, but accumulates into a `HashMap` +//! per output row instead of a dense vector. Better when the output is very sparse +//! (each row has few non-zeros), since it avoids allocating/clearing an n-wide +//! accumulator. Worse for dense output due to hashing overhead. + +use std::collections::HashMap; +use std::ops::{Add, AddAssign, Mul}; + +use crate::{NDIndex, InvalidSpec}; + +// ═══════════════════════════════════════════════════════════════════════════ +// Sparse2D trait +// ═══════════════════════════════════════════════════════════════════════════ + +/// Extension of `NDIndex` for sparse 2D arrays (matrices). +/// +/// Provides structured row-wise access to non-zero entries, enabling +/// sparse-driven einsum execution that skips entire zero regions. +pub trait Sparse2D: NDIndex { + /// Total number of structural non-zeros. + fn nnz(&self) -> usize; + + /// Number of rows (axis 0 dimension). + fn n_rows(&self) -> usize; + + /// Number of non-zero entries in the given row. + fn row_nnz(&self, row: usize) -> usize; + + /// Get the `idx`-th non-zero entry in the given row as `(col, value)`. + /// `idx` must be in `0..row_nnz(row)`. + fn row_entry(&self, row: usize, idx: usize) -> (usize, T); +} + +// ═══════════════════════════════════════════════════════════════════════════ +// Approach 2: Sparse-driven with dense accumulator +// ═══════════════════════════════════════════════════════════════════════════ + +/// Sparse-driven binary einsum with dense row accumulator. +/// +/// Only supports the standard matmul pattern: A's axis-1 index must equal +/// B's axis-0 index (the contracted/inner dimension). Any letter names work +/// (`"ab,bc->ac"`, `"xy,yz->xz"`, etc.) but the structure must be matmul. +/// For other specs, use `einsum_vm_oneshot` or `einsum_sparse_hash`. +/// +/// Both inputs must implement `Sparse2D`. The output must be a writable +/// `NDIndex` (typically dense) with the correct shape. +pub fn einsum_sparse_driven( + spec_str: &str, + a: &S, + b: &S, + out: &mut Out, +) -> Result<(), InvalidSpec> +where + T: Default + Copy + PartialEq + AddAssign + Mul, + S: Sparse2D, + Out: NDIndex, +{ + let spec = crate::parse_spec(spec_str, 2)?; + let dims = crate::validate_dims::(&spec, &[a, b])?; + crate::validate_output::(&spec, &dims, out)?; + + let a_slots = &spec.inputs[0]; + let b_slots = &spec.inputs[1]; + let out_slots = &spec.output(); + + assert_eq!(a_slots.len(), 2, "sparse-driven requires 2D inputs"); + assert_eq!(b_slots.len(), 2, "sparse-driven requires 2D inputs"); + + let (a0, a1) = (a_slots[0], a_slots[1]); + let (_b0, b1) = (b_slots[0], b_slots[1]); + + // Verify A's col index == B's row index. + // This restricts us to the matmul pattern (C = A × B) — the only + // structural arrangement where both inputs can be walked row-wise + // through CSR without transposing. + if a1 != b_slots[0] { + panic!( + "einsum_sparse_driven only supports matmul pattern \ + (A's axis-1 == B's axis-0), got spec '{spec_str}'" + ); + } + + // Map output slot positions: which output axis gets which slot value + let out_pos_a0 = out_slots.iter().position(|&s| s == a0); + let out_pos_b1 = out_slots.iter().position(|&s| s == b1); + let n_out_dims = out_slots.len(); + + // Dense accumulator sized to the full output column range. + // Track which columns were touched to avoid clearing the entire vector. + let n_cols_out = dims[b1 as usize]; + let mut acc = vec![T::default(); n_cols_out]; + let mut nz_cols: Vec = Vec::new(); + let mut out_ix = vec![0usize; n_out_dims]; + + let n_rows_a = a.n_rows(); + + for i in 0..n_rows_a { + // Scatter: iterate A's row i, then B's matching rows + for ai in 0..a.row_nnz(i) { + let (k, a_val) = a.row_entry(i, ai); + for bi in 0..b.row_nnz(k) { + let (j, b_val) = b.row_entry(k, bi); + if acc[j] == T::default() { + nz_cols.push(j); + } + acc[j] += a_val * b_val; + } + } + + // Write touched columns to output, then clear them + if let Some(p) = out_pos_a0 { + out_ix[p] = i; + } + for &j in &nz_cols { + if let Some(p) = out_pos_b1 { + out_ix[p] = j; + } + out.set(&out_ix[..n_out_dims], acc[j]); + acc[j] = T::default(); + } + nz_cols.clear(); + } + + Ok(()) +} + +// ═══════════════════════════════════════════════════════════════════════════ +// Approach 3: Custom VM +// ═══════════════════════════════════════════════════════════════════════════ + +/// VM operation — flat bytecode. `DenseLoop` and `SparseRowLoop` mark loop +/// starts; `LoopEnd` marks the end of the innermost enclosing loop. +/// Each loop-start stores the pc of its matching `LoopEnd` (and vice versa) +/// so the interpreter never needs to scan for matching brackets. +#[derive(Debug)] +pub enum VmOp { + /// Dense loop: iterate `slot` from 0 to `dim-1`. + /// `end_pc` points one past the matching `LoopEnd`. + /// When `fused`, the body is a single `MulAcc` — the loop runs the + /// multiply-accumulate inline without recursing into `exec_at`. + DenseLoop { slot: u8, dim: usize, end_pc: usize, fused: bool }, + /// Sparse row iteration: for each non-zero in `input[input_idx]` + /// at the row given by `vals[row_slot]`, set `vals[col_slot]` to the + /// column index. + /// `end_pc` points one past the matching `LoopEnd`. + /// When `fused`, the body is a single `MulAcc` — the loop runs the + /// multiply-accumulate inline without recursing into `exec_at`. + SparseRowLoop { + input_idx: usize, + row_slot: u8, + col_slot: u8, + end_pc: usize, + fused: bool, + }, + /// End of the enclosing loop. `start_pc` points back to the loop-start. + LoopEnd { start_pc: usize }, + /// Initialize a dense accumulator of size `dim`, indexed by `acc_slot`. + /// Placed just before the loops that should accumulate. + AccStart { acc_slot: u8, acc_out_pos: u8, dim: usize }, + /// Flush the dense accumulator to the output (scatter-gather: only write + /// and clear touched entries), then reset. + AccFlush, + /// Read input values at current slot positions, multiply, and + /// accumulate into the output (or into the active accumulator). + MulAcc, +} + +/// Precompiled VM program for sparse einsum. +pub struct VmProgram { + /// Flat bytecode. + pub ops: Vec, + /// Input slot patterns: input_patterns[i] = list of slot indices for input i. + input_patterns: Vec>, + /// Output slot patterns: one per output in the spec. + output_patterns: Vec>, + /// For each input: `Some(loop_index)` of the SparseRowLoop that fully covers + /// it (both axes iterated by that one loop), or `None` if `get_opt` is needed. + /// When `Some`, `MulAcc` reads the cached sparse value instead of calling `get_opt`. + sparse_value_source: Vec>, +} + +/// Compile an einsum spec into a VM program. +/// +/// Accepts inputs of any dimensionality. The compiler uses a greedy strategy: +/// 1. For each 2D sparse input (`is_sparse_2d()`), its axis-1 slot can be +/// iterated sparsely when its axis-0 slot is already fixed. +/// 2. Non-sparse or higher-dimensional inputs use dense loops for all slots. +/// 3. When choosing dense loops, prefer axis-0 slots of sparse inputs first +/// (so their axis-1 can be sparse in the next level). +/// +/// The `inputs` slice is used to check dimensionality and `is_sparse_2d()`. +/// It is not retained — execution uses separately provided inputs. +pub fn compile_vm( + spec_str: &str, + inputs: &[&dyn NDIndex], +) -> Result { + let n_inputs = inputs.len(); + let spec = crate::parse_spec(spec_str, n_inputs)?; + + // Validate dimensions + let mut dims = [0usize; 26]; + let mut dim_set = [false; 26]; + for (pi, inp_spec) in spec.inputs.iter().enumerate() { + let arr = inputs[pi]; + if arr.ndim() != inp_spec.len() { + return Err(InvalidSpec::InputNdimMismatch { + input: pi, + array_ndim: arr.ndim(), + spec_ndim: inp_spec.len(), + }); + } + for (pos, &s) in inp_spec.iter().enumerate() { + let si = s as usize; + let d = arr.dim(pos); + if dim_set[si] { + if dims[si] != d { + return Err(InvalidSpec::DimensionMismatch { + index: (s + b'a') as char, + expected: dims[si], + got: d, + }); + } + } else { + dims[si] = d; + dim_set[si] = true; + } + } + } + + // Identify which inputs are sparse-2D eligible for SparseRowLoop. + // An input qualifies when it has exactly 2 spec indices AND is_sparse_2d(). + // For qualifying inputs, record (axis_0_slot, axis_1_slot). + let sparse_axes: Vec> = spec + .inputs + .iter() + .zip(inputs.iter()) + .map(|(inp_spec, arr)| { + if inp_spec.len() == 2 && arr.is_sparse_2d() { + Some((inp_spec[0], inp_spec[1])) + } else { + None + } + }) + .collect(); + + // Collect all unique slots in order of first appearance + let mut all_slots = Vec::new(); + let mut seen = [false; 26]; + for inp in &spec.inputs { + for &s in inp { + if !seen[s as usize] { + seen[s as usize] = true; + all_slots.push(s); + } + } + } + for out in &spec.outputs { + for &s in out { + if !seen[s as usize] { + seen[s as usize] = true; + all_slots.push(s); + } + } + } + + // Greedy scheduler: decide loop order, then emit flat bytecode. + let mut fixed = [false; 26]; + let mut loop_order: Vec = Vec::new(); + let mut n_fixed = 0usize; + + while n_fixed < all_slots.len() { + let mut found_sparse = false; + + // Try to find a slot that can be sparse-iterated: + // s is axis-1 of a sparse-2D input whose axis-0 is already fixed. + for &s in &all_slots { + if fixed[s as usize] { + continue; + } + for (idx, axes) in sparse_axes.iter().enumerate() { + if let Some((ax0, ax1)) = axes { + if *ax1 == s && fixed[*ax0 as usize] { + loop_order.push(VmOp::SparseRowLoop { + input_idx: idx, + row_slot: *ax0, + col_slot: s, + end_pc: 0, // patched below + fused: false, // patched below + }); + fixed[s as usize] = true; + n_fixed += 1; + found_sparse = true; + break; + } + } + } + if found_sparse { + break; + } + } + + if !found_sparse { + // Dense loop. Prefer axis-0 slots of sparse inputs (enables future sparse). + let mut best = None; + for &s in &all_slots { + if fixed[s as usize] { + continue; + } + let is_sparse_ax0 = sparse_axes + .iter() + .any(|axes| matches!(axes, Some((ax0, _)) if *ax0 == s)); + if is_sparse_ax0 || best.is_none() { + best = Some(s); + if is_sparse_ax0 { + break; + } + } + } + let s = best.unwrap(); + loop_order.push(VmOp::DenseLoop { + slot: s, + dim: dims[s as usize], + end_pc: 0, // patched below + fused: false, // patched below + }); + fixed[s as usize] = true; + n_fixed += 1; + } + } + + // For each input, check if a single SparseRowLoop covers both its axes. + // If so, MulAcc can use the cached sparse value instead of get_opt(). + let sparse_value_source: Vec> = spec + .inputs + .iter() + .enumerate() + .map(|(inp_idx, inp_spec)| { + if inp_spec.len() != 2 { + return None; + } + // Find a SparseRowLoop in loop_order that iterates this input + for (loop_idx, op) in loop_order.iter().enumerate() { + if let VmOp::SparseRowLoop { input_idx, row_slot, col_slot, .. } = op { + if *input_idx == inp_idx { + // Check this loop covers both axes of this input + if sparse_axes[inp_idx] == Some((*row_slot, *col_slot)) { + return Some(loop_idx); + } + } + } + } + None + }) + .collect(); + + // Accumulator: if the innermost loop's slot appears in the output pattern + // and there's at least one other output slot, we can use a dense accumulator. + // We emit AccStart just inside the outermost output-slot loop and AccFlush + // just before each iteration's end of that same loop. + // + // For "ab,bc->ac": output=[a,c], innermost loop is c. + // acc_slot=c, flush_loop_idx=0 (the 'a' loop). + // Bytecode: FOR a | AccStart | FOR b(sparse) | FOR c(sparse) | MulAcc | + // LoopEnd(c) | LoopEnd(b) | AccFlush | LoopEnd(a) + // Accumulator optimization: only for single-output specs. + // With multi-output, different outputs may have different index layouts, + // so the dense accumulator trick doesn't generalise cleanly. + let all_output_slots = spec.all_output_slots(); + let mut acc_info: Option<(u8, u8, usize, usize)> = None; // (acc_slot, acc_out_pos, acc_dim, flush_loop_idx) + if spec.outputs.len() == 1 { + if let Some(last_op) = loop_order.last() { + let inner_slot = match last_op { + VmOp::DenseLoop { slot, .. } => *slot, + VmOp::SparseRowLoop { col_slot, .. } => *col_slot, + _ => unreachable!(), + }; + if let Some(pos) = spec.output().iter().position(|&s| s == inner_slot) { + for (i, op) in loop_order.iter().enumerate().rev() { + let s = match op { + VmOp::DenseLoop { slot, .. } => *slot, + VmOp::SparseRowLoop { col_slot, .. } => *col_slot, + _ => unreachable!(), + }; + if s != inner_slot && all_output_slots.contains(&s) { + acc_info = Some((inner_slot, pos as u8, dims[inner_slot as usize], i)); + break; + } + } + } + } + } + + // Emit flat bytecode. + // Layout: loops... MulAcc LoopEnd... with AccStart/AccFlush injected. + let n_loops = loop_order.len(); + let mut ops: Vec = Vec::with_capacity(n_loops * 2 + 4); + + // Emit loop-starts, injecting AccStart after the flush loop + for (i, op) in loop_order.into_iter().enumerate() { + ops.push(op); + if let Some((acc_slot, acc_out_pos, dim, flush_idx)) = acc_info { + if i == flush_idx { + ops.push(VmOp::AccStart { acc_slot, acc_out_pos, dim }); + } + } + } + ops.push(VmOp::MulAcc); + + // Emit LoopEnds (innermost first), injecting AccFlush before the flush loop's LoopEnd + for i in 0..n_loops { + let loop_idx = n_loops - 1 - i; + if let Some((_, _, _, flush_idx)) = acc_info { + if loop_idx == flush_idx { + ops.push(VmOp::AccFlush); + } + } + // Find this loop's start_pc by scanning back for the loop_idx-th loop-start + // The loop-starts are at varying positions due to AccStart injection. + // Track them: loop_idx corresponds to the loop_idx-th loop-start in ops. + let start_pc = ops.iter().enumerate() + .filter(|(_, op)| matches!(op, VmOp::DenseLoop{..} | VmOp::SparseRowLoop{..})) + .nth(loop_idx) + .unwrap().0; + let end_pc = ops.len() + 1; // one past the LoopEnd we're about to push + match &mut ops[start_pc] { + VmOp::DenseLoop { end_pc: ep, .. } => *ep = end_pc, + VmOp::SparseRowLoop { end_pc: ep, .. } => *ep = end_pc, + _ => unreachable!(), + } + ops.push(VmOp::LoopEnd { start_pc }); + } + + // Fusion: if a loop's body is just MulAcc (followed by LoopEnd), mark it fused. + // The loop will inline the multiply-accumulate without recursing. + for pc in 0..ops.len() { + let is_loop = matches!(&ops[pc], VmOp::DenseLoop{..} | VmOp::SparseRowLoop{..}); + if is_loop && matches!(&ops[pc + 1], VmOp::MulAcc) { + match &mut ops[pc] { + VmOp::DenseLoop { fused, .. } => *fused = true, + VmOp::SparseRowLoop { fused, .. } => *fused = true, + _ => unreachable!(), + } + } + } + + Ok(VmProgram { + ops, + input_patterns: spec.inputs.clone(), + output_patterns: spec.outputs.clone(), + sparse_value_source, + }) +} + +/// Accumulator state for the VM interpreter. +struct AccState { + acc: Vec, + nz_cols: Vec, + acc_slot: u8, + acc_out_pos: u8, +} + +impl VmProgram { + /// Execute this compiled VM program. + /// + /// Inputs can be any mix of dense and sparse `NDIndex` implementations. + /// `SparseRowLoop` ops call `sparse_row_nnz` / `sparse_row_entry` on the + /// relevant input — the compiler only emits these for inputs where + /// `is_sparse_2d()` returned true. + /// + /// Optimizations over naive interpretation: + /// - Sparse values from `row_entry()` are cached and reused in `MulAcc`, + /// avoiding redundant `get_opt()` binary searches. + /// - `AccStart`/`AccFlush` ops enable scatter-gather accumulation for the + /// innermost output dimension. + pub fn exec( + &self, + inputs: &[&dyn NDIndex], + outs: &mut [&mut dyn NDIndex], + ) where + T: Default + Copy + Add + AddAssign + Mul + PartialEq, + { + let mut vals = [0usize; 26]; + let mut buf = [0usize; 26]; + let mut sparse_vals: Vec = vec![T::default(); inputs.len()]; + let mut acc_state: Option> = None; + self.exec_at(0, &mut vals, &mut buf, &mut sparse_vals, &mut acc_state, inputs, outs); + } + + /// Execute bytecode starting at `pc`. Returns the pc after the + /// matching `LoopEnd` (or end of program). + fn exec_at( + &self, + mut pc: usize, + vals: &mut [usize; 26], + buf: &mut [usize; 26], + sparse_vals: &mut [T], + acc_state: &mut Option>, + inputs: &[&dyn NDIndex], + outs: &mut [&mut dyn NDIndex], + ) -> usize + where + T: Default + Copy + Add + AddAssign + Mul + PartialEq, + { + let ops = &self.ops; + while pc < ops.len() { + match &ops[pc] { + VmOp::DenseLoop { slot, dim, end_pc, fused } => { + let s = *slot as usize; + if *fused { + for v in 0..*dim { + vals[s] = v; + self.mul_acc(vals, buf, sparse_vals, acc_state, inputs, outs); + } + } else { + for v in 0..*dim { + vals[s] = v; + self.exec_at(pc + 1, vals, buf, sparse_vals, acc_state, inputs, outs); + } + } + pc = *end_pc; + } + VmOp::SparseRowLoop { + input_idx, + row_slot, + col_slot, + end_pc, + fused, + } => { + let row = vals[*row_slot as usize]; + let cs = *col_slot as usize; + let input = inputs[*input_idx]; + let nnz = input.sparse_row_nnz(row); + if *fused { + for ei in 0..nnz { + let (col, val) = input.sparse_row_entry(row, ei); + vals[cs] = col; + sparse_vals[*input_idx] = val; + self.mul_acc(vals, buf, sparse_vals, acc_state, inputs, outs); + } + } else { + for ei in 0..nnz { + let (col, val) = input.sparse_row_entry(row, ei); + vals[cs] = col; + sparse_vals[*input_idx] = val; + self.exec_at(pc + 1, vals, buf, sparse_vals, acc_state, inputs, outs); + } + } + pc = *end_pc; + } + VmOp::LoopEnd { .. } => { + return pc + 1; + } + VmOp::AccStart { acc_slot, acc_out_pos, dim } => { + *acc_state = Some(AccState { + acc: vec![T::default(); *dim], + nz_cols: Vec::new(), + acc_slot: *acc_slot, + acc_out_pos: *acc_out_pos, + }); + pc += 1; + } + VmOp::AccFlush => { + // AccFlush only emitted for single-output specs + if let Some(st) = acc_state { + let pattern = &self.output_patterns[0]; + let len = pattern.len(); + for &j in st.nz_cols.iter() { + for (i, &s) in pattern.iter().enumerate() { + buf[i] = vals[s as usize]; + } + buf[st.acc_out_pos as usize] = j; + outs[0].set(&buf[..len], st.acc[j]); + st.acc[j] = T::default(); + } + st.nz_cols.clear(); + } + pc += 1; + } + VmOp::MulAcc => { + self.mul_acc(vals, buf, sparse_vals, acc_state, inputs, outs); + pc += 1; + } + } + } + pc + } + + /// Compute one multiply-accumulate step: read inputs, multiply, write to + /// accumulator or output(s). + #[inline] + fn mul_acc( + &self, + vals: &[usize; 26], + buf: &mut [usize; 26], + sparse_vals: &[T], + acc_state: &mut Option>, + inputs: &[&dyn NDIndex], + outs: &mut [&mut dyn NDIndex], + ) where + T: Default + Copy + Add + AddAssign + Mul + PartialEq, + { + let mut product = None::; + for (i, pattern) in self.input_patterns.iter().enumerate() { + let v = if self.sparse_value_source[i].is_some() { + Some(sparse_vals[i]) + } else { + let len = pattern.len(); + for (p, &s) in pattern.iter().enumerate() { + buf[p] = vals[s as usize]; + } + inputs[i].get_opt(&buf[..len]) + }; + if let Some(v) = v { + product = Some(match product { + Some(p) => p * v, + None => v, + }); + } else { + product = None; + break; + } + } + if let Some(p) = product { + if let Some(st) = acc_state { + // Single-output accumulator path + let idx = vals[st.acc_slot as usize]; + if st.acc[idx] == T::default() { + st.nz_cols.push(idx); + } + st.acc[idx] += p; + } else { + // Write to all outputs + for (oi, pattern) in self.output_patterns.iter().enumerate() { + let len = pattern.len(); + for (i, &s) in pattern.iter().enumerate() { + buf[i] = vals[s as usize]; + } + let cur = outs[oi].get(&buf[..len]); + outs[oi].set(&buf[..len], cur + p); + } + } + } + } +} + +/// Convenience: compile and execute in one call. +/// +/// Accepts any number of inputs and outputs of any dimensionality. +/// All inputs must be the same concrete type; for mixed types use +/// [`einsum_vm_oneshot_dyn`]. +/// +/// Multiple outputs are supported: `"ab,bc->ac,ca"` writes to two output +/// tensors simultaneously from a single loop nest. +pub fn einsum_vm_oneshot( + spec_str: &str, + inputs: &[&In], + outs: &mut [&mut Out], +) -> Result<(), InvalidSpec> +where + T: Copy + Default + Add + AddAssign + Mul + PartialEq, + In: NDIndex, + Out: NDIndex, +{ + let dyn_inputs: Vec<&dyn NDIndex> = inputs.iter().map(|&x| x as &dyn NDIndex).collect(); + let mut dyn_outs: Vec<&mut dyn NDIndex> = outs.iter_mut().map(|o| { + let r: &mut Out = *o; + r as &mut dyn NDIndex + }).collect(); + einsum_vm_oneshot_dyn(spec_str, &dyn_inputs, &mut dyn_outs) +} + +/// Like [`einsum_vm_oneshot`] but accepts trait-object inputs, allowing +/// mixed concrete types (e.g. one sparse and one dense input). +pub fn einsum_vm_oneshot_dyn( + spec_str: &str, + inputs: &[&dyn NDIndex], + outs: &mut [&mut dyn NDIndex], +) -> Result<(), InvalidSpec> +where + T: Copy + Default + Add + AddAssign + Mul + PartialEq, +{ + let program = compile_vm(spec_str, inputs)?; + program.exec(inputs, outs); + Ok(()) +} + +// ═══════════════════════════════════════════════════════════════════════════ +// Approach 4: Sparse-driven with hash accumulator +// ═══════════════════════════════════════════════════════════════════════════ + +/// Sparse-driven binary einsum with hash accumulator per output row. +/// +/// Equivalent to matmul (C = A × B) — same restriction as +/// `einsum_sparse_driven`: only supports the `"ab,bc->ac"` pattern. +/// +/// Uses a `HashMap` per row instead of a dense `Vec`. +/// Better when the output is very sparse; worse for dense output. +pub fn einsum_sparse_hash( + spec_str: &str, + a: &S, + b: &S, + out: &mut Out, +) -> Result<(), InvalidSpec> +where + T: Default + Copy + AddAssign + Mul + PartialEq, + S: Sparse2D, + Out: NDIndex, +{ + let spec = crate::parse_spec(spec_str, 2)?; + let dims = crate::validate_dims::(&spec, &[a, b])?; + crate::validate_output::(&spec, &dims, out)?; + + let a_slots = &spec.inputs[0]; + let b_slots = &spec.inputs[1]; + let out_slots = &spec.output(); + + assert_eq!(a_slots.len(), 2, "sparse-hash requires 2D inputs"); + assert_eq!(b_slots.len(), 2, "sparse-hash requires 2D inputs"); + + let (a0, a1) = (a_slots[0], a_slots[1]); + let b1 = b_slots[1]; + + if a1 != b_slots[0] { + panic!( + "einsum_sparse_hash requires A's axis-1 == B's axis-0, \ + got spec '{spec_str}'" + ); + } + + let out_pos_a0 = out_slots.iter().position(|&s| s == a0); + let out_pos_b1 = out_slots.iter().position(|&s| s == b1); + let n_out_dims = out_slots.len(); + + let n_rows_a = a.n_rows(); + let mut acc: HashMap = HashMap::new(); + let mut out_ix = vec![0usize; n_out_dims]; + + for i in 0..n_rows_a { + acc.clear(); + + for ai in 0..a.row_nnz(i) { + let (k, a_val) = a.row_entry(i, ai); + for bi in 0..b.row_nnz(k) { + let (j, b_val) = b.row_entry(k, bi); + *acc.entry(j).or_insert_with(T::default) += a_val * b_val; + } + } + + if let Some(p) = out_pos_a0 { + out_ix[p] = i; + } + for (&j, &v) in &acc { + if let Some(p) = out_pos_b1 { + out_ix[p] = j; + } + out.set(&out_ix[..n_out_dims], v); + } + } + + Ok(()) +} + +// ═══════════════════════════════════════════════════════════════════════════ +// Display for VmOp (debugging / documentation) +// ═══════════════════════════════════════════════════════════════════════════ + +impl std::fmt::Display for VmProgram { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + writeln!(f, "VM Program:")?; + writeln!( + f, + " inputs: {:?}", + self.input_patterns + .iter() + .map(|p| p.iter().map(|&s| (s + b'a') as char).collect::()) + .collect::>() + )?; + writeln!( + f, + " outputs: {:?}", + self.output_patterns + .iter() + .map(|p| p.iter().map(|&s| (s + b'a') as char).collect::()) + .collect::>() + )?; + writeln!(f, " plan:")?; + let mut indent = 2usize; + for op in &self.ops { + let pad = " ".repeat(indent); + match op { + VmOp::DenseLoop { slot, dim, fused, .. } => { + let ch = (slot + b'a') as char; + let tag = if *fused { " [FUSED]" } else { "" }; + writeln!(f, "{pad}FOR {ch} IN 0..{dim}{tag}")?; + indent += 1; + } + VmOp::SparseRowLoop { + input_idx, + row_slot, + col_slot, + fused, + .. + } => { + let row_ch = (row_slot + b'a') as char; + let col_ch = (col_slot + b'a') as char; + let tag = if *fused { " [SPARSE,FUSED]" } else { " [SPARSE]" }; + writeln!( + f, + "{pad}FOR ({col_ch}, val) IN input[{input_idx}].row({row_ch}){tag}" + )?; + indent += 1; + } + VmOp::LoopEnd { .. } => { + indent -= 1; + } + VmOp::AccStart { acc_slot, dim, .. } => { + let ch = (acc_slot + b'a') as char; + writeln!(f, "{pad}ACC_START {ch}[0..{dim}]")?; + } + VmOp::AccFlush => { + writeln!(f, "{pad}ACC_FLUSH → output")?; + } + VmOp::MulAcc => { + writeln!(f, "{pad}MUL_ACC → acc")?; + } + } + } + Ok(()) + } +} + +// ═══════════════════════════════════════════════════════════════════════════ +// Tests +// ═══════════════════════════════════════════════════════════════════════════ + +#[cfg(test)] +mod tests { + use super::*; + use crate::einsum_binary; + + /// Minimal dense tensor for output. + struct DenseMat { + data: Vec, + rows: usize, + cols: usize, + } + + impl DenseMat { + fn new(rows: usize, cols: usize) -> Self { + Self { + data: vec![0; rows * cols], + rows, + cols, + } + } + } + + impl NDIndex for DenseMat { + fn ndim(&self) -> usize { + 2 + } + fn dim(&self, axis: usize) -> usize { + if axis == 0 { + self.rows + } else { + self.cols + } + } + fn get(&self, ix: &[usize]) -> u32 { + self.data[ix[0] * self.cols + ix[1]] + } + fn set(&mut self, ix: &[usize], v: u32) { + self.data[ix[0] * self.cols + ix[1]] = v; + } + } + + /// Minimal sparse matrix for testing (COO-based, implements Sparse2D). + struct SparseMat { + n: usize, + /// Sorted by (row, col). Each row's entries are contiguous. + row_ptr: Vec, + col_idx: Vec, + values: Vec, + } + + impl SparseMat { + /// Build from (row, col, val) triplets. n = matrix dimension. + fn from_triplets(n: usize, trips: &[(usize, usize, u32)]) -> Self { + let mut sorted = trips.to_vec(); + sorted.sort_by_key(|&(r, c, _)| (r, c)); + + let mut row_ptr = vec![0; n + 1]; + let mut col_idx = Vec::new(); + let mut values = Vec::new(); + + for &(r, c, v) in &sorted { + if v == 0 { + continue; + } + col_idx.push(c); + values.push(v); + row_ptr[r + 1] = col_idx.len(); + } + // Fill forward + for i in 1..=n { + if row_ptr[i] == 0 && i > 0 { + row_ptr[i] = row_ptr[i - 1]; + } + } + // Fix: ensure monotonic + for i in 1..=n { + row_ptr[i] = row_ptr[i].max(row_ptr[i - 1]); + } + + Self { + n, + row_ptr, + col_idx, + values, + } + } + } + + impl NDIndex for SparseMat { + fn ndim(&self) -> usize { + 2 + } + fn dim(&self, _axis: usize) -> usize { + self.n + } + fn get(&self, ix: &[usize]) -> u32 { + let r = ix[0]; + let c = ix[1]; + let start = self.row_ptr[r]; + let end = self.row_ptr[r + 1]; + for i in start..end { + if self.col_idx[i] == c { + return self.values[i]; + } + } + 0 + } + fn set(&mut self, _ix: &[usize], _v: u32) { + panic!("SparseMat is read-only"); + } + fn get_opt(&self, ix: &[usize]) -> Option { + let r = ix[0]; + let c = ix[1]; + let start = self.row_ptr[r]; + let end = self.row_ptr[r + 1]; + for i in start..end { + if self.col_idx[i] == c { + return Some(self.values[i]); + } + } + None + } + fn is_sparse_2d(&self) -> bool { true } + fn sparse_row_nnz(&self, row: usize) -> usize { + self.row_ptr[row + 1] - self.row_ptr[row] + } + fn sparse_row_entry(&self, row: usize, idx: usize) -> (usize, u32) { + let start = self.row_ptr[row]; + (self.col_idx[start + idx], self.values[start + idx]) + } + } + + impl Sparse2D for SparseMat { + fn nnz(&self) -> usize { + self.values.len() + } + fn n_rows(&self) -> usize { + self.n + } + fn row_nnz(&self, row: usize) -> usize { + self.row_ptr[row + 1] - self.row_ptr[row] + } + fn row_entry(&self, row: usize, idx: usize) -> (usize, u32) { + let start = self.row_ptr[row]; + (self.col_idx[start + idx], self.values[start + idx]) + } + } + + /// Reference matmul for verification. + fn naive_matmul(a: &SparseMat, b: &SparseMat) -> Vec> { + let n = a.n; + let mut c = vec![vec![0u32; n]; n]; + for i in 0..n { + for k in 0..n { + let a_ik = a.get(&[i, k]); + if a_ik == 0 { + continue; + } + for j in 0..n { + c[i][j] += a_ik * b.get(&[k, j]); + } + } + } + c + } + + #[test] + fn test_approach1_baseline() { + let a = SparseMat::from_triplets(3, &[(0, 1, 1), (1, 2, 1), (2, 0, 1)]); + let b = SparseMat::from_triplets(3, &[(0, 1, 1), (1, 2, 1), (2, 0, 1)]); + let mut out = DenseMat::new(3, 3); + + einsum_binary("ab,bc->ac", &a, &b, &mut out).unwrap(); + + let expected = naive_matmul(&a, &b); + for i in 0..3 { + for j in 0..3 { + assert_eq!( + out.get(&[i, j]), + expected[i][j], + "baseline mismatch at ({i},{j})" + ); + } + } + } + + #[test] + fn test_approach2_sparse_driven() { + let a = SparseMat::from_triplets(4, &[(0, 1, 2), (0, 2, 3), (1, 3, 1), (2, 3, 4)]); + let b = SparseMat::from_triplets(4, &[(1, 0, 5), (2, 0, 6), (3, 1, 7)]); + let mut out = DenseMat::new(4, 4); + + einsum_sparse_driven("ab,bc->ac", &a, &b, &mut out).unwrap(); + + let expected = naive_matmul(&a, &b); + for i in 0..4 { + for j in 0..4 { + assert_eq!( + out.get(&[i, j]), + expected[i][j], + "sparse-driven mismatch at ({i},{j})" + ); + } + } + } + + #[test] + fn test_approach3_vm() { + let a = SparseMat::from_triplets(4, &[(0, 1, 2), (0, 2, 3), (1, 3, 1), (2, 3, 4)]); + let b = SparseMat::from_triplets(4, &[(1, 0, 5), (2, 0, 6), (3, 1, 7)]); + let mut out = DenseMat::new(4, 4); + + einsum_vm_oneshot("ab,bc->ac", &[&a, &b], &mut [&mut out]).unwrap(); + + let expected = naive_matmul(&a, &b); + for i in 0..4 { + for j in 0..4 { + assert_eq!( + out.get(&[i, j]), + expected[i][j], + "VM mismatch at ({i},{j})" + ); + } + } + } + + #[test] + fn test_approach3_vm_display() { + let a = SparseMat::from_triplets(4, &[(0, 1, 1)]); + let b = SparseMat::from_triplets(4, &[(1, 0, 1)]); + let program = compile_vm("ab,bc->ac", &[&a as &dyn NDIndex, &b]).unwrap(); + let display = format!("{program}"); + assert!(display.contains("SPARSE"), "VM should use sparse loops:\n{display}"); + println!("{display}"); + } + + #[test] + fn test_approach4_sparse_hash() { + let a = SparseMat::from_triplets(4, &[(0, 1, 2), (0, 2, 3), (1, 3, 1), (2, 3, 4)]); + let b = SparseMat::from_triplets(4, &[(1, 0, 5), (2, 0, 6), (3, 1, 7)]); + let mut out = DenseMat::new(4, 4); + + einsum_sparse_hash("ab,bc->ac", &a, &b, &mut out).unwrap(); + + let expected = naive_matmul(&a, &b); + for i in 0..4 { + for j in 0..4 { + assert_eq!( + out.get(&[i, j]), + expected[i][j], + "hash mismatch at ({i},{j})" + ); + } + } + } + + #[test] + fn test_all_approaches_agree_diamond() { + // Diamond graph: 0→1, 0→2, 1→3, 2→3 + let a = SparseMat::from_triplets( + 4, + &[(0, 1, 1), (0, 2, 1), (1, 3, 1), (2, 3, 1)], + ); + let b = a.clone_sparse(); + let expected = naive_matmul(&a, &b); + + let mut out1 = DenseMat::new(4, 4); + einsum_binary("ab,bc->ac", &a, &b, &mut out1).unwrap(); + + let mut out2 = DenseMat::new(4, 4); + einsum_sparse_driven("ab,bc->ac", &a, &b, &mut out2).unwrap(); + + let mut out3 = DenseMat::new(4, 4); + einsum_vm_oneshot("ab,bc->ac", &[&a, &b], &mut [&mut out3]).unwrap(); + + let mut out4 = DenseMat::new(4, 4); + einsum_sparse_hash("ab,bc->ac", &a, &b, &mut out4).unwrap(); + + for i in 0..4 { + for j in 0..4 { + let e = expected[i][j]; + assert_eq!(out1.get(&[i, j]), e, "baseline@({i},{j})"); + assert_eq!(out2.get(&[i, j]), e, "sparse-driven@({i},{j})"); + assert_eq!(out3.get(&[i, j]), e, "VM@({i},{j})"); + assert_eq!(out4.get(&[i, j]), e, "hash@({i},{j})"); + } + } + } + + #[test] + fn test_identity_matmul() { + // A × I = A + let a = SparseMat::from_triplets(3, &[(0, 1, 5), (1, 2, 3), (2, 0, 7)]); + let id = SparseMat::from_triplets(3, &[(0, 0, 1), (1, 1, 1), (2, 2, 1)]); + + let mut out = DenseMat::new(3, 3); + einsum_sparse_driven("ab,bc->ac", &a, &id, &mut out).unwrap(); + + assert_eq!(out.get(&[0, 1]), 5); + assert_eq!(out.get(&[1, 2]), 3); + assert_eq!(out.get(&[2, 0]), 7); + assert_eq!(out.get(&[0, 0]), 0); + } + + /// Dense N-dimensional tensor for testing higher-dimensional VM inputs. + struct DenseTensor { + data: Vec, + shape: Vec, + } + + impl DenseTensor { + fn new(shape: Vec) -> Self { + let n: usize = shape.iter().product(); + Self { data: vec![0; n], shape } + } + fn linear_index(&self, ix: &[usize]) -> usize { + let mut idx = 0; + let mut stride = 1; + for (&k, &dim) in ix.iter().rev().zip(self.shape.iter().rev()) { + idx += k * stride; + stride *= dim; + } + idx + } + } + + impl NDIndex for DenseTensor { + fn ndim(&self) -> usize { self.shape.len() } + fn dim(&self, axis: usize) -> usize { self.shape[axis] } + fn get(&self, ix: &[usize]) -> u32 { self.data[self.linear_index(ix)] } + fn set(&mut self, ix: &[usize], v: u32) { + let i = self.linear_index(ix); + self.data[i] = v; + } + } + + #[test] + fn test_vm_3d_dense_inputs() { + // "abc,cd->abd": batched matmul with 3D × 2D → 3D + // All dense — no sparse loops, but the VM should handle it. + let mut a = DenseTensor::new(vec![2, 3, 4]); // batch=2, rows=3, inner=4 + let mut b = DenseTensor::new(vec![4, 5]); // inner=4, cols=5 + + // Fill with simple values + for i in 0..2 { + for j in 0..3 { + for k in 0..4 { + a.set(&[i, j, k], (i * 12 + j * 4 + k + 1) as u32); + } + } + } + for k in 0..4 { + for l in 0..5 { + b.set(&[k, l], (k * 5 + l + 1) as u32); + } + } + + let mut out = DenseTensor::new(vec![2, 3, 5]); + einsum_vm_oneshot( + "abc,cd->abd", + &[&a, &b], + &mut [&mut out], + ).unwrap(); + + // Verify against naive computation + for i in 0..2 { + for j in 0..3 { + for l in 0..5 { + let mut expected = 0u32; + for k in 0..4 { + expected += a.get(&[i, j, k]) * b.get(&[k, l]); + } + assert_eq!( + out.get(&[i, j, l]), expected, + "3D mismatch at ({i},{j},{l})" + ); + } + } + } + } + + #[test] + fn test_vm_mixed_2d_sparse_and_dense() { + // "ab,bc->ac": one sparse, one dense — VM should use sparse for the + // sparse input's axis and dense for the dense input. + let a = SparseMat::from_triplets(3, &[(0, 1, 2), (1, 2, 3)]); + let mut b = DenseTensor::new(vec![3, 3]); + for i in 0..3 { + for j in 0..3 { + b.set(&[i, j], (i * 3 + j + 1) as u32); + } + } + + let mut out = DenseMat::new(3, 3); + einsum_vm_oneshot_dyn( + "ab,bc->ac", + &[&a as &dyn NDIndex, &b as &dyn NDIndex], + &mut [&mut out as &mut dyn NDIndex], + ).unwrap(); + + // Verify: row 0 of A has only (1, 2), row 1 has only (2, 3) + // C[0, j] = A[0,1]*B[1,j] = 2 * B[1,j] + // C[1, j] = A[1,2]*B[2,j] = 3 * B[2,j] + for j in 0..3 { + assert_eq!(out.get(&[0, j]), 2 * b.get(&[1, j]), "mixed@(0,{j})"); + assert_eq!(out.get(&[1, j]), 3 * b.get(&[2, j]), "mixed@(1,{j})"); + assert_eq!(out.get(&[2, j]), 0, "mixed@(2,{j})"); + } + } + + #[test] + fn test_vm_4d_attention() { + // "bhqd,bhkd->bhqk": attention-style with all dense inputs + let (ba, h, q, k, d) = (1, 1, 2, 2, 3); + let mut qm = DenseTensor::new(vec![ba, h, q, d]); + let mut km = DenseTensor::new(vec![ba, h, k, d]); + + for qi in 0..q { + for di in 0..d { + qm.set(&[0, 0, qi, di], (qi * d + di + 1) as u32); + } + } + for ki in 0..k { + for di in 0..d { + km.set(&[0, 0, ki, di], (ki * d + di + 1) as u32); + } + } + + let mut out = DenseTensor::new(vec![ba, h, q, k]); + einsum_vm_oneshot( + "bhqd,bhkd->bhqk", + &[&qm, &km], + &mut [&mut out], + ).unwrap(); + + for qi in 0..q { + for ki in 0..k { + let mut expected = 0u32; + for di in 0..d { + expected += qm.get(&[0, 0, qi, di]) * km.get(&[0, 0, ki, di]); + } + assert_eq!( + out.get(&[0, 0, qi, ki]), expected, + "attention@(0,0,{qi},{ki})" + ); + } + } + } + + impl SparseMat { + fn clone_sparse(&self) -> Self { + Self { + n: self.n, + row_ptr: self.row_ptr.clone(), + col_idx: self.col_idx.clone(), + values: self.values.clone(), + } + } + } + + /// 0-dim scalar output tensor for VM scalar tests. + struct ScalarOut { + val: u32, + } + + impl ScalarOut { + fn new() -> Self { Self { val: 0 } } + } + + impl NDIndex for ScalarOut { + fn ndim(&self) -> usize { 0 } + fn dim(&self, _axis: usize) -> usize { panic!("0-dim has no axes") } + fn get(&self, _ix: &[usize]) -> u32 { self.val } + fn set(&mut self, _ix: &[usize], v: u32) { self.val = v; } + } + + #[test] + fn test_vm_scalar_dot_sparse() { + // "i,i->" dot product with sparse inputs + // a = [0, 2, 0, 3], b = [0, 5, 7, 0] + // dot = 0 + 2*5 + 0 + 0 = 10 + let a = SparseMat::from_triplets(1, &[(0, 1, 2), (0, 3, 3)]); + let b = SparseMat::from_triplets(1, &[(0, 1, 5), (0, 2, 7)]); + + // For dot product on 1D vectors stored as 1-row matrices, + // use spec "ab,ab->" (Frobenius inner product) since SparseMat is 2D + let mut out = ScalarOut::new(); + einsum_vm_oneshot("ab,ab->", &[&a, &b], &mut [&mut out]).unwrap(); + + // a[0,1]*b[0,1] + a[0,3]*b[0,3] = 2*5 + 0 = 10 + assert_eq!(out.val, 10); + } + + #[test] + fn test_vm_scalar_trace_sparse() { + // "aa->" trace of a sparse matrix + // Only diagonal entries matter: (0,0)=1, (1,1)=5, (2,2)=9 + let m = SparseMat::from_triplets(3, &[ + (0, 0, 1), (0, 2, 2), + (1, 1, 5), (1, 2, 3), + (2, 2, 9), + ]); + + let mut out = ScalarOut::new(); + einsum_vm_oneshot("aa->", &[&m], &mut [&mut out]).unwrap(); + + assert_eq!(out.val, 15); // 1 + 5 + 9 + } + + #[test] + fn test_vm_scalar_frobenius_sparse() { + // "ab,ab->" Frobenius inner product (sum of element-wise products) + let a = SparseMat::from_triplets(3, &[ + (0, 1, 2), (1, 0, 3), (2, 2, 4), + ]); + let b = SparseMat::from_triplets(3, &[ + (0, 1, 5), (1, 0, 7), (1, 1, 99), (2, 2, 2), + ]); + + let mut out = ScalarOut::new(); + einsum_vm_oneshot("ab,ab->", &[&a, &b], &mut [&mut out]).unwrap(); + + // overlapping entries: (0,1): 2*5=10, (1,0): 3*7=21, (2,2): 4*2=8 + assert_eq!(out.val, 39); // 10 + 21 + 8 + } + + #[test] + fn test_vm_scalar_dense_dot() { + // "i,i->" with dense 1D inputs + let mut a = DenseTensor::new(vec![4]); + let mut b = DenseTensor::new(vec![4]); + for i in 0..4 { + a.set(&[i], (i + 1) as u32); // [1, 2, 3, 4] + b.set(&[i], (i + 5) as u32); // [5, 6, 7, 8] + } + + let mut out = ScalarOut::new(); + einsum_vm_oneshot("i,i->", &[&a, &b], &mut [&mut out]).unwrap(); + + // 1*5 + 2*6 + 3*7 + 4*8 = 5 + 12 + 21 + 32 = 70 + assert_eq!(out.val, 70); + } + + #[test] + fn test_vm_scalar_three_input() { + // "i,i,i->" element-wise triple product summed to scalar + let mut a = DenseTensor::new(vec![3]); + let mut b = DenseTensor::new(vec![3]); + let mut c = DenseTensor::new(vec![3]); + for i in 0..3 { + a.set(&[i], (i + 1) as u32); // [1, 2, 3] + b.set(&[i], (i + 4) as u32); // [4, 5, 6] + c.set(&[i], (i + 7) as u32); // [7, 8, 9] + } + + let mut out = ScalarOut::new(); + einsum_vm_oneshot("i,i,i->", &[&a, &b, &c], &mut [&mut out]).unwrap(); + + // 1*4*7 + 2*5*8 + 3*6*9 = 28 + 80 + 162 = 270 + assert_eq!(out.val, 270); + } + + // --- Multi-output tests --- + + #[test] + fn test_vm_multi_output_matmul_and_transpose() { + // "ab,bc->ac,ca": matmul result written to both C and C^T simultaneously + let a = SparseMat::from_triplets(3, &[(0, 1, 2), (1, 2, 3), (2, 0, 1)]); + let b = SparseMat::from_triplets(3, &[(0, 1, 4), (1, 0, 5), (2, 2, 6)]); + + let mut out_ac = DenseMat::new(3, 3); + let mut out_ca = DenseMat::new(3, 3); + einsum_vm_oneshot("ab,bc->ac,ca", &[&a, &b], &mut [&mut out_ac, &mut out_ca]).unwrap(); + + // Verify against naive matmul + let expected = naive_matmul(&a, &b); + for i in 0..3 { + for j in 0..3 { + assert_eq!(out_ac.get(&[i, j]), expected[i][j], "ac@({i},{j})"); + assert_eq!(out_ca.get(&[j, i]), expected[i][j], "ca@({j},{i})"); + } + } + } + + #[test] + fn test_vm_multi_output_dense() { + // "ab,bc->ac,ca" with dense inputs + let mut a = DenseTensor::new(vec![3, 3]); + let mut b = DenseTensor::new(vec![3, 3]); + for i in 0..3 { + for j in 0..3 { + a.set(&[i, j], (i * 3 + j + 1) as u32); + b.set(&[i, j], (i * 3 + j + 10) as u32); + } + } + + let mut out_ac = DenseTensor::new(vec![3, 3]); + let mut out_ca = DenseTensor::new(vec![3, 3]); + einsum_vm_oneshot("ab,bc->ac,ca", &[&a, &b], &mut [&mut out_ac, &mut out_ca]).unwrap(); + + // Verify C[i,j] = sum_k A[i,k]*B[k,j] + for i in 0..3 { + for j in 0..3 { + let mut expected = 0u32; + for k in 0..3 { + expected += a.get(&[i, k]) * b.get(&[k, j]); + } + assert_eq!(out_ac.get(&[i, j]), expected, "ac@({i},{j})"); + assert_eq!(out_ca.get(&[j, i]), expected, "ca@({j},{i})"); + } + } + } + + #[test] + fn test_vm_multi_output_matmul_and_scalar() { + // "ab,bc->ac," : matmul + scalar (Frobenius-like total sum) + // Wait — we need different contracted indices for the scalar. + // Actually "ab,ba->,ab" would be: scalar = sum_ab A[a,b]*B[b,a], + // and output ab = A[a,b]*B[b,a] (element-wise, no contraction for ab output) + // Let's test a simpler case: "ab,bc->ac,ac" (same output written twice) + let a = SparseMat::from_triplets(3, &[(0, 1, 2), (1, 2, 3)]); + let b = SparseMat::from_triplets(3, &[(1, 0, 5), (2, 2, 6)]); + + let mut out1 = DenseMat::new(3, 3); + let mut out2 = DenseMat::new(3, 3); + einsum_vm_oneshot("ab,bc->ac,ac", &[&a, &b], &mut [&mut out1, &mut out2]).unwrap(); + + let expected = naive_matmul(&a, &b); + for i in 0..3 { + for j in 0..3 { + assert_eq!(out1.get(&[i, j]), expected[i][j], "out1@({i},{j})"); + assert_eq!(out2.get(&[i, j]), expected[i][j], "out2@({i},{j})"); + } + } + } +} diff --git a/experiments/eval-examples/src/lib.rs b/experiments/eval-examples/src/lib.rs index 7aaa84c..816e58c 100644 --- a/experiments/eval-examples/src/lib.rs +++ b/experiments/eval-examples/src/lib.rs @@ -1,7 +1,7 @@ use eval_ffi::{ExprSink, ExprSource, EvalError, SourceItem}; #[unsafe(export_name = "ground_mul")] -pub extern "C" fn ground_mul(expr: *mut ExprSource, sink: *mut ExprSink) -> Result<(), EvalError> { +pub extern "C" fn ground_mul(expr: *mut ExprSource, sink: *mut ExprSink, _ctx: *mut ()) -> Result<(), EvalError> { let expr = unsafe { &mut *expr }; let sink = unsafe { &mut *sink }; let items = expr.consume_head_check(b"*")?; @@ -16,7 +16,7 @@ pub extern "C" fn ground_mul(expr: *mut ExprSource, sink: *mut ExprSink) -> Resu } #[unsafe(export_name = "ground_sum")] -pub extern "C" fn ground_sum(expr: *mut ExprSource, sink: *mut ExprSink) -> Result<(), EvalError> { +pub extern "C" fn ground_sum(expr: *mut ExprSource, sink: *mut ExprSink, _ctx: *mut ()) -> Result<(), EvalError> { let expr = unsafe { &mut *expr }; let sink = unsafe { &mut *sink }; let items = expr.consume_head_check(b"+")?; diff --git a/experiments/eval-ffi/src/lib.rs b/experiments/eval-ffi/src/lib.rs index 6e2cff3..8655aa1 100644 --- a/experiments/eval-ffi/src/lib.rs +++ b/experiments/eval-ffi/src/lib.rs @@ -10,7 +10,7 @@ pub use sink::{ExprSink}; pub use source::{ExprSource}; pub use mork_expr::{Tag, SourceItem}; -pub type FuncPtr = extern "C" fn(*mut ExprSource, *mut ExprSink) -> Result<(), EvalError>; +pub type FuncPtr = extern "C" fn(*mut ExprSource, *mut ExprSink, *mut ()) -> Result<(), EvalError>; #[derive(Debug)] pub enum EvalError { diff --git a/experiments/eval/src/lib.rs b/experiments/eval/src/lib.rs index ba90304..202f845 100644 --- a/experiments/eval/src/lib.rs +++ b/experiments/eval/src/lib.rs @@ -7,11 +7,11 @@ use mork_expr::{item_source, SourceItem}; use eval_ffi::{EvalError, ExprSink, ExprSource, FuncPtr, Tag}; use log::trace; -extern "C" fn nothing(src: *mut ExprSource, snk: *mut ExprSink) -> Result<(), EvalError> { +extern "C" fn nothing(src: *mut ExprSource, snk: *mut ExprSink, _ctx: *mut ()) -> Result<(), EvalError> { Ok(()) } -extern "C" fn quote(src: *mut ExprSource, snk: *mut ExprSink) -> Result<(), EvalError> { +extern "C" fn quote(src: *mut ExprSource, snk: *mut ExprSink, _ctx: *mut ()) -> Result<(), EvalError> { unreachable!() } @@ -37,6 +37,8 @@ pub struct EvalScope { /// These are used to create ExprSinks for stack frames. /// This avoids repeated allocations during evaluation. alloc_pool: Vec>, + /// Opaque context pointer propagated to ExprSource for pure functions. + pub context: *mut (), } macro_rules! alloc { @@ -59,6 +61,7 @@ impl EvalScope { stack: Vec::new(), alloc_pool: Vec::new(), expr: ExprSource::new(core::ptr::null()), + context: core::ptr::null_mut(), } } pub fn add_func(&mut self, name: &str, func: FuncPtr, ty: FuncType) { @@ -136,7 +139,8 @@ impl EvalScope { let prev_frame = parent_frames.last_mut().unwrap(); let mut data = core::mem::take(&mut top_frame.sink).finish(); let offset = prev_frame.sink.as_ref().len(); - (top_frame.func)(&mut ExprSource::new(data.as_ptr()), &mut prev_frame.sink)?; + let mut src = ExprSource::new(data.as_ptr()); + (top_frame.func)(&mut src, &mut prev_frame.sink, self.context)?; trace!(target: "eval", "{:?} ==> {:?}", mork_expr::serialize(&data[..]), mork_expr::serialize(&prev_frame.sink.as_ref()[offset..])); let top = self.stack.pop().unwrap(); // return buffer to pool diff --git a/kernel/Cargo.toml b/kernel/Cargo.toml index 6a42ac4..502394f 100644 --- a/kernel/Cargo.toml +++ b/kernel/Cargo.toml @@ -26,6 +26,8 @@ itertools = "0.14.0" base64 = "0.22.1" hex = "0.4.3" subprocess = { version = "0.2.13" } +einsum-dyn = { path = "../einsum-dyn" } +num-traits = "0.2" [features] default = ["grounding", "specialize_io"] diff --git a/kernel/src/lib.rs b/kernel/src/lib.rs index b566759..94bc701 100644 --- a/kernel/src/lib.rs +++ b/kernel/src/lib.rs @@ -8,3 +8,4 @@ pub mod space; mod sources; mod sinks; mod pure; +pub mod sparse; diff --git a/kernel/src/pure.rs b/kernel/src/pure.rs index e5ced75..d8114e7 100644 --- a/kernel/src/pure.rs +++ b/kernel/src/pure.rs @@ -9,7 +9,7 @@ use eval_ffi::{ExprSink, ExprSource, EvalError, Tag}; macro_rules! op { (num nary $name:ident($initial:expr, $t:ident: $tt:ty, $x:ident: $tx:ty) => $e:expr) => { - pub extern "C" fn $name(expr: *mut ExprSource, sink: *mut ExprSink) -> Result<(), EvalError> { + pub extern "C" fn $name(expr: *mut ExprSource, sink: *mut ExprSink, _ctx: *mut ()) -> Result<(), EvalError> { let expr = unsafe { &mut *expr }; let sink = unsafe { &mut *sink }; let items = expr.consume_head_check(stringify!($name).as_bytes())?; @@ -23,7 +23,7 @@ macro_rules! op { } }; (num quaternary $name:ident($x:ident: $tx:ty, $y:ident: $ty:ty, $z:ident: $tz:ty, $w:ident: $tw:ty) => $e:expr) => { - pub extern "C" fn $name(expr: *mut ExprSource, sink: *mut ExprSink) -> Result<(), EvalError> { + pub extern "C" fn $name(expr: *mut ExprSource, sink: *mut ExprSink, _ctx: *mut ()) -> Result<(), EvalError> { let expr = unsafe { &mut *expr }; let sink = unsafe { &mut *sink }; let items = expr.consume_head_check(stringify!($name).as_bytes())?; @@ -37,7 +37,7 @@ macro_rules! op { } }; (num ternary $name:ident($x:ident: $tx:ty, $y:ident: $ty:ty, $z:ident: $tz:ty) => $e:expr) => { - pub extern "C" fn $name(expr: *mut ExprSource, sink: *mut ExprSink) -> Result<(), EvalError> { + pub extern "C" fn $name(expr: *mut ExprSource, sink: *mut ExprSink, _ctx: *mut ()) -> Result<(), EvalError> { let expr = unsafe { &mut *expr }; let sink = unsafe { &mut *sink }; let items = expr.consume_head_check(stringify!($name).as_bytes())?; @@ -50,7 +50,7 @@ macro_rules! op { } }; (num binary $name:ident($x:ident: $tx:ty, $y:ident: $ty:ty) => $e:expr) => { - pub extern "C" fn $name(expr: *mut ExprSource, sink: *mut ExprSink) -> Result<(), EvalError> { + pub extern "C" fn $name(expr: *mut ExprSource, sink: *mut ExprSink, _ctx: *mut ()) -> Result<(), EvalError> { let expr = unsafe { &mut *expr }; let sink = unsafe { &mut *sink }; let items = expr.consume_head_check(stringify!($name).as_bytes())?; @@ -62,7 +62,7 @@ macro_rules! op { } }; (num unary $name:ident($x:ident: $tx:ty) => $e:expr) => { - pub extern "C" fn $name(expr: *mut ExprSource, sink: *mut ExprSink) -> Result<(), EvalError> { + pub extern "C" fn $name(expr: *mut ExprSource, sink: *mut ExprSink, _ctx: *mut ()) -> Result<(), EvalError> { let expr = unsafe { &mut *expr }; let sink = unsafe { &mut *sink }; let items = expr.consume_head_check(stringify!($name).as_bytes())?; @@ -73,7 +73,7 @@ macro_rules! op { } }; (num nulary $name:ident() => $e:expr) => { - pub extern "C" fn $name(expr: *mut ExprSource, sink: *mut ExprSink) -> Result<(), EvalError> { + pub extern "C" fn $name(expr: *mut ExprSource, sink: *mut ExprSink, _ctx: *mut ()) -> Result<(), EvalError> { let expr = unsafe { &mut *expr }; let sink = unsafe { &mut *sink }; let items = expr.consume_head_check(stringify!($name).as_bytes())?; @@ -83,7 +83,7 @@ macro_rules! op { } }; (num from_string $name:ident<$t:ty>) => { - pub extern "C" fn $name(expr: *mut ExprSource, sink: *mut ExprSink) -> Result<(), EvalError> { + pub extern "C" fn $name(expr: *mut ExprSource, sink: *mut ExprSink, _ctx: *mut ()) -> Result<(), EvalError> { let expr = unsafe { &mut *expr }; let sink = unsafe { &mut *sink }; let items = expr.consume_head_check(stringify!($name).as_bytes())?; @@ -95,7 +95,7 @@ macro_rules! op { } }; (num to_string $name:ident<$t:ty>) => { - pub extern "C" fn $name(expr: *mut ExprSource, sink: *mut ExprSink) -> Result<(), EvalError> { + pub extern "C" fn $name(expr: *mut ExprSource, sink: *mut ExprSink, _ctx: *mut ()) -> Result<(), EvalError> { let expr = unsafe { &mut *expr }; let sink = unsafe { &mut *sink }; let items = expr.consume_head_check(stringify!($name).as_bytes())?; @@ -703,7 +703,7 @@ op!(num nary min_f32(f32::INFINITY, t: f32, x: f32) => t.min(x)); op!(num from_string f32_from_string); op!(num to_string f32_to_string); -pub extern "C" fn encode_hex(expr: *mut ExprSource, sink: *mut ExprSink) -> Result<(), EvalError> { +pub extern "C" fn encode_hex(expr: *mut ExprSource, sink: *mut ExprSink, _ctx: *mut ()) -> Result<(), EvalError> { let expr = unsafe { &mut *expr }; let sink = unsafe { &mut *sink }; let items = expr.consume_head_check(b"encode_hex")?; @@ -716,7 +716,7 @@ pub extern "C" fn encode_hex(expr: *mut ExprSource, sink: *mut ExprSink) -> Resu Ok(()) } -pub extern "C" fn decode_hex(expr: *mut ExprSource, sink: *mut ExprSink) -> Result<(), EvalError> { +pub extern "C" fn decode_hex(expr: *mut ExprSource, sink: *mut ExprSink, _ctx: *mut ()) -> Result<(), EvalError> { let expr = unsafe { &mut *expr }; let sink = unsafe { &mut *sink }; let items = expr.consume_head_check(b"decode_hex")?; @@ -729,7 +729,7 @@ pub extern "C" fn decode_hex(expr: *mut ExprSource, sink: *mut ExprSink) -> Resu Ok(()) } -pub extern "C" fn decode_base64url(expr: *mut ExprSource, sink: *mut ExprSink) -> Result<(), EvalError> { +pub extern "C" fn decode_base64url(expr: *mut ExprSource, sink: *mut ExprSink, _ctx: *mut ()) -> Result<(), EvalError> { let expr = unsafe { &mut *expr }; let sink = unsafe { &mut *sink }; let items = expr.consume_head_check(b"decode_base64url")?; @@ -742,7 +742,7 @@ pub extern "C" fn decode_base64url(expr: *mut ExprSource, sink: *mut ExprSink) - Ok(()) } -pub extern "C" fn encode_base64url(expr: *mut ExprSource, sink: *mut ExprSink) -> Result<(), EvalError> { +pub extern "C" fn encode_base64url(expr: *mut ExprSource, sink: *mut ExprSink, _ctx: *mut ()) -> Result<(), EvalError> { let expr = unsafe { &mut *expr }; let sink = unsafe { &mut *sink }; let items = expr.consume_head_check(b"encode_base64url")?; @@ -755,7 +755,7 @@ pub extern "C" fn encode_base64url(expr: *mut ExprSource, sink: *mut ExprSink) - Ok(()) } -pub extern "C" fn hash_expr(expr: *mut ExprSource, sink: *mut ExprSink) -> Result<(), EvalError> { +pub extern "C" fn hash_expr(expr: *mut ExprSource, sink: *mut ExprSink, _ctx: *mut ()) -> Result<(), EvalError> { let expr = unsafe { &mut *expr }; let sink = unsafe { &mut *sink }; let items = expr.consume_head_check(b"hash_expr")?; @@ -767,7 +767,7 @@ pub extern "C" fn hash_expr(expr: *mut ExprSource, sink: *mut ExprSink) -> Resul Ok(()) } -pub extern "C" fn reverse_symbol(expr: *mut ExprSource, sink: *mut ExprSink) -> Result<(), EvalError> { +pub extern "C" fn reverse_symbol(expr: *mut ExprSource, sink: *mut ExprSink, _ctx: *mut ()) -> Result<(), EvalError> { let expr = unsafe { &mut *expr }; let sink = unsafe { &mut *sink }; let items = expr.consume_head_check(b"reverse_symbol")?; @@ -780,7 +780,7 @@ pub extern "C" fn reverse_symbol(expr: *mut ExprSource, sink: *mut ExprSink) -> Ok(()) } -pub extern "C" fn collapse_symbol(expr: *mut ExprSource, sink: *mut ExprSink) -> Result<(), EvalError> { +pub extern "C" fn collapse_symbol(expr: *mut ExprSource, sink: *mut ExprSink, _ctx: *mut ()) -> Result<(), EvalError> { let expr = unsafe { &mut *expr }; let sink = unsafe { &mut *sink }; let items = expr.consume_head_check(b"collapse_symbol")?; @@ -800,7 +800,7 @@ pub extern "C" fn collapse_symbol(expr: *mut ExprSource, sink: *mut ExprSink) -> Ok(()) } -pub extern "C" fn explode_symbol(expr: *mut ExprSource, sink: *mut ExprSink) -> Result<(), EvalError> { +pub extern "C" fn explode_symbol(expr: *mut ExprSource, sink: *mut ExprSink, _ctx: *mut ()) -> Result<(), EvalError> { let expr = unsafe { &mut *expr }; let sink = unsafe { &mut *sink }; let items = expr.consume_head_check(b"explode_symbol")?; @@ -813,7 +813,7 @@ pub extern "C" fn explode_symbol(expr: *mut ExprSource, sink: *mut ExprSink) -> Ok(()) } -// pub extern "C" fn nth_expr(expr: *mut ExprSource, sink: *mut ExprSink) -> Result<(), EvalError> { +// pub extern "C" fn nth_expr(expr: *mut ExprSource, sink: *mut ExprSink, _ctx: *mut ()) -> Result<(), EvalError> { // let expr = unsafe { &mut *expr }; // let sink = unsafe { &mut *sink }; // let items = expr.consume_head_check(b"nth_expr")?; @@ -829,7 +829,7 @@ pub extern "C" fn explode_symbol(expr: *mut ExprSource, sink: *mut ExprSink) -> // (ifnz then else ) // The condition may be of any length ( is always len >= 1), // but all bytes in the must be b'\0' in order for the condition to be considered `false` -pub extern "C" fn ifnz(expr: *mut ExprSource, sink: *mut ExprSink) -> Result<(), EvalError> { +pub extern "C" fn ifnz(expr: *mut ExprSource, sink: *mut ExprSink, _ctx: *mut ()) -> Result<(), EvalError> { let expr = unsafe { &mut *expr }; let sink = unsafe { &mut *sink }; let items = expr.consume_head_check(b"ifnz")?; @@ -853,7 +853,7 @@ pub extern "C" fn ifnz(expr: *mut ExprSource, sink: *mut ExprSink) -> Result<(), } } -pub extern "C" fn tuple(expr: *mut ExprSource, sink: *mut ExprSink) -> Result<(), EvalError> { +pub extern "C" fn tuple(expr: *mut ExprSource, sink: *mut ExprSink, _ctx: *mut ()) -> Result<(), EvalError> { let expr = unsafe { &mut *expr }; let sink = unsafe { &mut *sink }; let items = expr.consume_head_check(b"tuple")?; diff --git a/kernel/src/sinks.rs b/kernel/src/sinks.rs index f5e2c62..50a6887 100644 --- a/kernel/src/sinks.rs +++ b/kernel/src/sinks.rs @@ -37,6 +37,7 @@ pub(crate) enum WriteResourceRequest { BTM(&'static [u8]), ACT(&'static str), Z3(&'static str), + TensorStore, } impl WriteResourceRequest { @@ -64,6 +65,12 @@ impl WriteResourceRequest { _ => { None } } } + WriteResourceRequest::TensorStore => { + match other { + WriteResourceRequest::TensorStore => { Some(WriteResourceRequest::TensorStore) } + _ => { None } + } + } } } } @@ -86,6 +93,11 @@ impl PartialOrd for WriteResourceRequest { if s == o { Some(Ordering::Equal) } else { None } } else { None } } + WriteResourceRequest::TensorStore => { + if let WriteResourceRequest::TensorStore = other { + Some(Ordering::Equal) + } else { None } + } } } } @@ -93,7 +105,8 @@ impl PartialOrd for WriteResourceRequest { pub(crate) enum WriteResource<'w, 'a, 'k> { BTM(&'w mut WriteZipperTracked<'a, 'k, ()>), ACT(()), - Z3(&'w mut subprocess::Popen) + Z3(&'w mut subprocess::Popen), + TensorStore(&'w mut std::collections::HashMap, crate::sparse::SparseTensorF64>), } // trait JoinLattice { @@ -1058,6 +1071,7 @@ impl Sink for PureSink { fn new(e: Expr) -> Self { let mut scope = EvalScope::new(); pure::register(&mut scope); + crate::sparse::register(&mut scope); PureSink { e, unique: PathMap::new(), scope } } fn request(&self) -> impl Iterator { @@ -1191,7 +1205,278 @@ impl Sink for Z3Sink { } +// ============================================================================ +// Tensor sinks — all tensor I/O goes through WriteResource::TensorStore +// ============================================================================ + +/// Helper: parse symbol args from an expression path, skipping a known header. +fn parse_symbol_args<'a>(path: &'a [u8], header_size: usize) -> Vec<&'a [u8]> { + let rest = &path[header_size..]; + let mut pos = 0; + let mut args = Vec::new(); + while pos < rest.len() { + match byte_item(rest[pos]) { + Tag::SymbolSize(len) => { + let len = len as usize; + pos += 1; + if pos + len <= rest.len() { + args.push(&rest[pos..pos+len]); + } + pos += len; + } + _ => { pos += 1; } + } + } + args +} + +/// TensorCollectSink — writes (indices, value) tuples directly into a named SparseTensorF64. +/// Syntax: (tensor_collect name $i0 $i1 ... $val) +/// +/// The first matching tuple replaces any existing tensor at `name` with a +/// fresh one; subsequent matches in the same exec invocation accumulate +/// into it. This assumes no concurrent tensor-as-source path is reading +/// `name` during the exec — if that ever becomes possible the writes must +/// be buffered until `finalize()` (see TensorFreeSink for that pattern). +pub struct TensorCollectSink { + e: Expr, + name: Vec, + rank: usize, + initialized: bool, +} + +impl TensorCollectSink { + const HEADER_SIZE: usize = 16; // Arity(N) + SymbolSize(14) + "tensor_collect" +} + +impl Sink for TensorCollectSink { + fn new(e: Expr) -> Self { + let arity = unsafe { + if let Tag::Arity(a) = byte_item(*e.ptr) { a as usize } else { panic!("tensor_collect: expected arity") } + }; + let rank = arity.saturating_sub(3); + let name = unsafe { + let name_tag = *e.ptr.add(Self::HEADER_SIZE); + if let Tag::SymbolSize(len) = byte_item(name_tag) { + std::slice::from_raw_parts(e.ptr.add(Self::HEADER_SIZE + 1), len as usize).to_vec() + } else { + panic!("tensor_collect: second arg must be a symbol (tensor name)") + } + }; + TensorCollectSink { e, name, rank, initialized: false } + } + + fn request(&self) -> impl Iterator { + std::iter::once(WriteResourceRequest::TensorStore) + } + + fn sink<'w, 'a, 'k, It: Iterator>>(&mut self, mut it: It, path: &[u8]) where 'a: 'w, 'k: 'w { + let name_len = self.name.len(); + let args = parse_symbol_args(path, Self::HEADER_SIZE + 1 + name_len); + + if args.len() < self.rank + 1 { return; } + + let mut indices = Vec::with_capacity(self.rank); + for i in 0..self.rank { + let Ok(s) = std::str::from_utf8(args[i]) else { return; }; + let Ok(idx) = s.parse::() else { return; }; + indices.push(idx); + } + let Ok(s) = std::str::from_utf8(args[self.rank]) else { return; }; + let Ok(val) = s.parse::() else { return; }; + + let WriteResource::TensorStore(store) = it.next().unwrap() else { unreachable!() }; + if !self.initialized { + store.insert(self.name.clone(), crate::sparse::SparseTensorF64::new(self.rank)); + self.initialized = true; + } + store.get_mut(&self.name).unwrap().set(&indices, val); + } + + fn finalize<'w, 'a, 'k, It: Iterator>>(&mut self, _it: It) -> bool where 'a: 'w, 'k: 'w { + self.initialized + } +} + +/// TensorEinsumSink — runs einsum on named tensors. +/// Syntax: (tensor_einsum "spec" input1 input2 ... output) +/// Variables are bound from pattern matching; actual values arrive in sink(). +pub struct TensorEinsumSink { + e: Expr, + spec: Vec, + input_names: Vec>, + output_name: Vec, + parsed: bool, +} + +impl TensorEinsumSink { + // "tensor_einsum" is 13 chars → header = Arity + SymbolSize(13) + "tensor_einsum" = 15 + const HEADER_SIZE: usize = 15; +} + +impl Sink for TensorEinsumSink { + fn new(e: Expr) -> Self { + let span = unsafe { e.span().as_ref().unwrap() }; + let args = parse_symbol_args(span, Self::HEADER_SIZE); + let spec = args.first().map(|a| a.to_vec()).unwrap_or_default(); + let mut names: Vec> = args.get(1..).unwrap_or(&[]).iter().map(|a| a.to_vec()).collect(); + let output_name = names.pop().unwrap_or_default(); + TensorEinsumSink { e, spec, input_names: names, output_name, parsed: !args.is_empty() } + } + fn request(&self) -> impl Iterator { + std::iter::once(WriteResourceRequest::TensorStore) + } + fn sink<'w, 'a, 'k, It: Iterator>>(&mut self, _it: It, _path: &[u8]) where 'a: 'w, 'k: 'w {} + fn finalize<'w, 'a, 'k, It: Iterator>>(&mut self, mut it: It) -> bool where 'a: 'w, 'k: 'w { + if !self.parsed { return false; } + let WriteResource::TensorStore(store) = it.next().unwrap() else { unreachable!() }; + // Strip quotes from spec if present (MORK parser includes them) + let spec_bytes = if self.spec.starts_with(b"\"") && self.spec.ends_with(b"\"") { + &self.spec[1..self.spec.len()-1] + } else { + &self.spec[..] + }; + let spec_str = match std::str::from_utf8(spec_bytes) { + Ok(s) => s, + Err(_) => { log::error!(target: "tensor", "tensor_einsum: invalid spec"); return false; } + }; + + let inputs: Vec<&crate::sparse::SparseTensorF64> = self.input_names.iter() + .filter_map(|name| store.get(name)) + .collect(); + if inputs.len() != self.input_names.len() { + log::error!(target: "tensor", "tensor_einsum: missing input tensor(s)"); + return false; + } + + // Parse spec to determine output shape + let arrow = match spec_str.find("->") { + Some(a) => a, + None => { log::error!(target: "tensor", "tensor_einsum: spec missing '->'"); return false; } + }; + let output_spec = &spec_str[arrow+2..]; + let input_specs: Vec<&str> = spec_str[..arrow].split(',').collect(); + + let mut dim_map = std::collections::HashMap::new(); + for (i, ispec) in input_specs.iter().enumerate() { + if let Some(inp) = inputs.get(i) { + for (axis, ch) in ispec.bytes().enumerate() { + let d = inp.dims.get(axis).copied().unwrap_or(0); + dim_map.entry(ch).or_insert(d); + } + } + } + let out_dims: Vec = output_spec.bytes() + .map(|ch| dim_map.get(&ch).copied().unwrap_or(1)) + .collect(); + + let mut output = crate::sparse::SparseTensorF64::with_dims(out_dims); + let dyn_inputs: Vec<&dyn einsum_dyn::NDIndex> = inputs.iter() + .map(|t| *t as &dyn einsum_dyn::NDIndex) + .collect(); + let mut dyn_out: &mut dyn einsum_dyn::NDIndex = &mut output; + if let Err(e) = einsum_dyn::sparse::einsum_vm_oneshot_dyn(spec_str, &dyn_inputs, &mut [dyn_out]) { + log::error!(target: "tensor", "tensor_einsum failed: {}", e); + return false; + } + store.insert(self.output_name.clone(), output); + true + } +} + +/// TensorBinopSink — element-wise add or mul on named tensors. +/// Syntax: (tensor_add A B C) or (tensor_mul A B C) +pub struct TensorBinopSink { + e: Expr, + op: TensorBinop, + header_size: usize, + a_name: Vec, + b_name: Vec, + c_name: Vec, + parsed: bool, +} + +enum TensorBinop { Add, Mul } + +impl TensorBinopSink { + fn new_with_op(e: Expr, op: TensorBinop, header_size: usize) -> Self { + TensorBinopSink { e, op, header_size, a_name: Vec::new(), b_name: Vec::new(), c_name: Vec::new(), parsed: false } + } +} + +impl Sink for TensorBinopSink { + fn new(e: Expr) -> Self { panic!("use new_with_op") } + fn request(&self) -> impl Iterator { + std::iter::once(WriteResourceRequest::TensorStore) + } + fn sink<'w, 'a, 'k, It: Iterator>>(&mut self, _it: It, path: &[u8]) where 'a: 'w, 'k: 'w { + if self.parsed { return; } + let args = parse_symbol_args(path, self.header_size); + self.a_name = args.first().map(|a| a.to_vec()).unwrap_or_default(); + self.b_name = args.get(1).map(|a| a.to_vec()).unwrap_or_default(); + self.c_name = args.get(2).map(|a| a.to_vec()).unwrap_or_default(); + self.parsed = true; + } + fn finalize<'w, 'a, 'k, It: Iterator>>(&mut self, mut it: It) -> bool where 'a: 'w, 'k: 'w { + if !self.parsed { return false; } + let WriteResource::TensorStore(store) = it.next().unwrap() else { unreachable!() }; + let (a, b) = match (store.get(&self.a_name), store.get(&self.b_name)) { + (Some(a), Some(b)) => (a, b), + _ => { log::error!(target: "tensor", "tensor binop: missing input"); return false; } + }; + let c = match self.op { + TensorBinop::Add => a.add(b), + TensorBinop::Mul => a.mul(b), + }; + store.insert(self.c_name.clone(), c); + true + } +} + +/// TensorFreeSink — removes named tensors. +/// Syntax: (tensor_free A) +/// +/// Names are buffered during `sink()` and the actual removals happen in +/// `finalize()`. We cannot free in `sink()` because a future source type +/// (tensor-as-source) could be iterating the tensor store while pattern +/// matches drive these calls — mutating the store mid-iteration would +/// invalidate the upstream zipper. +pub struct TensorFreeSink { + e: Expr, + names: Vec>, +} + +impl Sink for TensorFreeSink { + fn new(e: Expr) -> Self { + TensorFreeSink { e, names: Vec::new() } + } + fn request(&self) -> impl Iterator { + std::iter::once(WriteResourceRequest::TensorStore) + } + fn sink<'w, 'a, 'k, It: Iterator>>(&mut self, _it: It, path: &[u8]) where 'a: 'w, 'k: 'w { + // "tensor_free" is 11 chars → header = 13 + let args = parse_symbol_args(path, 13); + if let Some(name) = args.first() { + self.names.push(name.to_vec()); + } + } + fn finalize<'w, 'a, 'k, It: Iterator>>(&mut self, mut it: It) -> bool where 'a: 'w, 'k: 'w { + let WriteResource::TensorStore(store) = it.next().unwrap() else { unreachable!() }; + let mut any = false; + for name in self.names.drain(..) { + if store.remove(&name).is_some() { + any = true; + } + } + any + } +} + pub enum ASink { AddSink(AddSink), RemoveSink(RemoveSink), HeadSink(HeadSink), CountSink(CountSink), HashSink(HashSink), SumSink(SumSink), AndSink(AndSink), ACTSink(ACTSink), + TensorCollectSink(TensorCollectSink), + TensorEinsumSink(TensorEinsumSink), + TensorBinopSink(TensorBinopSink), + TensorFreeSink(TensorFreeSink), #[cfg(feature = "wasm")] WASMSink(WASMSink), #[cfg(feature = "grounding")] @@ -1211,6 +1496,14 @@ impl ASink { pub fn compat(e: Expr) -> Self { ASink::CompatSink(CompatSink::new(e)) } + /// Set an opaque context pointer on any PureSink's EvalScope. + /// This allows pure functions to access external state (e.g. tensor store). + pub fn set_context(&mut self, ctx: *mut ()) { + #[cfg(feature = "grounding")] + if let ASink::PureSink(s) = self { + s.scope.context = ctx; + } + } } impl Sink for ASink { @@ -1271,6 +1564,35 @@ impl Sink for ASink { return ASink::Z3Sink(Z3Sink::new(e)); #[cfg(not(feature = "z3"))] panic!("MORK was not built with the z3 feature, yet trying to call {:?}", e); + } else if unsafe { *e.ptr.offset(1) == item_byte(Tag::SymbolSize(14)) && + *e.ptr.offset(2) == b't' && *e.ptr.offset(3) == b'e' && *e.ptr.offset(4) == b'n' && *e.ptr.offset(5) == b's' && + *e.ptr.offset(6) == b'o' && *e.ptr.offset(7) == b'r' && *e.ptr.offset(8) == b'_' && *e.ptr.offset(9) == b'c' && + *e.ptr.offset(10) == b'o' && *e.ptr.offset(11) == b'l' && *e.ptr.offset(12) == b'l' && *e.ptr.offset(13) == b'e' && + *e.ptr.offset(14) == b'c' && *e.ptr.offset(15) == b't' } { + return ASink::TensorCollectSink(TensorCollectSink::new(e)); + } else if unsafe { *e.ptr.offset(1) == item_byte(Tag::SymbolSize(13)) && + *e.ptr.offset(2) == b't' && *e.ptr.offset(3) == b'e' && *e.ptr.offset(4) == b'n' && *e.ptr.offset(5) == b's' && + *e.ptr.offset(6) == b'o' && *e.ptr.offset(7) == b'r' && *e.ptr.offset(8) == b'_' && *e.ptr.offset(9) == b'e' && + *e.ptr.offset(10) == b'i' && *e.ptr.offset(11) == b'n' && *e.ptr.offset(12) == b's' && *e.ptr.offset(13) == b'u' && + *e.ptr.offset(14) == b'm' } { + return ASink::TensorEinsumSink(TensorEinsumSink::new(e)); + } else if unsafe { *e.ptr.offset(1) == item_byte(Tag::SymbolSize(10)) && + *e.ptr.offset(2) == b't' && *e.ptr.offset(3) == b'e' && *e.ptr.offset(4) == b'n' && *e.ptr.offset(5) == b's' && + *e.ptr.offset(6) == b'o' && *e.ptr.offset(7) == b'r' && *e.ptr.offset(8) == b'_' && *e.ptr.offset(9) == b'a' && + *e.ptr.offset(10) == b'd' && *e.ptr.offset(11) == b'd' } { + // "tensor_add" = 10 chars, header = 12 + return ASink::TensorBinopSink(TensorBinopSink::new_with_op(e, TensorBinop::Add, 12)); + } else if unsafe { *e.ptr.offset(1) == item_byte(Tag::SymbolSize(10)) && + *e.ptr.offset(2) == b't' && *e.ptr.offset(3) == b'e' && *e.ptr.offset(4) == b'n' && *e.ptr.offset(5) == b's' && + *e.ptr.offset(6) == b'o' && *e.ptr.offset(7) == b'r' && *e.ptr.offset(8) == b'_' && *e.ptr.offset(9) == b'm' && + *e.ptr.offset(10) == b'u' && *e.ptr.offset(11) == b'l' } { + // "tensor_mul" = 10 chars, header = 12 + return ASink::TensorBinopSink(TensorBinopSink::new_with_op(e, TensorBinop::Mul, 12)); + } else if unsafe { *e.ptr.offset(1) == item_byte(Tag::SymbolSize(11)) && + *e.ptr.offset(2) == b't' && *e.ptr.offset(3) == b'e' && *e.ptr.offset(4) == b'n' && *e.ptr.offset(5) == b's' && + *e.ptr.offset(6) == b'o' && *e.ptr.offset(7) == b'r' && *e.ptr.offset(8) == b'_' && *e.ptr.offset(9) == b'f' && + *e.ptr.offset(10) == b'r' && *e.ptr.offset(11) == b'e' && *e.ptr.offset(12) == b'e' } { + return ASink::TensorFreeSink(TensorFreeSink::new(e)); } else { panic!("unrecognized sink") } @@ -1300,6 +1622,10 @@ impl Sink for ASink { ASink::FMinSink(s) => { for i in s.request().into_iter() { yield i } } ASink::FMaxSink(s) => { for i in s.request().into_iter() { yield i } } ASink::FProdSink(s) => { for i in s.request().into_iter() { yield i } } + ASink::TensorCollectSink(s) => { for i in s.request().into_iter() { yield i } } + ASink::TensorEinsumSink(s) => { for i in s.request().into_iter() { yield i } } + ASink::TensorBinopSink(s) => { for i in s.request().into_iter() { yield i } } + ASink::TensorFreeSink(s) => { for i in s.request().into_iter() { yield i } } } } } @@ -1326,6 +1652,10 @@ impl Sink for ASink { ASink::FMinSink(s) => { s.sink(it, path) } ASink::FMaxSink(s) => { s.sink(it, path) } ASink::FProdSink(s) => { s.sink(it, path) } + ASink::TensorCollectSink(s) => { s.sink(it, path) } + ASink::TensorEinsumSink(s) => { s.sink(it, path) } + ASink::TensorBinopSink(s) => { s.sink(it, path) } + ASink::TensorFreeSink(s) => { s.sink(it, path) } } } @@ -1352,6 +1682,10 @@ impl Sink for ASink { ASink::FMinSink(s) => { s.finalize(it) } ASink::FMaxSink(s) => { s.finalize(it) } ASink::FProdSink(s) => { s.finalize(it) } + ASink::TensorCollectSink(s) => { s.finalize(it) } + ASink::TensorEinsumSink(s) => { s.finalize(it) } + ASink::TensorBinopSink(s) => { s.finalize(it) } + ASink::TensorFreeSink(s) => { s.finalize(it) } } } } diff --git a/kernel/src/space.rs b/kernel/src/space.rs index f954638..4811a9d 100644 --- a/kernel/src/space.rs +++ b/kernel/src/space.rs @@ -40,6 +40,7 @@ pub struct Space { pub sm: SharedMappingHandle, pub mmaps: HashMap>, pub z3s: HashMap>, + pub tensors: HashMap, crate::sparse::SparseTensorF64>, pub last_merkleize: Instant, pub timing: bool } @@ -444,7 +445,7 @@ macro_rules! sexpr { impl Space { pub fn new() -> Self { - Self { btm: PathMap::new(), sm: SharedMapping::new(), mmaps: HashMap::new(), z3s: HashMap::new(), last_merkleize: Instant::now(), timing: false } + Self { btm: PathMap::new(), sm: SharedMapping::new(), mmaps: HashMap::new(), z3s: HashMap::new(), tensors: HashMap::new(), last_merkleize: Instant::now(), timing: false } } pub fn parse_sexpr(&mut self, r: &[u8], buf: *mut u8) -> Result<(Expr, usize), ParserError> { @@ -1078,6 +1079,7 @@ impl Space { unsafe fn write_handler<'w, 'a, 'k>(zh_wzs: (*mut ZipperHead<'w, 'a, ()>, *mut Vec>), mmaps: *mut HashMap>, z3s: *mut HashMap>, + tensors: *mut HashMap, crate::sparse::SparseTensorF64>, request: &WriteResourceRequest) -> WriteResource<'w, 'a, 'k> where 'w : 'a { match *request { WriteResourceRequest::BTM(p) => { @@ -1103,6 +1105,9 @@ impl Space { }).as_mut(); WriteResource::Z3(instance) } + WriteResourceRequest::TensorStore => { + WriteResource::TensorStore(tensors.as_mut().unwrap()) + } } } @@ -1488,6 +1493,7 @@ impl Space { #[cfg(feature="specialize_io")] pub fn transform_multi_multi_o(&mut self, pat_expr: Expr, tpl_expr: Expr, add: Expr) -> (usize, bool) { use crate::sinks::*; + let tensors_ptr = (&self.tensors as *const HashMap, crate::sparse::SparseTensorF64>).cast_mut(); let mut buffer = Vec::with_capacity(1 << 32); unsafe { buffer.set_len(1 << 32); } let mut tpl_args = Vec::with_capacity(64); @@ -1511,7 +1517,7 @@ impl Space { template_prefixes.iter().enumerate().for_each(|(i, request)| { if subsumption[i] == i { placements[i] = template_resources.len(); - template_resources.push(unsafe { Self::write_handler((zh_ptr, outstanding_wzs_ptr), acts_ptr, z3s_ptr, request) }); + template_resources.push(unsafe { Self::write_handler((zh_ptr, outstanding_wzs_ptr), acts_ptr, z3s_ptr, tensors_ptr, request) }); } }); for i in 0..subsumption.len() { @@ -1523,7 +1529,7 @@ impl Space { let mut assignments: Vec<(u8, u8)> = vec![]; let mut trace: Vec<(u8, u8)> = vec![]; - + let mut ass = Vec::with_capacity(64); let mut astack = Vec::with_capacity(64); @@ -1569,6 +1575,7 @@ impl Space { }); for (i, s) in sinks.iter_mut().enumerate() { + s.set_context(tensors_ptr.cast()); let wz = unsafe { std::ptr::read(&template_resources[subsumption[i]]) }; any_new |= s.finalize(std::iter::once(wz)); } @@ -1581,6 +1588,8 @@ impl Space { pub fn transform_multi_multi_io(&mut self, pat_expr: Expr, tpl_expr: Expr, add: Expr, no_source: bool, no_sink: bool) -> (usize, bool) { use crate::sinks::*; + // Set tensor store pointer for sinks/pure functions that need it + let tensors_ptr = (&self.tensors as *const HashMap, crate::sparse::SparseTensorF64>).cast_mut(); let mut buffer = Vec::with_capacity(1 << 32); unsafe { buffer.set_len(1 << 32); } let mut tpl_args = Vec::with_capacity(64); @@ -1604,7 +1613,7 @@ impl Space { template_prefixes.iter().enumerate().for_each(|(i, request)| { if subsumption[i] == i { placements[i] = template_resources.len(); - template_resources.push(unsafe { Self::write_handler((zh_ptr, outstanding_wzs_ptr), acts_ptr, z3s_ptr, request) }); + template_resources.push(unsafe { Self::write_handler((zh_ptr, outstanding_wzs_ptr), acts_ptr, z3s_ptr, tensors_ptr, request) }); } }); for i in 0..subsumption.len() { @@ -1663,6 +1672,7 @@ impl Space { }); for (i, s) in sinks.iter_mut().enumerate() { + s.set_context(tensors_ptr.cast()); let wz = unsafe { std::ptr::read(&template_resources[subsumption[i]]) }; any_new |= s.finalize(std::iter::once(wz)); } @@ -1672,7 +1682,7 @@ impl Space { (touched, any_new) } - + // (exec (, ) // (, )) pub fn interpret(&mut self, rt: Expr) -> Result<(), &'static str> { diff --git a/kernel/src/sparse.rs b/kernel/src/sparse.rs new file mode 100644 index 0000000..77768c3 --- /dev/null +++ b/kernel/src/sparse.rs @@ -0,0 +1,475 @@ +use std::collections::HashMap; +use std::hash::Hasher; +use num_traits::Zero; +use pathmap::PathMap; +use pathmap::ring::{AlgebraicResult, Lattice}; +use pathmap::utils::ints::indices_to_bob; + +// ============================================================================ +// FAddMulF64 — Lattice wrapper for f64 (join=add, meet=mul) +// ============================================================================ + +#[derive(Copy, Clone, Debug)] +#[repr(transparent)] +pub struct FAddMulF64(pub f64); + +impl std::ops::Deref for FAddMulF64 { + type Target = f64; + fn deref(&self) -> &Self::Target { &self.0 } +} + +impl std::hash::Hash for FAddMulF64 { + fn hash(&self, state: &mut H) { + self.0.to_bits().hash(state); + } +} + +impl Lattice for FAddMulF64 { + fn pjoin(&self, other: &Self) -> AlgebraicResult { + if self.0.is_zero() { return AlgebraicResult::Identity(1) } + if other.0.is_zero() { return AlgebraicResult::Identity(2) } + let s = self.0 + other.0; + if self.0 * other.0 < 0f64 && s.abs() < 1e-15 { return AlgebraicResult::None } + AlgebraicResult::Element(FAddMulF64(s)) + } + + fn pmeet(&self, other: &Self) -> AlgebraicResult { + let s = self.0 * other.0; + if s.abs() < 1e-15 { return AlgebraicResult::None } + AlgebraicResult::Element(FAddMulF64(s)) + } +} + +// ============================================================================ +// SparseTensorF64 — PathMap-backed sparse tensor with BOB encoding +// ============================================================================ + +pub struct SparseTensorF64 { + pub m: PathMap, + pub d: usize, + pub dims: Vec, + p: Vec, +} + +impl SparseTensorF64 { + pub fn new(rank: usize) -> Self { + Self { m: PathMap::new(), d: rank, dims: vec![0; rank], p: Vec::new() } + } + + pub fn with_dims(dims: Vec) -> Self { + let d = dims.len(); + Self { m: PathMap::new(), d, dims, p: Vec::new() } + } + + pub fn set(&mut self, ix: &[usize], v: f64) { + for (i, &idx) in ix.iter().enumerate() { + if idx >= self.dims[i] { + self.dims[i] = idx + 1; + } + } + let path = Self::index_to_path_static(ix); + self.m.insert(&path[..], v); + } + + pub fn get(&self, ix: &[usize]) -> Option { + let path = Self::index_to_path_static(ix); + self.m.get(&path[..]).copied() + } + + pub fn remove(&mut self, ix: &[usize]) -> Option { + let path = Self::index_to_path_static(ix); + self.m.remove(&path[..]) + } + + pub fn nnz(&self) -> usize { + self.m.val_count() + } + + fn index_to_path_static(ix: &[usize]) -> Vec { + let mut p = Vec::new(); + let len = indices_to_bob(ix, &mut vec![]); + p.extend(std::iter::repeat_n(0u8, 64 - len)); + indices_to_bob(ix, &mut p); + p + } + + // Safety: FAddMulF64 has the same layout as f64 + fn vf(&self) -> &PathMap { + unsafe { (&self.m as *const PathMap as *const PathMap).as_ref().unwrap_unchecked() } + } + + fn from_vf(m: PathMap, d: usize, dims: Vec) -> Self { + unsafe { Self { m: std::mem::transmute::, PathMap>(m), d, dims, p: Vec::new() } } + } + + pub fn add(&self, other: &Self) -> Self { + let dims = self.dims.iter().zip(&other.dims).map(|(&a, &b)| a.max(b)).collect(); + Self::from_vf(self.vf().join(other.vf()), self.d, dims) + } + + pub fn mul(&self, other: &Self) -> Self { + let dims = self.dims.iter().zip(&other.dims).map(|(&a, &b)| a.max(b)).collect(); + Self::from_vf(self.vf().meet(other.vf()), self.d, dims) + } +} + +// ============================================================================ +// NDIndex implementation for einsum-dyn compatibility +// ============================================================================ + +impl einsum_dyn::NDIndex for SparseTensorF64 { + fn ndim(&self) -> usize { self.d } + fn dim(&self, axis: usize) -> usize { self.dims[axis] } + + fn get(&self, indices: &[usize]) -> f64 { + self.get(indices).unwrap_or(0.0) + } + + fn set(&mut self, indices: &[usize], val: f64) { + if val.abs() < 1e-15 { + self.remove(indices); + } else { + self.set(indices, val); + } + } + + fn get_opt(&self, indices: &[usize]) -> Option { + self.get(indices) + } + + fn is_sparse_2d(&self) -> bool { false } +} + +// ============================================================================ +// Pure functions (access tensor store via ctx arg) +// ============================================================================ + +use eval_ffi::{ExprSource, ExprSink, EvalError}; +use mork_expr::SourceItem; +use eval::{EvalScope, FuncType}; + +/// Reinterpret the opaque context pointer as a tensor store reference. +/// The context is set by PureSink via ASink::set_context before eval. +unsafe fn tensor_store_from_context(ctx: *mut ()) -> Option<&'static HashMap, SparseTensorF64>> { + (ctx as *const HashMap, SparseTensorF64>).as_ref() +} + +/// (tensor_get name i0 i1 ... iN) -> f64 value at that index +pub extern "C" fn tensor_get(expr: *mut ExprSource, sink: *mut ExprSink, ctx: *mut ()) -> Result<(), EvalError> { + let expr = unsafe { &mut *expr }; + let sink = unsafe { &mut *sink }; + let items = expr.consume_head_check(b"tensor_get")?; + if items < 2 { return Err(EvalError::from("tensor_get requires name and indices")) } + + let SourceItem::Symbol(name) = expr.read() else { + return Err(EvalError::from("tensor_get: first arg must be tensor name")) + }; + let name = name.to_vec(); + + let mut indices: Vec = Vec::new(); + for _ in 0..(items - 1) { + let SourceItem::Symbol(idx_bytes) = expr.read() else { + return Err(EvalError::from("tensor_get: index must be a symbol")) + }; + let idx = std::str::from_utf8(idx_bytes) + .map_err(|_| EvalError::from("tensor_get: index not utf8"))? + .parse::() + .map_err(|_| EvalError::from("tensor_get: index not a number"))?; + indices.push(idx); + } + + let store = unsafe { tensor_store_from_context(ctx) }; + let val = store + .and_then(|s| s.get(&name)) + .and_then(|t| t.get(&indices)) + .unwrap_or(0.0); + + sink.write(SourceItem::Symbol(&val.to_be_bytes()[..]))?; + Ok(()) +} + +/// (tensor_nnz name) -> u64 count of non-zeros +pub extern "C" fn tensor_nnz(expr: *mut ExprSource, sink: *mut ExprSink, ctx: *mut ()) -> Result<(), EvalError> { + let expr = unsafe { &mut *expr }; + let sink = unsafe { &mut *sink }; + let items = expr.consume_head_check(b"tensor_nnz")?; + if items != 1 { return Err(EvalError::from("tensor_nnz takes one argument")) } + + let SourceItem::Symbol(name) = expr.read() else { + return Err(EvalError::from("tensor_nnz: arg must be tensor name")) + }; + let name = name.to_vec(); + + let store = unsafe { tensor_store_from_context(ctx) }; + let nnz = store + .and_then(|s| s.get(&name)) + .map(|t| t.nnz()) + .unwrap_or(0) as u64; + + sink.write(SourceItem::Symbol(&nnz.to_be_bytes()[..]))?; + Ok(()) +} + +pub fn register(scope: &mut EvalScope) { + scope.add_func("tensor_get", tensor_get, FuncType::Pure); + scope.add_func("tensor_nnz", tensor_nnz, FuncType::Pure); +} + +// ============================================================================ +// Tests +// ============================================================================ + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_sparse_tensor_basic() { + let mut t = SparseTensorF64::new(2); + t.set(&[0, 1], 3.0); + t.set(&[1, 2], 5.0); + t.set(&[2, 0], 7.0); + + assert_eq!(t.get(&[0, 1]), Some(3.0)); + assert_eq!(t.get(&[1, 2]), Some(5.0)); + assert_eq!(t.get(&[2, 0]), Some(7.0)); + assert_eq!(t.get(&[0, 0]), None); + assert_eq!(t.nnz(), 3); + assert_eq!(t.dims, vec![3, 3]); + } + + #[test] + fn test_sparse_tensor_4d() { + let mut t = SparseTensorF64::new(4); + t.set(&[1, 2, 3, 4], 42.0); + t.set(&[0, 0, 0, 0], 1.0); + + assert_eq!(t.get(&[1, 2, 3, 4]), Some(42.0)); + assert_eq!(t.get(&[0, 0, 0, 0]), Some(1.0)); + assert_eq!(t.get(&[1, 1, 1, 1]), None); + assert_eq!(t.nnz(), 2); + } + + #[test] + fn test_sparse_tensor_add() { + let mut a = SparseTensorF64::new(2); + a.set(&[0, 0], 1.0); + a.set(&[0, 1], 2.0); + + let mut b = SparseTensorF64::new(2); + b.set(&[0, 0], 10.0); + b.set(&[1, 0], 20.0); + + let c = a.add(&b); + assert_eq!(c.get(&[0, 0]), Some(11.0)); + assert_eq!(c.get(&[0, 1]), Some(2.0)); + assert_eq!(c.get(&[1, 0]), Some(20.0)); + } + + #[test] + fn test_sparse_tensor_mul() { + let mut a = SparseTensorF64::new(2); + a.set(&[0, 0], 3.0); + a.set(&[0, 1], 5.0); + + let mut b = SparseTensorF64::new(2); + b.set(&[0, 0], 2.0); + b.set(&[1, 1], 4.0); + + let c = a.mul(&b); + assert_eq!(c.get(&[0, 0]), Some(6.0)); + assert_eq!(c.get(&[0, 1]), None); + assert_eq!(c.get(&[1, 1]), None); + } + + #[test] + fn test_ndindex_impl() { + use einsum_dyn::NDIndex; + let mut t = SparseTensorF64::with_dims(vec![3, 3]); + >::set(&mut t, &[0, 1], 5.0); + assert_eq!(>::get(&t, &[0, 1]), 5.0); + assert_eq!(>::get(&t, &[0, 0]), 0.0); + assert_eq!(t.get_opt(&[0, 1]), Some(5.0)); + assert_eq!(t.get_opt(&[0, 0]), None); + } + + #[test] + fn test_einsum_matmul() { + use einsum_dyn::NDIndex; + let mut a = SparseTensorF64::with_dims(vec![2, 2]); + a.set(&[0, 0], 1.0); a.set(&[0, 1], 2.0); + a.set(&[1, 0], 3.0); a.set(&[1, 1], 4.0); + + let mut b = SparseTensorF64::with_dims(vec![2, 2]); + b.set(&[0, 0], 5.0); b.set(&[0, 1], 6.0); + b.set(&[1, 0], 7.0); b.set(&[1, 1], 8.0); + + let mut c = SparseTensorF64::with_dims(vec![2, 2]); + + let inputs: Vec<&dyn einsum_dyn::NDIndex> = vec![&a, &b]; + let mut out: &mut dyn einsum_dyn::NDIndex = &mut c; + einsum_dyn::sparse::einsum_vm_oneshot_dyn("ab,bc->ac", &inputs, &mut [out]).unwrap(); + + assert_eq!(>::get(&c, &[0, 0]), 19.0); + assert_eq!(>::get(&c, &[0, 1]), 22.0); + assert_eq!(>::get(&c, &[1, 0]), 43.0); + assert_eq!(>::get(&c, &[1, 1]), 50.0); + } + + #[test] + fn test_end_to_end_collect_einsum() { + use crate::space::Space; + + let mut s = Space::new(); + + // Load matrix A = [[1, 2], [3, 4]] as (a row col val) triples + // Load matrix B = [[5, 6], [7, 8]] as (b row col val) triples + // Matrix A = [[1, 2], [3, 4]], B = [[5, 6], [7, 8]] + // Store as (a row col val) triples, collect into tensors, einsum, check result + s.add_all_sexpr(r#" + (a 0 0 1.0) + (a 0 1 2.0) + (a 1 0 3.0) + (a 1 1 4.0) + + (b 0 0 5.0) + (b 0 1 6.0) + (b 1 0 7.0) + (b 1 1 8.0) + + (exec P1 (, (a $r $c $v)) (O (tensor_collect A $r $c $v))) + (exec P2 (, (b $r $c $v)) (O (tensor_collect B $r $c $v))) + (exec P3 (,) (O (tensor_einsum "ab,bc->ac" A B C))) + "#.as_bytes()).unwrap(); + + s.metta_calculus(100); + + // Verify tensors were collected + assert!(s.tensors.contains_key(b"A".as_slice()), "tensor A should exist"); + assert!(s.tensors.contains_key(b"B".as_slice()), "tensor B should exist"); + assert!(s.tensors.contains_key(b"C".as_slice()), "tensor C should exist"); + + let a = s.tensors.get(b"A".as_slice()).unwrap(); + assert_eq!(a.nnz(), 4); + assert_eq!(a.get(&[0, 0]), Some(1.0)); + assert_eq!(a.get(&[1, 1]), Some(4.0)); + + let b = s.tensors.get(b"B".as_slice()).unwrap(); + assert_eq!(b.nnz(), 4); + + // C = A * B = [[19, 22], [43, 50]] + let c = s.tensors.get(b"C".as_slice()).unwrap(); + assert_eq!(c.get(&[0, 0]), Some(19.0)); + assert_eq!(c.get(&[0, 1]), Some(22.0)); + assert_eq!(c.get(&[1, 0]), Some(43.0)); + assert_eq!(c.get(&[1, 1]), Some(50.0)); + + // Phase 2: test tensor_get directly through EvalScope + { + use eval_ffi::ExprSource; + let mut scope = eval::EvalScope::new(); + crate::sparse::register(&mut scope); + scope.context = (&s.tensors as *const HashMap, SparseTensorF64>).cast_mut().cast(); + + // Build expression: (tensor_get C 0 0) + let expr_bytes = mork_expr::construct!("tensor_get" "C" "0" "0").unwrap(); + let result = scope.eval(ExprSource::new(expr_bytes.as_ptr())).unwrap(); + // Result is SymbolSize(8) + 8 bytes of f64 + assert_eq!(result.len(), 9, "tensor_get should return 9 bytes (tag + f64)"); + let val = f64::from_be_bytes(result[1..9].try_into().unwrap()); + assert!((val - 19.0).abs() < 1e-10, "tensor_get C 0 0 = {} expected 19.0", val); + + let expr_bytes = mork_expr::construct!("tensor_get" "C" "1" "1").unwrap(); + let result = scope.eval(ExprSource::new(expr_bytes.as_ptr())).unwrap(); + let val = f64::from_be_bytes(result[1..9].try_into().unwrap()); + assert!((val - 50.0).abs() < 1e-10, "tensor_get C 1 1 = {} expected 50.0", val); + } + } + + #[test] + fn test_tensor_add_sink() { + use crate::space::Space; + + let mut s = Space::new(); + + let mut a = SparseTensorF64::new(2); + a.set(&[0, 0], 1.0); + a.set(&[0, 1], 2.0); + s.tensors.insert(b"A".to_vec(), a); + + let mut b = SparseTensorF64::new(2); + b.set(&[0, 0], 10.0); + b.set(&[1, 0], 20.0); + s.tensors.insert(b"B".to_vec(), b); + + s.add_all_sexpr(r#" + (exec F (,) (O (tensor_add A B C))) + "#.as_bytes()).unwrap(); + + s.metta_calculus(100); + + let c = s.tensors.get(b"C".as_slice()).expect("C should exist"); + assert_eq!(c.get(&[0, 0]), Some(11.0)); + assert_eq!(c.get(&[0, 1]), Some(2.0)); + assert_eq!(c.get(&[1, 0]), Some(20.0)); + } + + #[test] + fn test_tensor_mul_sink() { + use crate::space::Space; + + let mut s = Space::new(); + + let mut a = SparseTensorF64::new(2); + a.set(&[0, 0], 3.0); + a.set(&[0, 1], 5.0); + s.tensors.insert(b"A".to_vec(), a); + + let mut b = SparseTensorF64::new(2); + b.set(&[0, 0], 2.0); + b.set(&[1, 1], 4.0); + s.tensors.insert(b"B".to_vec(), b); + + s.add_all_sexpr(r#" + (exec F (,) (O (tensor_mul A B C))) + "#.as_bytes()).unwrap(); + + s.metta_calculus(100); + + let c = s.tensors.get(b"C".as_slice()).expect("C should exist"); + assert_eq!(c.get(&[0, 0]), Some(6.0)); + assert_eq!(c.get(&[0, 1]), None); + assert_eq!(c.get(&[1, 1]), None); + } + + #[test] + fn test_tensor_free_multi_match() { + use crate::space::Space; + + let mut s = Space::new(); + + for name in [b"A".as_slice(), b"B", b"C", b"D"] { + let mut t = SparseTensorF64::new(1); + t.set(&[0], 1.0); + s.tensors.insert(name.to_vec(), t); + } + + s.add_all_sexpr(r#" + (to_free A) + (to_free B) + (to_free C) + + (exec F (, (to_free $name)) (O (tensor_free $name))) + "#.as_bytes()).unwrap(); + + s.metta_calculus(100); + + assert!(!s.tensors.contains_key(b"A".as_slice()), "A should be freed"); + assert!(!s.tensors.contains_key(b"B".as_slice()), "B should be freed"); + assert!(!s.tensors.contains_key(b"C".as_slice()), "C should be freed"); + assert!(s.tensors.contains_key(b"D".as_slice()), "D should survive"); + } +} +/* +binary ops (tensor add/mul) don't work +*/ \ No newline at end of file