diff --git a/cuthbert/factorial/README.md b/cuthbert/factorial/README.md new file mode 100644 index 0000000..8a6698c --- /dev/null +++ b/cuthbert/factorial/README.md @@ -0,0 +1,133 @@ +# Factorial State-Space Models + +A factorial state-space model is a state-space model where the dynamics distribution +factors into a product of independent distributions across factors + +$$ +p(x_t \mid x_{t-1}) = \prod_{f=1}^F p(x_t^f \mid x_{t-1}^f), +$$ + +for factorial index $f \in \{1, \ldots, F\}$. We additionally assume that observations +act locally on some subset of factors $S_t \subseteq \{1, \ldots, F\}$. + +$$ +p(y_t \mid x_t) = p(y_t \mid x_t^{S_t}). +$$ + +This motivates a factored approximation of filtering and smoothing distributions, e.g. + +$$ +p(x_t \mid y_{1:t}) = \prod_{f=1}^F p(x_t^f \mid y_{1:t}). +$$ + +A tutorial on factorial state-space models can be found in [Duffield et al](https://doi.org/10.1093/jrsssc/qlae035). + +The factorial approximation allows us to exploit significant benefits in terms of +memory, compute and parallelization. + +Note that although the dynamics are factorized, `cuthbert` does not differentiate +between `predict` and `update` (instead favouring a unified filter operation +via `filter_prepare` and `filter_combine`). Thus the dynamics and model inputs +should be specified to act on the joint local state (i.e. block diagonal +where appropriate). + + +## Factorial filtering with `cuthbert` + +Filtering in a factorial state-space model is similar to standard filtering, but with +an additional step before the filtering operation to extract the relevant +factors as well as an additional step after the filtering operation to insert the +updated factors back into the factorial state. + + +```python +from jax import tree +import cuthbert + +# Define model_inputs +model_inputs = ... + +# Define function to extract the factorial indices from model inputs +# Here we assume model_inputs is a NamedTuple with a field `factorial_inds` +get_factorial_indices = lambda mi: mi.factorial_inds + +# Build factorializer for the inference method +factorializer = cuthbert.factorial.gaussian.build_factorializer(get_factorial_indices) + +# Load inference method, with parameter extraction functions defined for factorial inference +kalman_filter = cuthbert.gaussian.kalman.build_filter( + get_init_params=get_init_params, # Init specified to generate factorial state + get_dynamics_params=get_dynamics_params, # Dynamics specified to act on joint local state + get_observation_params=get_observation_params, # Observation specified to act on joint local state +) + +# Online inference +factorial_state = kalman_filter.init_prepare(tree.map(lambda x: x[0], model_inputs)) + +for t in range(1, T): + model_inputs_t = tree.map(lambda x: x[t], model_inputs) + local_state = factorializer.extract_and_join(prev_factorial_state, model_inputs_t) + prepare_state = kalman_filter.filter_prepare(model_inputs_t) + local_joint_filtered_state = kalman_filter.filter_combine(local_state, prepare_state) + factorial_state = factorializer.marginalize_and_insert( + local_joint_filtered_state, factorial_state, model_inputs_t + ) +``` + +You can also use `cuthbert.factorial.filter` for convenient offline filtering. +Note that associative/parallel filtering is not supported for factorial filtering. + +```python +init_factorial_state, local_filter_states = cuthbert.factorial.filter( + kalman_filter, factorializer, model_inputs, output_factorial=False +) +``` + +## Factorial smoothing with `cuthbert` + +Smoothing in factorial state-space models can be performed embarrassingly parallel +across factors since the dynamics and factorial approximation are independent +across factors (the observations are fully absorbed in the filtering and +are not accessed during smoothing). + +The model inputs and filter states require some preprocessing to convert from being +single sequence with each state containing all factors into a sequence or multiple +sequences with each state corresponding to a single factor. This can be +fiddly but is left to the user for maximum freedom. Oftentimes, it is easiest to +specify different parameter functions for smoothing than filtering. + +After this preprocessing, smoothing can be performed as usual: + +```python +# Define model_inputs for a single factor +model_inputs_single_factor = ... + +# Similarly, we need to extract the filter states for the single factor we're smoothing. +filter_states_single_factor = ... + +# Load smoother, with parameter extraction functions defined for a single factor +kalman_smoother = cuthbert.gaussian.kalman.build_smoother( + get_dynamics_params=get_dynamics_params, # Dynamics specified to act on a single factor +) + +smoother_state = kalman_smoother.convert_filter_to_smoother_state( + tree.map(lambda x: x[-1], filter_states_single_factor), + model_inputs=tree.map(lambda x: x[-1], model_inputs_single_factor), +) + +for t in range(T - 1, -1, -1): + model_inputs_single_factor_t = tree.map(lambda x: x[t], model_inputs_single_factor) + filter_state_single_factor_t = tree.map(lambda x: x[t], filter_states_single_factor) + prepare_state = kalman_smoother.smoother_prepare( + filter_state_single_factor_t, model_inputs_single_factor_t + ) + smoother_state = kalman_smoother.smoother_combine(prepare_state, smoother_state) +``` + +Or directly using the `cuthbert.smoother`: + +```python +smoother_states = cuthbert.smoother( + kalman_smoother, filter_states_single_factor, model_inputs_single_factor +) +``` \ No newline at end of file diff --git a/cuthbert/factorial/__init__.py b/cuthbert/factorial/__init__.py new file mode 100644 index 0000000..ffbcef4 --- /dev/null +++ b/cuthbert/factorial/__init__.py @@ -0,0 +1,11 @@ +from cuthbert.factorial import gaussian +from cuthbert.factorial.filtering import filter +from cuthbert.factorial.types import ( + Extract, + Factorializer, + GetFactorialIndices, + Insert, + Join, + Marginalize, +) +from cuthbert.factorial.utils import serial_to_factorial, serial_to_single_factor diff --git a/cuthbert/factorial/filtering.py b/cuthbert/factorial/filtering.py new file mode 100644 index 0000000..1d0804f --- /dev/null +++ b/cuthbert/factorial/filtering.py @@ -0,0 +1,110 @@ +"""cuthbert factorial filtering interface.""" + +from jax import numpy as jnp +from jax import random, tree +from jax.lax import scan + +from cuthbert.factorial.types import Factorializer +from cuthbert.inference import Filter +from cuthbertlib.types import ArrayTree, ArrayTreeLike, KeyArray + + +def filter( + filter_obj: Filter, + factorializer: Factorializer, + model_inputs: ArrayTreeLike, + output_factorial: bool = False, + key: KeyArray | None = None, +) -> ( + ArrayTree | tuple[ArrayTree, ArrayTree] +): # TODO: Can overload this function so the type checker knows that the output is a ArrayTree if output_factorial is True and a tuple[ArrayTree, ArrayTree] if output_factorial is False + """Applies offline factorial filtering for given model inputs. + + `model_inputs` should have leading temporal dimension of length T + 1, + where T is the number of time steps excluding the initial state. + + Parallel associative filtering is not supported for factorial filtering. + + Note that if output_factorial is True, this function will output a factorial state + with first temporal dimension of length T + 1 and second factorial dimension of + length F. Many of the factors will be unchanged across timesteps where they aren't + relevant. + + Args: + filter_obj: The filter inference object. + factorializer: The factorializer object for the inference method. + model_inputs: The model inputs (with leading temporal dimension of length T + 1). + output_factorial: If True, return a single state with first temporal dimension + of length T + 1 and second factorial dimension of length F. + If False, return a tuple of states. The first being the initial state + with first dimension of length F and temporal dimension. + The second being the local states for each time step, i.e. first + dimension of length T and no factorial dimension. + key: The key for the random number generator. + + Returns: + The filtered states (NamedTuple with leading temporal dimension of length T + 1). + """ + T = tree.leaves(model_inputs)[0].shape[0] - 1 + + if key is None: + # This will throw error if used as a key, which is desired behavior + # (albeit not a useful error, we could improve this) + prepare_keys = jnp.empty(T + 1) + else: + prepare_keys = random.split(key, T + 1) + + init_model_input = tree.map(lambda x: x[0], model_inputs) + init_factorial_state = filter_obj.init_prepare( + init_model_input, key=prepare_keys[0] + ) + + prep_model_inputs = tree.map(lambda x: x[1:], model_inputs) + + def body_local(prev_factorial_state, prep_inp_and_k): + prep_inp, k = prep_inp_and_k + factorial_inds = factorializer.get_factorial_indices(prep_inp) + factorial_inds = jnp.asarray(factorial_inds) + + # Extract and join local factors into joint local state + local_state = factorializer.extract_and_join(prev_factorial_state, prep_inp) + + # Filter the joint local state + prep_state = filter_obj.filter_prepare(prep_inp, key=k) + filtered_joint_state = filter_obj.filter_combine(local_state, prep_state) + + # Marginalize and insert filtered joint local state into factorial state + local_factorial_filtered_state = factorializer.marginalize( + filtered_joint_state, len(factorial_inds) + ) + factorial_state = factorializer.insert( + local_factorial_filtered_state, prev_factorial_state, factorial_inds + ) + return factorial_state, local_factorial_filtered_state + + if output_factorial: + + def body_factorial(prev_factorial_state, prep_inp_and_k): + factorial_state, _ = body_local(prev_factorial_state, prep_inp_and_k) + return factorial_state, factorial_state + + _, factorial_states = scan( + body_factorial, + init_factorial_state, + (prep_model_inputs, prepare_keys[1:]), + ) + factorial_states = tree.map( + lambda x, y: jnp.concatenate([x[None], y]), + init_factorial_state, + factorial_states, + ) + + return factorial_states + + else: + _, local_states = scan( + body_local, + init_factorial_state, + (prep_model_inputs, prepare_keys[1:]), + ) + return init_factorial_state, local_states diff --git a/cuthbert/factorial/gaussian.py b/cuthbert/factorial/gaussian.py new file mode 100644 index 0000000..f8c5f17 --- /dev/null +++ b/cuthbert/factorial/gaussian.py @@ -0,0 +1,218 @@ +"""Factorial utilities for Kalman states.""" + +from typing import TypeVar + +from jax import numpy as jnp +from jax import tree +from jax.scipy.linalg import block_diag + +from cuthbert.factorial.types import Factorializer, GetFactorialIndices +from cuthbert.gaussian.kalman import KalmanFilterState +from cuthbert.gaussian.types import LinearizedKalmanFilterState +from cuthbertlib.linalg import block_marginal_sqrt_cov +from cuthbertlib.types import Array, ArrayLike + +KalmanState = TypeVar("KalmanState", KalmanFilterState, LinearizedKalmanFilterState) + + +## TODO: If factorial_inds is just an integer, i.e. shape (,) then factorial dimension +# should be removed in extract + + +def build_factorializer( + get_factorial_indices: GetFactorialIndices, +) -> Factorializer: + """Build a factorializer for Kalman states. + + Args: + get_factorial_indices: Function to extract the factorial indices + from model inputs. + + Returns: + Factorializer object for Kalman states with functions to extract and join + the relevant factors and marginalize and insert the updated factors. + """ + return Factorializer( + get_factorial_indices=get_factorial_indices, + extract=extract, + join=join, + marginalize=marginalize, + insert=insert, + ) + + +def extract(factorial_state: KalmanState, factorial_inds: ArrayLike) -> KalmanState: + """Extract the relevant factors from a factorial Kalman state. + + Single dimensional arrays will be treated as scalars e.g. log normalizing constants. + This means univariate problems still need to be stored with a dimension array + (e.g. means with shape (F, 1) and chol_covs with shape (F, 1, 1)). + Multidimensional arrays will be treated as arrays with shape (F, *). + In this case the factorial_inds indices will be extracted from the first + dimension and then the remaining dimensions will be preserved. + + Here F is the number of factors and d is the dimension of the state. + + Args: + factorial_state: Factorial Kalman state storing means and chol_covs + with shape (F, d) and (F, d, d) respectively. + factorial_inds: Indices of the factors to extract. Integer array. + + Returns: + Factorial Kalman state storing means and chol_covs + with shape (len(factorial_inds), d) and (len(factorial_inds), d, d). + """ + factorial_inds = jnp.asarray(factorial_inds) + new_elem = tree.map(lambda x: _extract_arr(x, factorial_inds), factorial_state.elem) + new_state = factorial_state._replace(elem=new_elem) + + if isinstance(factorial_state, LinearizedKalmanFilterState): + new_mean_prev = _extract_arr(factorial_state.mean_prev, factorial_inds) + new_state = new_state._replace(mean_prev=new_mean_prev) + + return new_state + + +def _extract_arr(arr: Array, factorial_inds: Array) -> Array: + if arr.ndim == 0 or arr.ndim == 1: + return arr + else: + return arr[factorial_inds] + + +def join(local_factorial_state: KalmanState) -> KalmanState: + """Convert a factorial Kalman state into a joint local Kalman state. + + Single dimensional arrays will be treated as scalars e.g. log normalizing constants. + This means univariate problems still need to be stored with a dimension array + (e.g. means with shape (F, 1) and chol_covs with shape (F, 1, 1)). + Two dimensional arrays will be treated as means with shape (F, d). + In this case the factorial_inds indices will be extracted from the first + dimension and then stacked into a single array. + Three dimensional arrays will be treated as chol_covs with shape (F, d, d). + In this case the factorial_inds indices will be extracted from the first + dimension and then stacked into a block diagonal array. + + Here F is the number of factors and d is the dimension of the state. + + Args: + local_factorial_state: Factorial Kalman state storing means and chol_covs + with shape (F, d) and (F, d, d) respectively. + factorial_inds: Indices of the factors to extract. Integer array. + + Returns: + Joint local Kalman state with no factorial index dimension. + """ + new_elem = tree.map(_join_arr, local_factorial_state.elem) + new_state = local_factorial_state._replace(elem=new_elem) + + if isinstance(local_factorial_state, LinearizedKalmanFilterState): + new_mean_prev = _join_arr(local_factorial_state.mean_prev) + new_state = new_state._replace(mean_prev=new_mean_prev) + + return new_state + + +def _join_arr(arr: Array) -> Array: + if arr.ndim == 0 or arr.ndim == 1: + return arr + elif arr.ndim == 2: # means + return arr.reshape(-1) + elif arr.ndim == 3: # chol_covs + return block_diag(*arr) + else: + raise ValueError(f"Array must be 3D or lower, got {arr.ndim}D") + + +def marginalize( + local_state: KalmanState, + num_factors: int, +) -> KalmanState: + """Marginalize a joint local Kalman state into a factorial Kalman state. + + Args: + local_state: Joint local Kalman state to marginalize and insert. + With means and chol_covs with shape (d * len(factorial_inds),) + and (d * len(factorial_inds), d * len(factorial_inds)) respectively. + num_factors: Number of factors to marginalize out. Integer. + + Returns: + Joint local Kalman state with no factorial index dimension. + """ + new_elem = tree.map( + lambda loc: _marginalize_arr(loc, num_factors), + local_state.elem, + ) + new_state = local_state._replace(elem=new_elem) + if isinstance(local_state, LinearizedKalmanFilterState): + new_mean_prev = _marginalize_arr(local_state.mean_prev, num_factors) + new_state = new_state._replace(mean_prev=new_mean_prev) + + return new_state + + +def _marginalize_arr(arr: Array, num_factors: int) -> Array: + if arr.ndim == 0: + return arr + elif arr.ndim == 1: # means + return arr.reshape(num_factors, -1) + elif arr.ndim == 2: # chol_covs + local_dim = arr.shape[-1] // num_factors + return block_marginal_sqrt_cov(arr, local_dim) + else: + raise ValueError(f"Array must be 1D (means) or 2D (chol_covs), got {arr.ndim}D") + + +def insert( + local_factorial_state: KalmanState, + factorial_state: KalmanState, + factorial_inds: ArrayLike, +) -> KalmanState: + """Insert a local factorial Kalman state into a factorial Kalman state. + + Single dimensional arrays will be treated as scalars e.g. log normalizing constants. + This means univariate problems still need to be stored with a dimension array + (e.g. means with shape (F, 1) and chol_covs with shape (F, 1, 1)). + Multidimensional arrays will be treated as arrays with shape (F, *). + In this case the factorial_inds indices will be inserted into the first + dimension and then the remaining dimensions will be preserved. + + Here F is the number of factors and d is the dimension of the state. + + Args: + local_factorial_state: Joint local Kalman state to marginalize and insert. + With means and chol_covs with shape (len(factorial_inds), d) + and (len(factorial_inds), d, d) respectively. + factorial_state: Factorial Kalman state storing means and chol_covs + with shape (F, d) and (F, d, d) respectively. + factorial_inds: Indices of the factors to insert. Integer array. + + Returns: + Joint local Kalman state with no factorial index dimension. + """ + factorial_inds = jnp.asarray(factorial_inds) + new_elem = tree.map( + lambda loc, glob: _insert_arr(loc, glob, factorial_inds), + local_factorial_state.elem, + factorial_state.elem, + ) + new_state = factorial_state._replace(elem=new_elem) + + if isinstance(local_factorial_state, LinearizedKalmanFilterState) and isinstance( + factorial_state, LinearizedKalmanFilterState + ): + new_mean_prev = _insert_arr( + local_factorial_state.mean_prev, factorial_state.mean_prev, factorial_inds + ) + new_state = new_state._replace(mean_prev=new_mean_prev) + + return new_state + + +def _insert_arr( + local_factorial_arr: Array, factorial_arr: Array, factorial_inds: Array +) -> Array: + if local_factorial_arr.ndim == 0 or local_factorial_arr.ndim == 1: + return local_factorial_arr + else: + return factorial_arr.at[factorial_inds].set(local_factorial_arr) diff --git a/cuthbert/factorial/types.py b/cuthbert/factorial/types.py new file mode 100644 index 0000000..4889d20 --- /dev/null +++ b/cuthbert/factorial/types.py @@ -0,0 +1,192 @@ +"""Provides types for factorial state-space models.""" + +from typing import NamedTuple, Protocol + +from jax import numpy as jnp + +from cuthbertlib.types import ArrayLike, ArrayTree, ArrayTreeLike + + +class GetFactorialIndices(Protocol): + """Protocol for getting the factorial indices.""" + + def __call__(self, model_inputs: ArrayTreeLike) -> ArrayLike: + """Extract the factorial indices from model inputs. + + Args: + model_inputs: Model inputs. + + Returns: + Indices of the factors to extract. Integer array. + """ + ... + + +class Extract(Protocol): + """Protocol for extracting the relevant factors.""" + + def __call__( + self, + factorial_state: ArrayTreeLike, + factorial_inds: ArrayLike, + ) -> ArrayTree: + """Extract factors from factorial state. + + E.g. factorial_state might encode factorial `means` with shape (F, d) and + `chol_covs` with shape (F, d, d). Then `model_inputs` tells us factors `i` and + `j` are relevant, so we extract `means[i]` and `means[j]` and `chol_covs[i]` and + `chol_covs[j]`. Thus we return `means` with shape (2, d) and `chol_covs` with + shape (2, d, d). + + Args: + factorial_state: Factorial state with factorial index as the first dimension. + factorial_inds: Indices of the factors to extract. Integer array. + + Returns: + Local factorial state with factorial dimension of length len(factorial_inds). + """ + ... + + +class Join(Protocol): + """Protocol for combining factorial states into a joint state.""" + + def __call__( + self, + local_factorial_state: ArrayTreeLike, + ) -> ArrayTree: + """Extract factors from factorial state and combine into a joint local state. + + E.g. local_factorial_state might encode factorial `means` with shape (2, d) and + `chol_covs` with shape (2, d, d). + Which is then combined into a joint state with shape (2 * d,) + and block diagonal `joint_chol_cov` with shape (2 * d, 2 * d). + + Args: + local_factorial_state: Factorial state with factorial index as the first + dimension. Typically contains only a small number of factors, as it's + applied after an `Extract` operation. + + Returns: + Joint state with no factorial index dimension. + """ + ... + + +class Marginalize(Protocol): + """Protocol for marginalizing a joint state into a factored state.""" + + def __call__( + self, + local_state: ArrayTree, + num_factors: int, + ) -> ArrayTree: + """Marginalize joint state into factored state. + + E.g. `local_state` might have shape (2 * d,) and `joint_chol_cov` + with shape (2 * d, 2 * d). Then we marginalize out the joint local state into + two factorial `means` with shape (2, d) and `chol_covs` with shape (2, d, d). + + Args: + local_state: Joint local state with no factorial index dimension. + num_factors: Number of factors to marginalize out. Integer. + This is typically equal to len(factorial_inds). + + Returns: + Factorial state with factorial index as the first dimension and + `num_factors` factors (length of first dimension). + """ + ... + + +class Insert(Protocol): + """Protocol for inserting a local factorial state into a factorial state.""" + + def __call__( + self, + local_factorial_state: ArrayTree, + factorial_state: ArrayTree, + factorial_inds: ArrayLike, + ) -> ArrayTree: + """Marginalize joint state into factored state and insert into factorial state. + + E.g. `local_factorial_state` might have shape (2, d) and `joint_chol_cov` + with shape (2, d, d). Then we insert `means[0]` and `means[1]` into + `state[i]` and `state[j]` respectively. Similarly, we insert `chol_covs[0]` and + `chol_covs[1]`. In both cases, we overwrite the existing factors in the + factorial state for `i` and `j`, leaving the other factors unchanged. + Here `i` and `j` are determined from `factorial_inds`. + + Args: + local_factorial_state: Local factorial state with factorial index as the first + dimension and `len(factorial_inds)` factors (length of first dimension). + factorial_state: Factorial state with factorial index as the first dimension. + factorial_inds: Indices of the factors to insert. Integer array. + + Returns: + Factorial state with factorial index as the first dimension. + The updated factors are inserted into the factorial state. + The remaining factors are left unchanged. + """ + ... + + +class Factorializer(NamedTuple): + """Factorializer object. + + All functions are inference method dependent (e.g. Gaussian/SMC etc), + aside from the `get_factorial_indices` function which acts purely on `model_inputs`. + + Attributes: + get_factorial_indices: Function to extract factorial indices from model inputs. + extract: Function to extract the relevant factors. + join: Function to combine factorial states into a joint state. + marginalize: Function to marginalize a joint state into a factored state. + insert: Function to insert a local factorial state into a factorial state. + """ + + get_factorial_indices: GetFactorialIndices + extract: Extract + join: Join + marginalize: Marginalize + insert: Insert + + def extract_and_join( + self, factorial_state: ArrayTreeLike, model_inputs: ArrayTreeLike + ) -> ArrayTree: + """Extract and join the relevant factors into a joint local state. + + Args: + factorial_state: Factorial state with factorial index as the first dimension. + model_inputs: Model inputs, from which the factorial indices are extracted. + + Returns: + Joint local state with no factorial index dimension. + """ + factorial_inds = self.get_factorial_indices(model_inputs) + local_factorial_state = self.extract(factorial_state, factorial_inds) + return self.join(local_factorial_state) + + def marginalize_and_insert( + self, + local_state: ArrayTree, + factorial_state: ArrayTree, + model_inputs: ArrayTreeLike, + ) -> ArrayTree: + """Marginalize and insert the relevant factors into a factorial state. + + Args: + local_state: Joint local state with no factorial index dimension. + factorial_state: Factorial state with factorial index as the first dimension. + model_inputs: Model inputs, from which the factorial indices are extracted. + + Returns: + Factorial state with factorial index as the first dimension. + The updated factors are inserted into the factorial state. + The remaining factors are left unchanged. + """ + factorial_inds = self.get_factorial_indices(model_inputs) + factorial_inds = jnp.asarray(factorial_inds) + num_factors = len(factorial_inds) + local_factorial_state = self.marginalize(local_state, num_factors) + return self.insert(local_factorial_state, factorial_state, factorial_inds) diff --git a/cuthbert/factorial/utils.py b/cuthbert/factorial/utils.py new file mode 100644 index 0000000..353da0b --- /dev/null +++ b/cuthbert/factorial/utils.py @@ -0,0 +1,133 @@ +"""Utility functions to convert between serial and factorial trees.""" + +from jax import numpy as jnp +from jax import tree, vmap + +from cuthbert.factorial.types import Extract +from cuthbertlib.types import ArrayLike, ArrayTree, ArrayTreeLike + +### TODO: Add support for an init factorial state + + +def serial_to_factorial( + extract: Extract, + serial_tree: ArrayTreeLike, + factorial_inds: ArrayLike, + init_factorial_tree: ArrayTree = None, +) -> list[ArrayTree]: + """Convert a serial tree into a list of trees, one for each factor. + + Args: + extract: Function to extract the relevant factors from the serial tree. + serial_tree: The serial tree to convert. + Each leaf of the tree should have shape (T, F, ...) where T is the number of + time steps and F is the number of factors. + Although some leaves may not have the factorial dimension F, as controlled + by the `extract` function. + factorial_inds: The indices of the factors used in each element of the serial + tree. Shape (T, F). + init_factorial_tree: Optional initial factorial tree to use, as the first + elements of the returned list. + Leaves with shape (F, ...) + + Returns: + A list of trees, one for each factor. + Length max(factorial_inds) + 1. + Each element has shape (T_i, ...) where T_i is the number of occurrences of + index i in factorial_inds (which may be zero). + """ + # TODO: This function is not very JAX-like or efficient, we may want to improve it in time. + + factorial_inds = jnp.asarray(factorial_inds) + num_factors = jnp.max(factorial_inds) + 1 + T = tree.leaves(serial_tree)[0].shape[0] + + if init_factorial_tree is None: + # Initialize factorial trees with empty tree of correct shape (for later concat) + # This can probably be improved + init_state = tree.map(lambda x: x[0], serial_tree) + init_single_factor_state = extract(init_state, jnp.array([0])) + factorial_trees = [ + tree.map(lambda x: jnp.zeros((0,) + x.shape[1:]), init_single_factor_state) + for _ in range(num_factors) + ] + else: + factorial_trees = [ + extract(init_factorial_tree, jnp.array([i])) for i in range(num_factors) + ] + # Add temporal dimension to init factorial trees + factorial_trees = [tree.map(lambda x: x[None], tr) for tr in factorial_trees] + + for t in range(T): + joint_factor_t = tree.map(lambda x: x[t], serial_tree) + local_factors_t = vmap(extract, in_axes=(None, 0))( + joint_factor_t, jnp.arange(len(factorial_inds[t])) + ) + + for j, ind in enumerate(factorial_inds[t]): + factorial_trees[ind] = tree.map( + lambda x, y: jnp.concatenate([x, y[j][None]]), + factorial_trees[ind], + local_factors_t, + ) + + return factorial_trees + + +def serial_to_single_factor( + extract: Extract, + serial_tree: ArrayTreeLike, + factorial_inds: ArrayLike, + factorial_index: int, + init_factorial_tree: ArrayTree = None, +) -> ArrayTree: + """Convert a serial tree into a single factor tree. + + Args: + extract: Function to extract the relevant factors from the serial tree. + serial_tree: The serial tree to convert. + Each leaf of the tree should have shape (T, F, ...) where T is the number of + time steps and F is the number of factors. + factorial_inds: The indices of the factors used in each element of the serial + tree. Shape (T, F). + factorial_index: Single integer index of the factor to extract. + init_factorial_tree: Optional initial factorial tree to use, as the first + elements of the returned list. + Leaves with shape (F, ...) of which only the factorial_index element will be + used. + + Returns: + A single ArrayTree with shape (T_i, ...) where T_i is the number of occurrences of + the factorial index in factorial_inds. + """ + # TODO: As above, we can improve this and make it more JAX-like + efficient. + factorial_inds = jnp.asarray(factorial_inds) + T = tree.leaves(serial_tree)[0].shape[0] + + if init_factorial_tree is None: + # Initialize factorial tree with empty tree of correct shape (for later concat) + # This can probably be improved + init_state = tree.map(lambda x: x[0], serial_tree) + init_single_factor_state = extract(init_state, jnp.array([0])) + factorial_tree = tree.map( + lambda x: jnp.zeros((0,) + x.shape[1:]), init_single_factor_state + ) + else: + factorial_tree = extract(init_factorial_tree, jnp.array([factorial_index])) + factorial_tree = tree.map(lambda x: x[None], factorial_tree) + + for t in range(T): + joint_factor_t = tree.map(lambda x: x[t], serial_tree) + local_factors_t = vmap(extract, in_axes=(None, 0))( + joint_factor_t, jnp.arange(len(factorial_inds[t])) + ) + + for j, ind in enumerate(factorial_inds[t]): + if ind == factorial_index: + factorial_tree = tree.map( + lambda x, y: jnp.concatenate([x, y[j][None]]), + factorial_tree, + local_factors_t, + ) + + return factorial_tree diff --git a/tests/cuthbert/factorial/__init__.py b/tests/cuthbert/factorial/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/cuthbert/factorial/gaussian_utils.py b/tests/cuthbert/factorial/gaussian_utils.py new file mode 100644 index 0000000..e357911 --- /dev/null +++ b/tests/cuthbert/factorial/gaussian_utils.py @@ -0,0 +1,44 @@ +import jax.numpy as jnp +from jax import random, vmap + +from cuthbertlib.kalman import generate + + +def generate_factorial_kalman_model( + seed, x_dim, y_dim, num_factors, num_factors_local, num_time_steps +): + # T = num_time_steps, F = num_factors + + key = random.key(seed) + init_key, factorial_indices_key = random.split(key, 2) + + # m0 with shape (F, x_dim) + # chol_P0 with shape (F, x_dim, x_dim) + init_keys_factorial = random.split(init_key, num_factors) + m0s, chol_P0s = vmap(generate.generate_init_model, in_axes=(0, None))( + init_keys_factorial, x_dim + ) + + # Fs with shape (T, num_factors_local * x_dim, num_factors_local * x_dim) + # cs with shape (T, num_factors_local * x_dim) + # chol_Qs with shape (T, num_factors_local * x_dim, num_factors_local * x_dim) + # Hs with shape (T, d_y, num_factors_local * x_dim) + # ds with shape (T, y_dim) + # chol_Rs with shape (T, num_factors_local * y_dim, num_factors_local * y_dim) + # ys with shape (T, d_y) + _, _, Fs, cs, chol_Qs, Hs, ds, chol_Rs, ys = generate.generate_lgssm( + seed + 1, num_factors_local * x_dim, y_dim, num_time_steps + ) + + # factorial_indices with shape (T, num_factors_local) + # Each entry is a random integer in {0, ..., num_factors - 1} + # But each row must have unique entries + def rand_unique_indices(key): + indices = random.choice( + key, jnp.arange(num_factors), (num_factors_local,), replace=False + ) + return indices + + factorial_indices_keys = random.split(factorial_indices_key, num_time_steps) + factorial_indices = vmap(rand_unique_indices)(factorial_indices_keys) + return m0s, chol_P0s, Fs, cs, chol_Qs, Hs, ds, chol_Rs, ys, factorial_indices diff --git a/tests/cuthbert/factorial/test_kalman.py b/tests/cuthbert/factorial/test_kalman.py new file mode 100644 index 0000000..1f9414f --- /dev/null +++ b/tests/cuthbert/factorial/test_kalman.py @@ -0,0 +1,279 @@ +import itertools +from typing import cast + +import chex +import jax +import jax.numpy as jnp +import pytest +from jax import Array, tree +from jax.scipy.linalg import block_diag + +import cuthbert +from cuthbert import factorial +from cuthbert.factorial.utils import serial_to_single_factor +from cuthbert.gaussian import kalman +from cuthbert.inference import Filter, Smoother +from cuthbertlib.linalg import block_marginal_sqrt_cov +from cuthbertlib.types import ArrayTree +from tests.cuthbert.factorial.gaussian_utils import generate_factorial_kalman_model +from tests.cuthbertlib.kalman.test_filtering import std_predict, std_update + + +@pytest.fixture(scope="module", autouse=True) +def config(): + jax.config.update("jax_enable_x64", True) + yield + jax.config.update("jax_enable_x64", False) + + +def load_kalman_pairwise_factorial_inference( + m0: Array, # (F, d) + chol_P0: Array, # (F, d, d) + Fs: Array, # (T, 2 * d, 2 * d) + cs: Array, # (T, 2 * d) + chol_Qs: Array, # (T, 2 * d, 2 * d) + Hs: Array, # (T, d_y, 2 * d) + ds: Array, # (T, d_y) + chol_Rs: Array, # (T, d_y, d_y) + ys: Array, # (T, d_y) + factorial_indices: Array, # (T, 2) + smoother_factorial_index: int, +) -> tuple[Filter, factorial.Factorializer, Array, Smoother, Array]: + """Builds factorial Kalman filter and smoother objects and model_inputs for a linear-Gaussian SSM.""" + + def get_init_params(model_inputs: int) -> tuple[Array, Array]: + return m0, chol_P0 + + def get_dynamics_params(model_inputs: int) -> tuple[Array, Array, Array]: + return Fs[model_inputs - 1], cs[model_inputs - 1], chol_Qs[model_inputs - 1] + + def get_observation_params(model_inputs: int) -> tuple[Array, Array, Array, Array]: + return ( + Hs[model_inputs - 1], + ds[model_inputs - 1], + chol_Rs[model_inputs - 1], + ys[model_inputs - 1], + ) + + filter = kalman.build_filter( + get_init_params, get_dynamics_params, get_observation_params + ) + + factorializer = factorial.gaussian.build_factorializer( + get_factorial_indices=lambda model_inputs: factorial_indices[model_inputs - 1] + ) + filter_model_inputs = jnp.arange(len(ys) + 1) + + # Some processing to get smoothing for a single factor + num_factors = len(m0) + d_x = m0.shape[1] + Fs_per_factor = [[] for _ in range(num_factors)] + cs_per_factor = [[] for _ in range(num_factors)] + chol_Qs_per_factor = [[] for _ in range(num_factors)] + + for i in range(1, len(ys) + 1): + h, a = factorial_indices[i - 1] + + F_h = Fs[i - 1][:d_x, :d_x] + F_a = Fs[i - 1][-d_x:, -d_x:] + c_h = cs[i - 1][:d_x] + c_a = cs[i - 1][-d_x:] + chol_Q_h, chol_Q_a = block_marginal_sqrt_cov(chol_Qs[i - 1], d_x) + Fs_per_factor[h].append(F_h) + cs_per_factor[h].append(c_h) + chol_Qs_per_factor[h].append(chol_Q_h) + + Fs_per_factor[a].append(F_a) + cs_per_factor[a].append(c_a) + chol_Qs_per_factor[a].append(chol_Q_a) + + def get_dynamics_params_single_factor( + model_inputs: int, + ) -> tuple[Array, Array, Array]: + return ( + Fs_per_factor[smoother_factorial_index][model_inputs - 1], + cs_per_factor[smoother_factorial_index][model_inputs - 1], + chol_Qs_per_factor[smoother_factorial_index][model_inputs - 1], + ) + + smoother = kalman.build_smoother( + get_dynamics_params_single_factor, + store_gain=True, + store_chol_cov_given_next=True, + ) + smoother_model_inputs = jnp.arange(len(Fs_per_factor[smoother_factorial_index]) + 1) + + return filter, factorializer, filter_model_inputs, smoother, smoother_model_inputs + + +seeds = [1, 43] +x_dims = [1, 3] +y_dims = [1, 2] +num_factors = [10, 20] +num_factors_local = [2] # number of factors to interact at each time step +num_time_steps = [1, 25] + +common_params = list( + itertools.product( + seeds, x_dims, y_dims, num_factors, num_factors_local, num_time_steps + ) +) + + +@pytest.mark.parametrize( + "seed,x_dim,y_dim,num_factors,num_factors_local,num_time_steps", common_params +) +def test_filter(seed, x_dim, y_dim, num_factors, num_factors_local, num_time_steps): + model_params = generate_factorial_kalman_model( + seed, x_dim, y_dim, num_factors, num_factors_local, num_time_steps + ) + filter_obj, factorializer, model_inputs, _, _ = ( + load_kalman_pairwise_factorial_inference( + *model_params, smoother_factorial_index=0 + ) + ) + + # True means, covs and log norm constants + fac_means = model_params[0] + fac_chol_covs = model_params[1] + fac_covs = fac_chol_covs @ fac_chol_covs.transpose(0, 2, 1) + ell = jnp.array(0.0) + + local_means = [] + local_covs = [] + ells = [] + fac_means_t_all = [fac_means] + fac_covs_t_all = [fac_covs] + for i in model_inputs[1:]: + F, c, chol_Q = ( + model_params[2][i - 1], + model_params[3][i - 1], + model_params[4][i - 1], + ) + H, d, chol_R, y = ( + model_params[5][i - 1], + model_params[6][i - 1], + model_params[7][i - 1], + model_params[8][i - 1], + ) + fac_inds = model_params[9][i - 1] + + joint_mean = fac_means[fac_inds].reshape(-1) + joint_cov = block_diag(*fac_covs[fac_inds]) + Q = chol_Q @ chol_Q.T + R = chol_R @ chol_R.T + pred_mean, pred_cov = std_predict(joint_mean, joint_cov, F, c, Q) + upd_mean, upd_cov, upd_ell = std_update(pred_mean, pred_cov, H, d, R, y) + marginal_means = upd_mean.reshape(len(fac_inds), -1) + marginal_covs = jnp.array( + [ + upd_cov[i * x_dim : (i + 1) * x_dim, i * x_dim : (i + 1) * x_dim] + for i in range(len(fac_inds)) + ] + ) + ell += upd_ell + local_means.append(marginal_means) + local_covs.append(marginal_covs) + ells.append(ell) + fac_means = fac_means.at[fac_inds].set(marginal_means) + fac_covs = fac_covs.at[fac_inds].set(marginal_covs) + fac_means_t_all.append(fac_means) + fac_covs_t_all.append(fac_covs) + + local_means = jnp.stack(local_means) + local_covs = jnp.stack(local_covs) + ells = jnp.stack(ells) + fac_means_t_all = jnp.stack(fac_means_t_all) + fac_covs_t_all = jnp.stack(fac_covs_t_all) + + # Check output_factorial = False + init_state, local_filter_states = factorial.filter( + filter_obj, factorializer, model_inputs, output_factorial=False + ) + local_filter_covs = ( + local_filter_states.chol_cov + @ local_filter_states.chol_cov.transpose(0, 1, 3, 2) + ) + chex.assert_trees_all_close( + (init_state.mean, init_state.chol_cov), (model_params[0], model_params[1]) + ) + chex.assert_trees_all_close( + (local_means, local_covs, ells), + ( + local_filter_states.mean, + local_filter_covs, + local_filter_states.log_normalizing_constant, + ), + ) + + # Check output_factorial = True + factorial_filtering_states = factorial.filter( + filter_obj, factorializer, model_inputs, output_factorial=True + ) + + factorial_filtering_states = cast(ArrayTree, factorial_filtering_states) + factorial_filtering_covs = ( + factorial_filtering_states.chol_cov + @ factorial_filtering_states.chol_cov.transpose(0, 1, 3, 2) + ) + chex.assert_trees_all_close( + (fac_means_t_all, fac_covs_t_all), + (factorial_filtering_states.mean, factorial_filtering_covs), + ) + chex.assert_trees_all_close( + ells, factorial_filtering_states.log_normalizing_constant[1:] + ) + + +smoother_indices = [0, 1, 5] + +common_smoother_params = [ + (*params, smoother_idx) + for params in common_params + for smoother_idx in smoother_indices +] + + +@pytest.mark.parametrize( + "seed,x_dim,y_dim,num_factors,num_factors_local,num_time_steps,smoother_factorial_index", + common_smoother_params, +) +def test_smoother( + seed, + x_dim, + y_dim, + num_factors, + num_factors_local, + num_time_steps, + smoother_factorial_index, +): + model_params = generate_factorial_kalman_model( + seed, x_dim, y_dim, num_factors, num_factors_local, num_time_steps + ) + filter_obj, factorializer, filter_model_inputs, smoother, smoother_model_inputs = ( + load_kalman_pairwise_factorial_inference( + *model_params, smoother_factorial_index=smoother_factorial_index + ) + ) + + # Check output_factorial = False + init_state, local_filter_states = factorial.filter( + filter_obj, factorializer, filter_model_inputs, output_factorial=False + ) + + # Convert to local smoother states + factorial_inds = model_params[-1] + local_filter_states_single_factor = serial_to_single_factor( + factorializer.extract, + local_filter_states, + factorial_inds, + smoother_factorial_index, + init_factorial_tree=init_state, + ) + + # Smooth + smoother_states = cuthbert.smoother( + smoother, local_filter_states_single_factor, smoother_model_inputs + ) + + ### TODO: Finish test diff --git a/tests/cuthbert/factorial/test_utils.py b/tests/cuthbert/factorial/test_utils.py new file mode 100644 index 0000000..e0711e6 --- /dev/null +++ b/tests/cuthbert/factorial/test_utils.py @@ -0,0 +1,71 @@ +import chex +import jax.numpy as jnp +from jax import tree + +from cuthbert.factorial.utils import serial_to_factorial, serial_to_single_factor + +### TODO: test init_state + + +def extract(factorial_state, factorial_inds): + return tree.map(lambda x: x[factorial_inds], factorial_state) + + +def test_serial_to_factorial_groups_values_by_index_in_order(): + serial_tree = { + "x": jnp.array( + [ + [[10.0], [11.0]], + [[20.0], [21.0]], + [[30.0], [31.0]], + ] + ), + "y": jnp.array( + [ + [[1.0, 2.0], [3.0, 4.0]], + [[5.0, 6.0], [7.0, 8.0]], + [[9.0, 10.0], [11.0, 12.0]], + ] + ), + } + factorial_inds = jnp.array([[0, 2], [1, 0], [2, 1]]) + + # Build expected trees from serial traversal order in one pass. + expected_trees = [ + { + "x": jnp.zeros((0,) + serial_tree["x"].shape[2:]), + "y": jnp.zeros((0,) + serial_tree["y"].shape[2:]), + } + for _ in range(3) + ] + for t, inds in enumerate(factorial_inds): + for j, ind in enumerate(inds): + expected_trees[ind]["x"] = jnp.concatenate( + [expected_trees[ind]["x"], serial_tree["x"][t, j][None]], + ) + expected_trees[ind]["y"] = jnp.concatenate( + [expected_trees[ind]["y"], serial_tree["y"][t, j][None]], + ) + + factorial_trees = serial_to_factorial(extract, serial_tree, factorial_inds) + + assert len(factorial_trees) == 3 + for factor, actual_tree in enumerate(factorial_trees): + chex.assert_trees_all_close(actual_tree, expected_trees[factor]) + + +def test_serial_to_single_factor_matches_corresponding_factorial_tree(): + serial_tree = jnp.array( + [ + [[1.0], [2.0]], + [[3.0], [4.0]], + ] + ) + factorial_inds = jnp.array([[1, 0], [0, 1]]) + + all_factors = serial_to_factorial(extract, serial_tree, factorial_inds) + factor_1 = serial_to_single_factor( + extract, serial_tree, factorial_inds, factorial_index=1 + ) + + chex.assert_trees_all_close(factor_1, all_factors[1])