Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
584 changes: 584 additions & 0 deletions docs/deep_dives/l63_forecast_examples.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions dynestyx/discretizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ def _sample_ds(
observation_model=dynamics.observation_model,
control_model=dynamics.control_model,
control_dim=dynamics.control_dim,
t0=dynamics.t0,
)
return fwd(
name,
Expand Down
2 changes: 2 additions & 0 deletions dynestyx/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def sample(
*,
obs_times: jax.Array,
obs_values: jax.Array | None = None,
predict_times: jax.Array | None = None,
ctrl_times: jax.Array | None = None,
ctrl_values: jax.Array | None = None,
**kwargs,
Expand All @@ -42,6 +43,7 @@ def sample(
dynamics: Dynamical model to sample from.
obs_times: Times at which to sample the observations.
obs_values: Values of the observations at the given times.
predict_times: Optional forecasting times, strictly after all observation times.
ctrl_times: Times at which to sample the controls.
ctrl_values: Values of the controls at the given times.
**kwargs: Additional keyword arguments.
Expand Down
19 changes: 19 additions & 0 deletions dynestyx/inference/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
)
from dynestyx.models import DynamicalModel
from dynestyx.types import FunctionOfTime
from dynestyx.utils import _validate_predict_times, _validate_t0_alignment

type SSMType = ContDiscreteNonlinearGaussianSSM | ContDiscreteNonlinearSSM

Expand All @@ -50,6 +51,7 @@ def _sample_ds(
*,
obs_times=None,
obs_values=None,
predict_times=None,
ctrl_times=None,
ctrl_values=None,
**kwargs,
Expand All @@ -59,6 +61,7 @@ def _sample_ds(
dynamics,
obs_times=obs_times,
obs_values=obs_values,
predict_times=predict_times,
ctrl_times=ctrl_times,
ctrl_values=ctrl_values,
**kwargs,
Expand All @@ -70,6 +73,7 @@ def _sample_ds(
dynamics,
obs_times=obs_times,
obs_values=obs_values,
predict_times=predict_times,
ctrl_times=ctrl_times,
ctrl_values=ctrl_values,
**kwargs,
Expand All @@ -82,6 +86,7 @@ def _add_log_factors(
*,
obs_times=None,
obs_values=None,
predict_times=None,
ctrl_times=None,
ctrl_values=None,
**kwargs,
Expand Down Expand Up @@ -162,6 +167,7 @@ def _add_log_factors(
*,
obs_times: jax.Array | None = None,
obs_values: jax.Array | None = None,
predict_times: jax.Array | None = None,
ctrl_times=None,
ctrl_values=None,
**kwargs,
Expand All @@ -180,6 +186,11 @@ def _add_log_factors(
if obs_times is None or obs_values is None:
raise ValueError("obs_times and obs_values are required for filtering.")

_validate_t0_alignment(
dynamics, obs_times, predict_times, require_obs_t0_match=True
)
_validate_predict_times(obs_times, predict_times)

config = (
self.filter_config
if self.filter_config is not None
Expand All @@ -202,6 +213,7 @@ def _add_log_factors(
key=key,
obs_times=obs_times,
obs_values=obs_values,
predict_times=predict_times,
ctrl_times=ctrl_times,
ctrl_values=ctrl_values,
**kwargs,
Expand All @@ -214,6 +226,7 @@ def _add_log_factors(
config, # type: ignore[arg-type]
obs_times=obs_times,
obs_values=obs_values,
predict_times=predict_times,
ctrl_times=ctrl_times,
ctrl_values=ctrl_values,
**kwargs,
Expand All @@ -226,6 +239,7 @@ def _add_log_factors(
key=key,
obs_times=obs_times,
obs_values=obs_values,
predict_times=predict_times,
ctrl_times=ctrl_times,
ctrl_values=ctrl_values,
**kwargs,
Expand All @@ -246,6 +260,7 @@ def _filter_discrete_time(
*,
obs_times: jax.Array,
obs_values: jax.Array,
predict_times: jax.Array | None = None,
ctrl_times=None,
ctrl_values=None,
**kwargs,
Expand All @@ -272,6 +287,7 @@ def _filter_discrete_time(
filter_config,
obs_times=obs_times,
obs_values=obs_values,
predict_times=predict_times,
ctrl_times=ctrl_times,
ctrl_values=ctrl_values,
**kwargs,
Expand All @@ -284,6 +300,7 @@ def _filter_discrete_time(
key=key,
obs_times=obs_times,
obs_values=obs_values,
predict_times=predict_times,
ctrl_times=ctrl_times,
ctrl_values=ctrl_values,
**kwargs,
Expand All @@ -300,6 +317,7 @@ def _filter_continuous_time(
*,
obs_times: jax.Array,
obs_values: jax.Array,
predict_times: jax.Array | None = None,
ctrl_times=None,
ctrl_values=None,
**kwargs,
Expand All @@ -324,6 +342,7 @@ def _filter_continuous_time(
key=key,
obs_times=obs_times,
obs_values=obs_values,
predict_times=predict_times,
ctrl_times=ctrl_times,
ctrl_values=ctrl_values,
**kwargs,
Expand Down
97 changes: 88 additions & 9 deletions dynestyx/inference/integrations/cd_dynamax/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
ContDiscreteNonlinearGaussianSSM,
ContDiscreteNonlinearSSM,
)
from cd_dynamax.src.continuous_discrete_linear_gaussian_ssm.cdlgssm_utils import (
GSSMForecast,
)
from cd_dynamax.src.continuous_discrete_linear_gaussian_ssm.models import (
PosteriorGSSMFiltered,
)
Expand All @@ -29,6 +32,7 @@
_should_record_field,
_validate_control_dim,
_validate_controls,
_validate_predict_times,
)

type SSMType = ContDiscreteNonlinearGaussianSSM | ContDiscreteNonlinearSSM
Expand Down Expand Up @@ -142,29 +146,60 @@ def _add_filter_sites(
numpyro.deterministic(f"{name}_filtered_states_cov_diag", diag_cov)


def _to_filter_and_forecast_kwargs(
filter_kwargs: dict, predict_times_arr: jax.Array
) -> dict:
"""Translate cd_dynamax `filter(...)` kwargs into `filter_and_forecast(...)` kwargs."""
shared = dict(filter_kwargs)
emissions_filter = shared.pop("emissions")
t_emissions_filter = shared.pop("t_emissions")
inputs_filter = shared.pop("inputs")
return {
"emissions_filter": emissions_filter,
"t_emissions_filter": t_emissions_filter,
"t_emissions_forecast": predict_times_arr,
"inputs_filter": inputs_filter,
"inputs_forecast": None,
**shared,
}


def _run_linear_kf(
name: str,
dynamics: DynamicalModel,
obs_times,
obs_values,
ctrl_values,
predict_times,
filter_config: ContinuousTimeKFConfig,
) -> PosteriorGSSMFiltered:
) -> tuple[PosteriorGSSMFiltered, GSSMForecast | None]:
"""Run exact continuous-discrete KF (AffineLinearDrift + constant diffusion + LinearGaussianObservation)."""
params = dsx_to_cdlgssm_params(dynamics)
cd_model = ContDiscreteLinearGaussianSSM(
state_dim=dynamics.state_dim,
emission_dim=dynamics.observation_dim,
input_dim=dynamics.control_dim,
)
if predict_times is not None and len(predict_times) > 0:
filtered, forecasted = cd_model.filter_and_forecast(
params=params,
emissions_filter=obs_values,
t_emissions_filter=obs_times,
t_emissions_forecast=predict_times,
inputs_filter=ctrl_values,
inputs_forecast=None,
warn=filter_config.warn,
)
return filtered, forecasted

filtered = cd_model.filter(
params=params,
emissions=obs_values,
t_emissions=obs_times,
inputs=ctrl_values,
warn=filter_config.warn,
)
return filtered
return filtered, None


def run_continuous_filter(
Expand All @@ -175,6 +210,7 @@ def run_continuous_filter(
*,
obs_times: jax.Array,
obs_values: jax.Array,
predict_times: jax.Array | None = None,
ctrl_times=None,
ctrl_values=None,
**kwargs,
Expand All @@ -195,6 +231,16 @@ def run_continuous_filter(
obs_times_arr = jnp.asarray(obs_times)
if obs_times_arr.ndim == 1:
obs_times_arr = obs_times_arr[:, None]
predict_times_arr = None
if predict_times is not None:
predict_times_arr = jnp.asarray(predict_times)
if predict_times_arr.ndim == 1:
predict_times_arr = predict_times_arr[:, None]

_validate_predict_times(
jnp.ravel(obs_times_arr),
None if predict_times_arr is None else jnp.ravel(predict_times_arr),
)
_validate_controls(jnp.ravel(obs_times_arr), ctrl_times, ctrl_values)
_validate_control_dim(dynamics, ctrl_values)

Expand All @@ -206,8 +252,14 @@ def run_continuous_filter(
)

if isinstance(filter_config, ContinuousTimeKFConfig):
filtered = _run_linear_kf(
name, dynamics, obs_times_arr, obs_values, ctrl_vals, filter_config
filtered, forecasted = _run_linear_kf(
name,
dynamics,
obs_times_arr,
obs_values,
ctrl_vals,
predict_times_arr,
filter_config,
)
else:
if isinstance(
Expand All @@ -233,10 +285,37 @@ def run_continuous_filter(
)

params, _ = dsx_to_cd_dynamax(dynamics, cd_model=cd_dynamax_model)
filter_kwargs = _config_to_cd_dynamax_filter_kwargs(
filter_config, params, obs_values, obs_times_arr, ctrl_vals, key
)

filtered = cd_dynamax_model.filter(**filter_kwargs) # type: ignore
if predict_times_arr is not None and len(predict_times_arr) > 0:
if not hasattr(cd_dynamax_model, "filter_and_forecast"):
raise ValueError(
"predict_times is not supported for this CD-Dynamax model "
f"({type(cd_dynamax_model).__name__}). "
"Only CDNLGSSM/CDLGSSM backends currently support forecasting."
)
filter_kwargs = _config_to_cd_dynamax_filter_kwargs(
filter_config, params, obs_values, obs_times_arr, ctrl_vals, key
)
forecast_kwargs = _to_filter_and_forecast_kwargs(
filter_kwargs, predict_times_arr
)
filtered, forecasted = cd_dynamax_model.filter_and_forecast( # type: ignore[attr-defined]
**forecast_kwargs
)
else:
filter_kwargs = _config_to_cd_dynamax_filter_kwargs(
filter_config, params, obs_values, obs_times_arr, ctrl_vals, key
)
filtered = cd_dynamax_model.filter(**filter_kwargs) # type: ignore
forecasted = None

_add_filter_sites(name, filter_config, filtered)
if forecasted is not None:
if forecasted.forecasted_state_means is not None:
numpyro.deterministic(
f"{name}_forecasted_state_means", forecasted.forecasted_state_means
)
if forecasted.forecasted_state_covariances is not None:
numpyro.deterministic(
f"{name}_forecasted_state_covs",
forecasted.forecasted_state_covariances,
)
Loading