Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions tests/cuthbert/gaussian/test_kalman.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,39 @@ def get_observation_params(model_inputs: int) -> tuple[Array, Array, Array, Arra
return filter, smoother, model_inputs


def generate_simple_LTI_system(key: jax.Array, num_steps: int = 50):
x_dim, y_dim = 2, 1

m0 = jnp.zeros(x_dim)
chol_P0 = jnp.eye(x_dim)

F = jnp.array([[0.4, 0.1], [0.1, 0.6]])
c = jnp.zeros(x_dim)
chol_Q = 0.1**0.5 * jnp.eye(x_dim)

H = jnp.array([[1.0, 0.0]])
d = jnp.zeros(y_dim)
chol_R = 0.5 * jnp.eye(y_dim)

xs = []
ys = []

key, state_key = jax.random.split(key)
x = m0 + chol_P0 @ jax.random.normal(state_key, (x_dim,))
xs.append(x)

for _ in range(num_steps):
key, state_key, obs_key = jax.random.split(key, 3)
x = F @ x + chol_Q @ jax.random.normal(state_key, (x_dim,))
y = H @ x + chol_R @ jax.random.normal(obs_key, (y_dim,))
xs.append(x)
ys.append(y)

ys = jnp.stack(ys)

return m0, chol_P0, F, c, chol_Q, H, d, chol_R, ys


seeds = [1, 43, 99, 123, 456]
x_dims = [3]
y_dims = [1, 2]
Expand Down Expand Up @@ -288,6 +321,55 @@ def construct_joint_chol_cov(chol_cov_t_plus_1, gain_t, chol_cov_given_next_t):
)


@pytest.mark.parametrize("seed", seeds)
@pytest.mark.parametrize("num_time_steps", num_time_steps)
@pytest.mark.parametrize("parallel", [False, True])
@pytest.mark.parametrize("param", ["F", "c", "chol_Q", "H", "d", "chol_R"])
def test_simple_LTI_grad_mll_non_nan(seed, num_time_steps, parallel, param):
key = jax.random.key(seed)
m0, chol_P0, F, c, chol_Q, H, d, chol_R, ys = generate_simple_LTI_system(
key, num_steps=num_time_steps
)

def get_mll_from_param(param_value):
params = {
"m0": m0,
"chol_P0": chol_P0,
"F": F,
"c": c,
"chol_Q": chol_Q,
"H": H,
"d": d,
"chol_R": chol_R,
}
params[param] = param_value

def get_init_params(model_inputs: int):
return params["m0"], params["chol_P0"]

def get_dynamics_params(model_inputs: int):
return params["F"], params["c"], params["chol_Q"]

def get_observation_params(model_inputs: int):
return params["H"], params["d"], params["chol_R"], ys[model_inputs - 1]

filter_obj = kalman.build_filter(
get_init_params, get_dynamics_params, get_observation_params
)

states = filter(filter_obj, jnp.arange(len(ys) + 1), parallel=parallel)
return states.log_normalizing_constant[-1]

grad_mll_at_param = jax.grad(get_mll_from_param)(eval(param))
(
chex.assert_tree_all_finite(grad_mll_at_param),
f"Gradient of MLL with respect to {param} is infinite/NaN",
)
assert not jnp.all(grad_mll_at_param == 0.0), (
f"Gradient of MLL with respect to {param} is zero"
)


# @pytest.mark.parametrize("seed,x_dim,y_dim,num_time_steps", common_params)
# @pytest.mark.parametrize("parallel", [False, True])
# def test_sampler(seed, x_dim, y_dim, num_time_steps, parallel):
Expand Down
Loading