diff --git a/tests/cuthbert/gaussian/test_kalman.py b/tests/cuthbert/gaussian/test_kalman.py index 93c3235..658797a 100644 --- a/tests/cuthbert/gaussian/test_kalman.py +++ b/tests/cuthbert/gaussian/test_kalman.py @@ -93,6 +93,39 @@ def get_observation_params(model_inputs: int) -> tuple[Array, Array, Array, Arra return filter, smoother, model_inputs +def generate_simple_LTI_system(key: jax.Array, num_steps: int = 50): + x_dim, y_dim = 2, 1 + + m0 = jnp.zeros(x_dim) + chol_P0 = jnp.eye(x_dim) + + F = jnp.array([[0.4, 0.1], [0.1, 0.6]]) + c = jnp.zeros(x_dim) + chol_Q = 0.1**0.5 * jnp.eye(x_dim) + + H = jnp.array([[1.0, 0.0]]) + d = jnp.zeros(y_dim) + chol_R = 0.5 * jnp.eye(y_dim) + + xs = [] + ys = [] + + key, state_key = jax.random.split(key) + x = m0 + chol_P0 @ jax.random.normal(state_key, (x_dim,)) + xs.append(x) + + for _ in range(num_steps): + key, state_key, obs_key = jax.random.split(key, 3) + x = F @ x + chol_Q @ jax.random.normal(state_key, (x_dim,)) + y = H @ x + chol_R @ jax.random.normal(obs_key, (y_dim,)) + xs.append(x) + ys.append(y) + + ys = jnp.stack(ys) + + return m0, chol_P0, F, c, chol_Q, H, d, chol_R, ys + + seeds = [1, 43, 99, 123, 456] x_dims = [3] y_dims = [1, 2] @@ -288,6 +321,55 @@ def construct_joint_chol_cov(chol_cov_t_plus_1, gain_t, chol_cov_given_next_t): ) +@pytest.mark.parametrize("seed", seeds) +@pytest.mark.parametrize("num_time_steps", num_time_steps) +@pytest.mark.parametrize("parallel", [False, True]) +@pytest.mark.parametrize("param", ["F", "c", "chol_Q", "H", "d", "chol_R"]) +def test_simple_LTI_grad_mll_non_nan(seed, num_time_steps, parallel, param): + key = jax.random.key(seed) + m0, chol_P0, F, c, chol_Q, H, d, chol_R, ys = generate_simple_LTI_system( + key, num_steps=num_time_steps + ) + + def get_mll_from_param(param_value): + params = { + "m0": m0, + "chol_P0": chol_P0, + "F": F, + "c": c, + "chol_Q": chol_Q, + "H": H, + "d": d, + "chol_R": chol_R, + } + params[param] = param_value + + def get_init_params(model_inputs: int): + return params["m0"], params["chol_P0"] + + def get_dynamics_params(model_inputs: int): + return params["F"], params["c"], params["chol_Q"] + + def get_observation_params(model_inputs: int): + return params["H"], params["d"], params["chol_R"], ys[model_inputs - 1] + + filter_obj = kalman.build_filter( + get_init_params, get_dynamics_params, get_observation_params + ) + + states = filter(filter_obj, jnp.arange(len(ys) + 1), parallel=parallel) + return states.log_normalizing_constant[-1] + + grad_mll_at_param = jax.grad(get_mll_from_param)(eval(param)) + ( + chex.assert_tree_all_finite(grad_mll_at_param), + f"Gradient of MLL with respect to {param} is infinite/NaN", + ) + assert not jnp.all(grad_mll_at_param == 0.0), ( + f"Gradient of MLL with respect to {param} is zero" + ) + + # @pytest.mark.parametrize("seed,x_dim,y_dim,num_time_steps", common_params) # @pytest.mark.parametrize("parallel", [False, True]) # def test_sampler(seed, x_dim, y_dim, num_time_steps, parallel):