Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
0369682
wip
DanWaxman Mar 5, 2026
2bb8f63
wip formatted
DanWaxman Mar 5, 2026
205665f
removing unneeded validations since they are done at handler level. s…
mattlevine22 Mar 10, 2026
e606fd1
add wip filtered predictions for sde
DanWaxman Mar 10, 2026
2570629
fmt
DanWaxman Mar 10, 2026
b01ff1f
Merge main
DanWaxman Mar 10, 2026
8aca61e
wip (use dynamicalmodel.t0 to start simulations)
mattlevine22 Mar 10, 2026
ca22296
dealing with length-1 simulations
mattlevine22 Mar 10, 2026
710f2a1
updating tests to use predict times and map onto new sites f_states, …
mattlevine22 Mar 10, 2026
2eaa3aa
more test restrictions and fix the shape of deltas
mattlevine22 Mar 10, 2026
61d2507
small bugfix/simplification
mattlevine22 Mar 12, 2026
47d78c6
giant/bold cursor-based commit. a lot of new predict_times functionality
mattlevine22 Mar 13, 2026
559e930
update simulator predict-mode to get conntrol segments (needed in dis…
mattlevine22 Mar 15, 2026
b46a1d6
final linter fixes and clean up notebooks more
mattlevine22 Mar 15, 2026
a640eed
fixing tests with missing predict_times
mattlevine22 Mar 16, 2026
3874cb6
remove unneeded script
mattlevine22 Mar 16, 2026
7c59c72
add interpolation between obs_times via predict_times
mattlevine22 Mar 16, 2026
b8a77ad
make sure all table cases work
mattlevine22 Mar 16, 2026
c89a033
support num_samples>1 and n_simulations>1 and update notebooks
mattlevine22 Mar 16, 2026
971c80d
re-run MLL nb
mattlevine22 Mar 16, 2026
45ee1bd
re-ran notebook
mattlevine22 Mar 16, 2026
50940db
Merge branch 'main' into dw-ml-predict-times
mattlevine22 Mar 16, 2026
93463cd
quickstart uses more data for better MCMC (slower now); also added qu…
mattlevine22 Mar 16, 2026
41d506a
simplify code and add trailing dim for HMMs
mattlevine22 Mar 16, 2026
d77331d
remove unneeded fallback
mattlevine22 Mar 16, 2026
6563842
please the lint
mattlevine22 Mar 16, 2026
62aab00
simplifying simulators.py
mattlevine22 Mar 16, 2026
b1d4821
modularizing _step in DiscreteTimeSimulator
mattlevine22 Mar 16, 2026
bb9def6
remove unused function
mattlevine22 Mar 16, 2026
99f0577
only check n_sim once;
mattlevine22 Mar 16, 2026
d164210
simplify BaseSimulator _sample_ds
mattlevine22 Mar 16, 2026
0ebccc1
update api reference documentation
mattlevine22 Mar 16, 2026
837abcc
Merge branch 'main' into dw-ml-predict-times
DanWaxman Mar 17, 2026
31cda7b
Update dynestyx/inference/integrations/cd_dynamax/discrete.py
mattlevine22 Mar 18, 2026
2cb67e0
Update dynestyx/inference/integrations/cd_dynamax/discrete.py
mattlevine22 Mar 18, 2026
2faa0e9
add back args comment and fix indent bug in cd-dynamax/discrete.py
mattlevine22 Mar 18, 2026
3ec4f14
t1->T1 and particle_to_deltas utility
mattlevine22 Mar 18, 2026
3666c41
improving comments in simulators.py/utils.py and updating api reference
mattlevine22 Mar 18, 2026
e3d1fea
Merge branch 'main' into dw-ml-predict-times
mattlevine22 Mar 18, 2026
1e2be25
Merge branch 'main' into dw-ml-predict-times
mattlevine22 Mar 18, 2026
1473c48
merge with main and update blackjax tests for predict_times compatibi…
mattlevine22 Mar 18, 2026
a4f2e88
improve readability of simulators.py's predict_times segmentation str…
mattlevine22 Mar 18, 2026
6b83ee4
Merge branch 'main' into dw-ml-predict-times
DanWaxman Mar 18, 2026
628292f
Merge branch 'main' into dw-ml-predict-times
mattlevine22 Mar 20, 2026
9f54fb8
t1 -> T1
mattlevine22 Mar 20, 2026
057409b
Faster tests (#162)
mattlevine22 Mar 20, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 36 additions & 11 deletions docs/api_reference/developer/simulators.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,28 @@
# Overview

Simulators (also called *unrollers*) turn a `DynamicalModel` into explicit NumPyro
`sample` sites for latent states and observations on a provided time grid.
`sample` sites for latent states and observations on a chosen time grid.

!!! note "Context"
!!! note "When to use each time argument"
- **`obs_times` and `obs_values` must be provided together**:
- `obs_times` defines where observation sample sites (`y_t`) live.
- `obs_values` provides conditioning values for those sites via `obs=...`.
- Typical use: observed-data simulation/inference on a known observation grid.
- **`predict_times`**: use this when you want rollout trajectories at specific
times for simulation and/or post-filter rollout.
- In filter-rollout mode, predictions are generated at `predict_times` from
filtered posteriors.
- Typical use: forward simulation, forecasting, or dense trajectories for
visualization.
- **If both are provided**:
- `obs_times` controls filtering/conditioning points.
- `predict_times` controls where predicted trajectories are reported.
- **If both are omitted**: simulator does not run and adds no deterministic sites.

!!! note "Context and caveats"
- **NumPyro context required**: simulators call `numpyro.sample(...)` and draw
randomness via NumPyro PRNG keys, so they must run inside a NumPyro model
(or a `numpyro.handlers.seed(...)` context).
- **`obs_times` is required**: simulators only run when observation times are
provided (e.g. `dsx.sample(..., DynamicalModel(...), obs_times=...)`), because
those times define the trajectory grid.
- **Conditioning is optional**: if `obs_values` is provided (e.g.
`dsx.sample(..., DynamicalModel(...), obs_times=..., obs_values=...)`),
simulators pass these values as `obs=...` to the observation `numpyro.sample`
Expand All @@ -20,13 +33,25 @@ Simulators (also called *unrollers*) turn a `DynamicalModel` into explicit NumPy
`SDESimulator` is usually a poor inference strategy.

!!! note "Deterministic sites"
When a simulator runs (i.e., when `obs_times` is provided), it records:
- `"times"`: the observation-time grid used for unrolling,
- `"states"`: the latent trajectory on that grid,
- `"observations"`: sampled (or conditioned) emissions on that grid.
When simulator trajectories are produced, sites are recorded as `"{name}_{key}"`
where `name` is the first
argument to `dsx.sample(name, dynamics, ...)` (conventionally `"f"`):

- `"f_times"`: trajectory time grid, shape `(n_sim, T)`,
- `"f_states"`: latent trajectory, shape `(n_sim, T, state_dim)`,
- `"f_observations"`: sampled or conditioned emissions, shape `(n_sim, T, obs_dim)`.

In filter-rollout mode (`predict_times` with filtered posteriors), additional
keys `"f_predicted_states"`, `"f_predicted_times"`, and
`"f_predicted_observations"` are recorded.

Under `numpyro.infer.Predictive(model, num_samples=N)`, NumPyro prepends a leading
`num_samples` axis, giving final shapes `(num_samples, n_sim, T, dim)`.
Use `dynestyx.flatten_draws` to collapse the `(num_samples, n_sim)` prefix into one
axis for plotting or downstream analysis.

If `obs_times` is omitted, no simulation is performed and these deterministic
sites are not added.
If both `obs_times` and `predict_times` are omitted, no simulation is performed
and these sites are not added.

## Simulators

Expand Down
12 changes: 6 additions & 6 deletions docs/api_reference/public/simulators/discrete_time_simulator.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@
with DiscreteTimeSimulator():
prior_pred = Predictive(model, num_samples=5)(
jr.PRNGKey(0),
obs_times=obs_times,
predict_times=obs_times,
)
print("Predictive keys:", sorted(prior_pred.keys())) # e.g. ['f', 'observations', 'phi', 'states', 'times', ...]
print("Predictive shapes:", {k: v.shape for k, v in prior_pred.items()}) # e.g. first axis is num_samples=5
print("Predictive keys:", sorted(prior_pred.keys())) # e.g. ['f_observations', 'f_states', 'f_times', 'phi', ...]
print("Predictive shapes:", {k: v.shape for k, v in prior_pred.items()}) # trajectory arrays: (num_samples, n_sim, T, dim); here num_samples=5, n_sim=1
```

??? example "NUTS with DiscreteTimeSimulator"
Expand All @@ -66,11 +66,11 @@
print("Posterior sample keys:", sorted(posterior.keys())) # stochastic sites (often includes latent x_* and parameters like 'phi')
print("Posterior sample shapes:", {k: v.shape for k, v in posterior.items()})

# Deterministic trajectory keys like 'states'/'observations' are in posterior predictive output.
# Deterministic trajectory keys like 'f_states'/'f_observations' are in posterior predictive output.
with DiscreteTimeSimulator():
post_pred = Predictive(model, posterior_samples=posterior)(
jr.PRNGKey(2), obs_times=obs_times
jr.PRNGKey(2), predict_times=obs_times
)
print("Posterior predictive keys:", sorted(post_pred.keys())) # includes 'states', 'observations', 'times'
print("Posterior predictive keys:", sorted(post_pred.keys())) # includes 'f_states', 'f_observations', 'f_times'
print("Posterior predictive shapes:", {k: v.shape for k, v in post_pred.items()})
```
12 changes: 6 additions & 6 deletions docs/api_reference/public/simulators/ode_simulator.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@

obs_times = jnp.linspace(0.0, 5.0, 51)
with ODESimulator():
prior_pred = Predictive(model, num_samples=5)(jr.PRNGKey(0), obs_times=obs_times)
print("Predictive keys:", sorted(prior_pred.keys())) # e.g. ['f', 'observations', 'sigma_y', 'states', 'theta', 'times', ...]
print("Predictive shapes:", {k: v.shape for k, v in prior_pred.items()}) # e.g. first axis is num_samples=5
prior_pred = Predictive(model, num_samples=5)(jr.PRNGKey(0), predict_times=obs_times)
print("Predictive keys:", sorted(prior_pred.keys())) # e.g. ['f_observations', 'f_states', 'f_times', 'sigma_y', 'theta', ...]
print("Predictive shapes:", {k: v.shape for k, v in prior_pred.items()}) # trajectory arrays: (num_samples, n_sim, T, dim); here num_samples=5, n_sim=1
```

??? example "NUTS with ODESimulator"
Expand All @@ -63,11 +63,11 @@
print("Posterior sample keys:", sorted(posterior.keys())) # stochastic sites (typically parameters and x_0)
print("Posterior sample shapes:", {k: v.shape for k, v in posterior.items()})

# Deterministic trajectories are exposed as 'states'/'observations' in posterior predictive output.
# Deterministic trajectories are exposed as 'f_states'/'f_observations' in posterior predictive output.
with ODESimulator():
post_pred = Predictive(model, posterior_samples=posterior)(
jr.PRNGKey(2), obs_times=obs_times
jr.PRNGKey(2), predict_times=obs_times
)
print("Posterior predictive keys:", sorted(post_pred.keys())) # includes 'states', 'observations', 'times'
print("Posterior predictive keys:", sorted(post_pred.keys())) # includes 'f_states', 'f_observations', 'f_times'
print("Posterior predictive shapes:", {k: v.shape for k, v in post_pred.items()})
```
45 changes: 35 additions & 10 deletions docs/api_reference/public/simulators/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,26 @@
Simulators (also called *unrollers*) turn a `DynamicalModel` into explicit NumPyro
`sample` sites for latent states and observations on a provided time grid.

!!! note "Context"
!!! note "When to use each time argument"
- **`obs_times` and `obs_values` must be provided together**:
- `obs_times` defines where observation sample sites (`y_t`) live.
- `obs_values` provides conditioning values for those sites via `obs=...`.
- Typical use: observed-data simulation/inference on a known observation grid.
- **`predict_times`**: use this when you want rollout trajectories at specific
times for simulation and/or post-filter rollout.
- In filter-rollout mode, predictions are generated at `predict_times` from
filtered posteriors.
- Typical use: forward simulation, forecasting, or dense trajectories for
visualization.
- **If both are provided**:
- `obs_times` controls filtering/conditioning points.
- `predict_times` controls where predicted trajectories are reported.
- **If both are omitted**: simulator does not run and adds no deterministic sites.

!!! note "Context and caveats"
- **NumPyro context required**: simulators call `numpyro.sample(...)` and draw
randomness via NumPyro PRNG keys, so they must run inside a NumPyro model
(or a `numpyro.handlers.seed(...)` context).
- **`obs_times` is required**: simulators only run when observation times are
provided (e.g. `dsx.sample(..., DynamicalModel(...), obs_times=...)`), because
those times define the trajectory grid.
- **Conditioning is optional**: if `obs_values` is provided (e.g.
`dsx.sample(..., DynamicalModel(...), obs_times=..., obs_values=...)`),
simulators pass these values as `obs=...` to the observation `numpyro.sample`
Expand All @@ -20,13 +33,25 @@ Simulators (also called *unrollers*) turn a `DynamicalModel` into explicit NumPy
`SDESimulator` is usually a poor inference strategy.

!!! note "Deterministic sites"
When a simulator runs (i.e., when `obs_times` is provided), it records:
- `"times"`: the observation-time grid used for unrolling,
- `"states"`: the latent trajectory on that grid,
- `"observations"`: sampled (or conditioned) emissions on that grid.
When simulator trajectories are produced, sites are recorded as `"{name}_{key}"`
where `name` is the first
argument to `dsx.sample(name, dynamics, ...)` (conventionally `"f"`):

- `"f_times"`: trajectory time grid, shape `(n_sim, T)`,
- `"f_states"`: latent trajectory, shape `(n_sim, T, state_dim)`,
- `"f_observations"`: sampled or conditioned emissions, shape `(n_sim, T, obs_dim)`.

In filter-rollout mode (`predict_times` with filtered posteriors), additional
keys `"f_predicted_states"`, `"f_predicted_times"`, and
`"f_predicted_observations"` are recorded.

Under `numpyro.infer.Predictive(model, num_samples=N)`, NumPyro prepends a leading
`num_samples` axis, giving final shapes `(num_samples, n_sim, T, dim)`.
Use `dynestyx.flatten_draws` to collapse the `(num_samples, n_sim)` prefix into one
axis for plotting or downstream analysis.

If `obs_times` is omitted, no simulation is performed and these deterministic
sites are not added.
If both `obs_times` and `predict_times` are omitted, no simulation is performed
and these sites are not added.

## BaseSimulator

Expand Down
12 changes: 6 additions & 6 deletions docs/api_reference/public/simulators/sde_simulator.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@

obs_times = jnp.linspace(0.0, 5.0, 51)
with SDESimulator():
prior_pred = Predictive(model, num_samples=5)(jr.PRNGKey(0), obs_times=obs_times)
print("Predictive keys:", sorted(prior_pred.keys())) # e.g. ['f', 'observations', 'sigma_x', 'sigma_y', 'states', 'theta', 'times', ...]
print("Predictive shapes:", {k: v.shape for k, v in prior_pred.items()}) # e.g. first axis is num_samples=5
prior_pred = Predictive(model, num_samples=5)(jr.PRNGKey(0), predict_times=obs_times)
print("Predictive keys:", sorted(prior_pred.keys())) # e.g. ['f_observations', 'f_states', 'f_times', 'sigma_x', 'sigma_y', 'theta', ...]
print("Predictive shapes:", {k: v.shape for k, v in prior_pred.items()}) # trajectory arrays: (num_samples, n_sim, T, dim); here num_samples=5, n_sim=1
```

??? example "NUTS with SDESimulator (small demonstration)"
Expand All @@ -67,11 +67,11 @@
print("Posterior sample keys:", sorted(posterior.keys())) # stochastic sites (typically parameters and x_0)
print("Posterior sample shapes:", {k: v.shape for k, v in posterior.items()})

# Deterministic trajectories are exposed as 'states'/'observations' in posterior predictive output.
# Deterministic trajectories are exposed as 'f_states'/'f_observations' in posterior predictive output.
with SDESimulator():
post_pred = Predictive(model, posterior_samples=posterior)(
jr.PRNGKey(2), obs_times=obs_times
jr.PRNGKey(2), predict_times=obs_times
)
print("Posterior predictive keys:", sorted(post_pred.keys())) # includes 'states', 'observations', 'times'
print("Posterior predictive keys:", sorted(post_pred.keys())) # includes 'f_states', 'f_observations', 'f_times'
print("Posterior predictive shapes:", {k: v.shape for k, v in post_pred.items()})
```
12 changes: 6 additions & 6 deletions docs/api_reference/public/simulators/simulator_wrapper.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@
with Simulator():
prior_pred = Predictive(model, num_samples=5)(
jr.PRNGKey(0),
obs_times=obs_times,
predict_times=obs_times,
)
print("Predictive keys:", sorted(prior_pred.keys())) # e.g. ['f', 'observations', 'phi', 'states', 'times', ...]
print("Predictive shapes:", {k: v.shape for k, v in prior_pred.items()}) # e.g. first axis is num_samples=5
print("Predictive keys:", sorted(prior_pred.keys())) # e.g. ['f_observations', 'f_states', 'f_times', 'phi', ...]
print("Predictive shapes:", {k: v.shape for k, v in prior_pred.items()}) # trajectory arrays: (num_samples, n_sim, T, dim); here num_samples=5, n_sim=1
```

??? example "NUTS inference with auto-routing"
Expand All @@ -67,11 +67,11 @@
print("Posterior sample keys:", sorted(posterior.keys())) # stochastic sites (e.g. parameters, and possibly latent x_* sites)
print("Posterior sample shapes:", {k: v.shape for k, v in posterior.items()}) # each shape starts with num_samples (here 100)

# Deterministic trajectory keys like 'states'/'observations' are in posterior predictive output.
# Deterministic trajectory keys like 'f_states'/'f_observations' are in posterior predictive output.
with Simulator():
post_pred = Predictive(model, posterior_samples=posterior)(
jr.PRNGKey(2), obs_times=obs_times
jr.PRNGKey(2), predict_times=obs_times
)
print("Posterior predictive keys:", sorted(post_pred.keys())) # includes 'states', 'observations', 'times'
print("Posterior predictive keys:", sorted(post_pred.keys())) # includes 'f_states', 'f_observations', 'f_times'
print("Posterior predictive shapes:", {k: v.shape for k, v in post_pred.items()})
```
49 changes: 29 additions & 20 deletions docs/deep_dives/discrete_time_lti_profile_likelihood.ipynb

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions docs/deep_dives/fhn_sparse_id.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -253,10 +253,10 @@
"\n",
"predictive = Predictive(model_with_true_drift, num_samples=1, exclude_deterministic=False)\n",
"with SDESimulator():\n",
" synthetic = predictive(k_data, obs_times=obs_times)\n",
" synthetic = predictive(k_data, predict_times=obs_times)\n",
"\n",
"obs_values = synthetic[\"observations\"][0] # y_k shape (n_obs, observation_dim)\n",
"states = synthetic[\"states\"][0] # x_{t_k} shape (n_obs, state_dim)\n",
"obs_values = synthetic[\"f_observations\"][0] # y_k shape (n_obs, observation_dim)\n",
"states = synthetic[\"f_states\"][0] # x_{t_k} shape (n_obs, state_dim)\n",
"times_1d = jnp.asarray(obs_times).squeeze()"
]
},
Expand Down Expand Up @@ -761,7 +761,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "dynestyx (3.12.11)",
"display_name": ".venv (3.12.11)",
"language": "python",
"name": "python3"
},
Expand Down
21 changes: 12 additions & 9 deletions docs/deep_dives/gp_drift.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@
"outputs": [],
"source": [
"# True system: pipe in drift, **kwargs the rest\n",
"def model_with_true_drift(obs_times=None, obs_values=None):\n",
"def model_with_true_drift(obs_times=None, obs_values=None, predict_times=None):\n",
" return dsx.sample(\n",
" \"f\",\n",
" DynamicalModel(\n",
Expand All @@ -225,18 +225,20 @@
" ),\n",
" obs_times=obs_times,\n",
" obs_values=obs_values,\n",
" predict_times=predict_times,\n",
" )\n",
"\n",
"\n",
"predictive = Predictive(\n",
" model_with_true_drift, num_samples=1, exclude_deterministic=False\n",
")\n",
"with SDESimulator():\n",
" synthetic = predictive(k_data, obs_times=obs_times)\n",
" synthetic = predictive(k_data, predict_times=obs_times)\n",
"\n",
"obs_values = synthetic[\"observations\"][0]\n",
"states = synthetic[\"states\"][0]\n",
"times_1d = jnp.asarray(obs_times).squeeze()"
"# f_observations / f_states: (num_samples, n_sim, T, ...)\n",
"obs_values = synthetic[\"f_observations\"][0, 0]\n",
"states = synthetic[\"f_states\"][0, 0]\n",
"times_1d = jnp.asarray(obs_times).squeeze()\n"
]
},
{
Expand Down Expand Up @@ -401,7 +403,7 @@
"from dynestyx.inference.filter_configs import ContinuousTimeEnKFConfig\n",
"\n",
"\n",
"def model_with_gp_drift(obs_times=None, obs_values=None):\n",
"def model_with_gp_drift(obs_times=None, obs_values=None, predict_times=None):\n",
" beta = numpyro.sample(\n",
" \"beta\", dist.Normal(0.0, 1.0).expand((MSTAR, state_dim)).to_event(2)\n",
" )\n",
Expand All @@ -411,6 +413,7 @@
" DynamicalModel(state_evolution=make_state_evolution(drift), **dynamics_kwargs),\n",
" obs_times=obs_times,\n",
" obs_values=obs_values,\n",
" predict_times=predict_times,\n",
" )\n",
"\n",
"\n",
Expand All @@ -429,7 +432,7 @@
" k_svi, num_steps=num_steps, obs_times=obs_times, obs_values=obs_values\n",
" )\n",
"\n",
"beta_map = guide.median(svi_result.params)[\"beta\"]"
"beta_map = guide.median(svi_result.params)[\"beta\"]\n"
]
},
{
Expand Down Expand Up @@ -1331,7 +1334,7 @@
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"display_name": ".venv (3.12.11)",
"language": "python",
"name": "python3"
},
Expand All @@ -1345,7 +1348,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.10"
"version": "3.12.11"
}
},
"nbformat": 4,
Expand Down
50 changes: 25 additions & 25 deletions docs/deep_dives/l63_speedup_dirac_vs_enkf.ipynb

Large diffs are not rendered by default.

Loading
Loading