Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions cuthbert/discrete/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def init_prepare(
model_inputs: ArrayTreeLike,
get_init_dist: GetInitDist,
get_obs_lls: GetObsLogLikelihoods,
init_likelihood: bool = True,
key: KeyArray | None = None,
) -> DiscreteFilterState:
"""Prepare the initial state for the filter.
Expand All @@ -81,18 +82,25 @@ def init_prepare(
model_inputs: Model inputs.
get_init_dist: Function to get initial state probabilities m_i = p(x_0 = i).
get_obs_lls: Function to get observation log likelihoods b_i = log p(y_t | x_t = i).
init_likelihood: Whether to do a Bayesian update on the initial state.
I.e. whether an observation is included at the first time point.
key: JAX random key - not used.

Returns:
Prepared state for the filter.
"""
model_inputs = tree.map(lambda x: jnp.asarray(x), model_inputs)
init_dist = get_init_dist(model_inputs)
obs_lls = get_obs_lls(model_inputs)
f, log_g = filtering.condition_on_obs(init_dist, obs_lls)
N = init_dist.shape[0]
f *= jnp.ones((N, N))
log_g *= jnp.ones(N)
if init_likelihood:
obs_lls = get_obs_lls(model_inputs)
f, log_g = filtering.condition_on_obs(init_dist, obs_lls)
N = init_dist.shape[0]
f *= jnp.ones((N, N))
log_g *= jnp.ones(N)
else:
f = init_dist
log_g = jnp.zeros(init_dist.shape[0])

