Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 38 additions & 2 deletions dynestyx/simulators.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import diffrax as dfx
import jax.numpy as jnp
import numpy as np
import numpyro
from effectful.ops.semantics import fwd
from effectful.ops.syntax import ObjectInterpretation, implements
Expand Down Expand Up @@ -409,8 +410,43 @@ def _simulate(

# DiracIdentityObservation with observed values: y_t = x_t, so we use plating
# instead of scan. state_evolution returns a dist; call it with batched inputs.
if isinstance(dynamics.observation_model, DiracIdentityObservation) and (
obs_values is not None
#
# When there are missing rows (entire-row NaN), we filter them out
# using numpy (concrete indexing, no tracers) before entering the
# plate. state_evolution handles non-unit dt from skipped rows.
has_no_obs = obs_values is None
has_missing_data = not has_no_obs and np.isnan(np.asarray(obs_values)).any()

if has_missing_data:
# Only entire-row missingness is supported; raise on partial.
obs_np = np.asarray(obs_values)
nan_per_row = np.isnan(obs_np).any(axis=1)
all_nan_per_row = np.isnan(obs_np).all(axis=1)
has_partial = (nan_per_row & ~all_nan_per_row).any()
if has_partial:
raise ValueError(
"Partial missingness (some but not all components NaN in a "
"row) is not yet supported. Only entire-row NaN is allowed."
)
# Filter to observed rows using numpy (concrete indexing, no
# tracers). state_evolution handles non-unit dt from skipped
# rows. Both the Dirac plate path and the default scan path
# operate on the filtered arrays.
observed_mask = ~all_nan_per_row
obs_values = jnp.array(obs_np[observed_mask])
obs_times = jnp.array(np.asarray(obs_times)[observed_mask])
if ctrl_values is not None:
ctrl_values = jnp.array(np.asarray(ctrl_values)[observed_mask])
T = len(obs_times)
if T < 1:
raise ValueError(
"obs_times must contain at least one timepoint after "
"removing missing data"
)

if (
isinstance(dynamics.observation_model, DiracIdentityObservation)
and not has_no_obs
):
numpyro.sample("x_0", dynamics.initial_condition, obs=obs_values[0])
numpyro.deterministic("y_0", obs_values[0])
Expand Down
46 changes: 46 additions & 0 deletions tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,3 +565,49 @@ def jumpy_controls_model_ode(
ctrl_times=ctrl_times,
ctrl_values=ctrl_values,
)


def particle_sde_gaussian_potential_model(
N=3,
D=2,
K=2,
sigma=0.5,
obs_times=None,
obs_values=None,
):
"""N particles in D dimensions with drift = -grad(V), V = sum of weighted Gaussians.

Learnable parameters: centers (K, D) and strengths (K,) of the Gaussian components.
Diffusion is diagonal with known sigma.
"""
centers = numpyro.sample(
"centers", dist.Normal(0.0, 3.0).expand([K, D]).to_event(2)
)
strengths = numpyro.sample(
"strengths", dist.LogNormal(0.0, 1.0).expand([K]).to_event(1)
)

def potential(x, u, t):
particles = x.reshape(N, D)
V = 0.0
for k in range(K):
diff = particles - centers[k]
V = V - strengths[k] * jnp.sum(jnp.exp(-0.5 * jnp.sum(diff**2, axis=-1)))
return V

state_dim = N * D
dynamics = DynamicalModel(
control_dim=0,
initial_condition=dist.MultivariateNormal(
loc=jnp.zeros(state_dim),
covariance_matrix=2.0**2 * jnp.eye(state_dim),
),
state_evolution=ContinuousTimeStateEvolution(
potential=potential,
use_negative_gradient=True,
diffusion_coefficient=lambda x, u, t: sigma * jnp.eye(state_dim),
),
observation_model=DiracIdentityObservation(),
)

dsx.sample("f", dynamics, obs_times=obs_times, obs_values=obs_values)
Loading
Loading