return DiscreteFilterState(
elem=filtering.FilterScanElement(f, log_g), model_inputs=model_inputs
)
Expand Down
7 changes: 6 additions & 1 deletion cuthbert/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def filter(
filter_obj: Filter,
model_inputs: ArrayTreeLike,
parallel: bool = False,
init_likelihood: bool = True,
key: KeyArray | None = None,
) -> ArrayTree:
"""Applies offline filtering given a filter object and model inputs.
Expand All @@ -26,6 +27,8 @@ def filter(
model_inputs: The model inputs (with leading temporal dimension of length T + 1).
parallel: Whether to run the filter in parallel.
Requires `filter.associative_filter` to be `True`.
init_likelihood: Whether to do a Bayesian update on the initial state.
I.e. whether an observation is included at the first time point.
key: The key for the random number generator.

Returns:
Expand All @@ -46,7 +49,9 @@ def filter(
prepare_keys = random.split(key, T + 1)

init_model_input = tree.map(lambda x: x[0], model_inputs)
init_state = filter_obj.init_prepare(init_model_input, key=prepare_keys[0])
init_state = filter_obj.init_prepare(
init_model_input, init_likelihood=init_likelihood, key=prepare_keys[0]
)

prep_model_inputs = tree.map(lambda x: x[1:], model_inputs)

Expand Down
13 changes: 10 additions & 3 deletions cuthbert/gaussian/kalman.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def init_prepare(
model_inputs: ArrayTreeLike,
get_init_params: GetInitParams,
get_observation_params: GetObservationParams,
init_likelihood: bool = True,
key: KeyArray | None = None,
) -> KalmanFilterState:
"""Prepare the initial state for the Kalman filter.
Expand All @@ -144,17 +145,23 @@ def init_prepare(
model_inputs: Model inputs.
get_init_params: Function to get m0, chol_P0 from model inputs.
get_observation_params: Function to get observation parameters, H, d, chol_R, y.
init_likelihood: Whether to do a Bayesian update on the initial state.
I.e. whether an observation is included at the first time point.
key: JAX random key - not used.

Returns:
State for the Kalman filter.
Contains mean and chol_cov (generalised Cholesky factor of covariance).
"""
model_inputs = tree.map(lambda x: jnp.asarray(x), model_inputs)
m0, chol_P0 = get_init_params(model_inputs)
H, d, chol_R, y = get_observation_params(model_inputs)
m, chol_P = get_init_params(model_inputs)
ell = jnp.array(0.0)

if init_likelihood:
H, d, chol_R, y = get_observation_params(model_inputs)

(m, chol_P), ell = filtering.update(m, chol_P, H, d, chol_R, y)

(m, chol_P), ell = filtering.update(m0, chol_P0, H, d, chol_R, y)
elem = filtering.FilterScanElement(
A=jnp.zeros_like(chol_P),
b=m,
Expand Down
61 changes: 2 additions & 59 deletions cuthbert/gaussian/moments/associative_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from jax import numpy as jnp

from cuthbert.gaussian.kalman import GetInitParams
from cuthbert.gaussian.moments import non_associative_filter
from cuthbert.gaussian.moments.types import GetDynamicsMoments, GetObservationMoments
from cuthbert.gaussian.types import (
LinearizedKalmanFilterState,
Expand All @@ -16,65 +17,7 @@
KeyArray,
)


def init_prepare(
model_inputs: ArrayTreeLike,
get_init_params: GetInitParams,
get_observation_params: GetObservationMoments,
key: KeyArray | None = None,
) -> LinearizedKalmanFilterState:
"""Prepare the initial state for the linearized moments Kalman filter.

Args:
model_inputs: Model inputs.
get_init_params: Function to get m0, chol_P0 from model inputs.
get_observation_params: Function to get observation conditional mean,
(generalised) Cholesky covariance function, linearization point and
observation.
key: JAX random key - not used.

Returns:
State for the linearized moments Kalman filter.
Contains mean, chol_cov (generalised Cholesky factor of covariance)
and log_normalizing_constant.
"""
model_inputs = tree.map(lambda x: jnp.asarray(x), model_inputs)
m0, chol_P0 = get_init_params(model_inputs)

prior_state = LinearizedKalmanFilterState(
elem=filtering.FilterScanElement(
A=jnp.zeros_like(chol_P0),
b=m0,
U=chol_P0,
eta=jnp.zeros_like(m0),
Z=jnp.zeros_like(chol_P0),
ell=jnp.array(0.0),
),
model_inputs=model_inputs,
mean_prev=dummy_tree_like(m0),
)

mean_and_chol_cov_func, linearization_point, y = get_observation_params(
prior_state, model_inputs
)

H, d, chol_R = linearize_moments(mean_and_chol_cov_func, linearization_point)
(m, chol_P), ell = filtering.update(m0, chol_P0, H, d, chol_R, y)

elem = filtering.FilterScanElement(
A=jnp.zeros_like(chol_P),
b=m,
U=chol_P,
eta=jnp.zeros_like(m),
Z=jnp.zeros_like(chol_P),
ell=ell,
)

return LinearizedKalmanFilterState(
elem=elem,
model_inputs=model_inputs,
mean_prev=dummy_tree_like(m),
)
init_prepare = non_associative_filter.init_prepare


def filter_prepare(
Expand Down
45 changes: 27 additions & 18 deletions cuthbert/gaussian/moments/non_associative_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def init_prepare(
model_inputs: ArrayTreeLike,
get_init_params: GetInitParams,
get_observation_params: GetObservationMoments,
init_likelihood: bool = True,
key: KeyArray | None = None,
) -> LinearizedKalmanFilterState:
"""Prepare the initial state for the linearized moments Kalman filter.
Expand All @@ -26,6 +27,8 @@ def init_prepare(
get_observation_params: Function to get observation conditional mean,
(generalised) Cholesky covariance function, linearization point and
observation.
init_likelihood: Whether to do a Bayesian update on the initial state.
I.e. whether an observation is included at the first time point.
key: JAX random key - not used.

Returns:
Expand All @@ -36,25 +39,31 @@ def init_prepare(
model_inputs = tree.map(lambda x: jnp.asarray(x), model_inputs)
m0, chol_P0 = get_init_params(model_inputs)

prior_state = LinearizedKalmanFilterState(
elem=filtering.FilterScanElement(
A=jnp.zeros_like(chol_P0),
b=m0,
U=chol_P0,
eta=jnp.zeros_like(m0),
Z=jnp.zeros_like(chol_P0),
ell=jnp.array(0.0),
),
model_inputs=model_inputs,
mean_prev=dummy_tree_like(m0),
)

mean_and_chol_cov_func, linearization_point, y = get_observation_params(
prior_state, model_inputs
)
if init_likelihood:
prior_state = LinearizedKalmanFilterState(
elem=filtering.FilterScanElement(
A=jnp.zeros_like(chol_P0),
b=m0,
U=chol_P0,
eta=jnp.zeros_like(m0),
Z=jnp.zeros_like(chol_P0),
ell=jnp.array(0.0),
),
model_inputs=model_inputs,
mean_prev=dummy_tree_like(m0),
)

mean_and_chol_cov_func, linearization_point, y = get_observation_params(
prior_state, model_inputs
)

H, d, chol_R = linearize_moments(mean_and_chol_cov_func, linearization_point)
(m, chol_P), ell = filtering.update(m0, chol_P0, H, d, chol_R, y)
else:
m = m0
chol_P = chol_P0
ell = jnp.array(0.0)

H, d, chol_R = linearize_moments(mean_and_chol_cov_func, linearization_point)
(m, chol_P), ell = filtering.update(m0, chol_P0, H, d, chol_R, y)
return linearized_kalman_filter_state_dummy_elem(
mean=m,
chol_cov=chol_P,
Expand Down
85 changes: 2 additions & 83 deletions cuthbert/gaussian/taylor/associative_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from jax import eval_shape, tree
from jax import numpy as jnp

from cuthbert.gaussian.taylor import non_associative_filter
from cuthbert.gaussian.taylor.non_associative_filter import process_observation
from cuthbert.gaussian.taylor.types import (
GetDynamicsLogDensity,
Expand All @@ -20,89 +21,7 @@
KeyArray,
)


def init_prepare(
model_inputs: ArrayTreeLike,
get_init_log_density: GetInitLogDensity,
get_observation_func: GetObservationFunc,
rtol: float | None = None,
ignore_nan_dims: bool = False,
key: KeyArray | None = None,
) -> LinearizedKalmanFilterState:
"""Prepare the initial state for the linearized Taylor Kalman filter.

Args:
model_inputs: Model inputs.
get_init_log_density: Function that returns log density log p(x_0)
and linearization point.
get_observation_func: Function that returns either
- An observation log density
function log p(y_0 | x_0) as well as points x_0 and y_0
to linearize around.
- A log potential function log G(x_0) and a linearization point x_0.
rtol: The relative tolerance for the singular values of precision matrices
when passed to `symmetric_inv_sqrt` during linearization.
Cutoff for small singular values; singular values smaller than
`rtol * largest_singular_value` are treated as zero.
The default is determined based on the floating point precision of the dtype.
See https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.pinv.html.
ignore_nan_dims: Whether to treat dimensions with NaN on the diagonal of the
precision matrices (found via linearization) as missing and ignore all rows
and columns associated with them.
key: JAX random key - not used.

Returns:
State for the linearized Taylor Kalman filter.
Contains mean, chol_cov (generalised Cholesky factor of covariance)
and log_normalizing_constant.
"""
model_inputs = tree.map(lambda x: jnp.asarray(x), model_inputs)
init_log_density, linearization_point = get_init_log_density(model_inputs)

_, m0, chol_P0 = linearize_log_density(
lambda _, x: init_log_density(x),
linearization_point,
linearization_point,
rtol=rtol,
ignore_nan_dims=ignore_nan_dims,
)

prior_state = LinearizedKalmanFilterState(
elem=filtering.FilterScanElement(
A=jnp.zeros_like(chol_P0),
b=m0,
U=chol_P0,
eta=jnp.zeros_like(m0),
Z=jnp.zeros_like(chol_P0),
ell=jnp.array(0.0),
),
model_inputs=model_inputs,
mean_prev=dummy_tree_like(m0),
)

observation_output = get_observation_func(prior_state, model_inputs)
H, d, chol_R, observation = process_observation(
observation_output,
rtol=rtol,
ignore_nan_dims=ignore_nan_dims,
)

(m, chol_P), ell = filtering.update(m0, chol_P0, H, d, chol_R, observation)

elem = filtering.FilterScanElement(
A=jnp.zeros_like(chol_P),
b=m,
U=chol_P,
eta=jnp.zeros_like(m),
Z=jnp.zeros_like(chol_P),
ell=ell,
)

return LinearizedKalmanFilterState(
elem=elem,
model_inputs=model_inputs,
mean_prev=dummy_tree_like(m),
)
init_prepare = non_associative_filter.init_prepare


def filter_prepare(
Expand Down
Loading
Loading