Skip to content

Numerical Instabilities for Kalman Filter Gradients #211

@DanWaxman

Description

@DanWaxman

I ran into some troubles with NaN gradients in the Kalman filter. Here's a simple test that can be attended to test_kalman.py and reproduces:

def generate_simple_LTI_system(key: jax.Array, num_steps: int = 50):
    x_dim, y_dim = 2, 1

    m0 = jnp.zeros(x_dim)
    chol_P0 = jnp.eye(x_dim)

    F = jnp.array([[0.4, 0.1], [0.1, 0.6]])
    c = jnp.zeros(x_dim)
    chol_Q = 0.1**0.5 * jnp.eye(x_dim)

    H = jnp.array([[1.0, 0.0]])
    d = jnp.zeros(y_dim)
    chol_R = 0.5 * jnp.eye(y_dim)

    xs = []
    ys = []

    key, state_key = jax.random.split(key)
    x = m0 + chol_P0 @ jax.random.normal(state_key, (x_dim,))
    xs.append(x)

    for _ in range(num_steps):
        key, state_key, obs_key = jax.random.split(key, 3)
        x = F @ x + chol_Q @ jax.random.normal(state_key, (x_dim,))
        y = H @ x + chol_R @ jax.random.normal(obs_key, (y_dim,))
        xs.append(x)
        ys.append(y)

    ys = jnp.stack(ys)

    return m0, chol_P0, F, c, chol_Q, H, d, chol_R, ys

@pytest.mark.parametrize("seed", seeds)
@pytest.mark.parametrize("num_time_steps", num_time_steps)
@pytest.mark.parametrize("parallel", [False, True])
@pytest.mark.parametrize("param", ["F", "c", "chol_Q", "H", "d", "chol_R"])
def test_simple_LTI_grad_mll_non_nan(seed, num_time_steps, parallel, param):
    key = jax.random.key(seed)
    m0, chol_P0, F, c, chol_Q, H, d, chol_R, ys = generate_simple_LTI_system(
        key, num_steps=num_time_steps
    )

    def get_mll_from_param(param_value):
        params = {
            "m0": m0,
            "chol_P0": chol_P0,
            "F": F,
            "c": c,
            "chol_Q": chol_Q,
            "H": H,
            "d": d,
            "chol_R": chol_R,
        }
        params[param] = param_value

        def get_init_params(model_inputs: int):
            return params["m0"], params["chol_P0"]

        def get_dynamics_params(model_inputs: int):
            return params["F"], params["c"], params["chol_Q"]

        def get_observation_params(model_inputs: int):
            return params["H"], params["d"], params["chol_R"], ys[model_inputs - 1]

        filter_obj = kalman.build_filter(
            get_init_params, get_dynamics_params, get_observation_params
        )

        states = filter(filter_obj, jnp.arange(len(ys) + 1), parallel=parallel)
        return states.log_normalizing_constant[-1]

    grad_mll_at_param = jax.grad(get_mll_from_param)(eval(param))
    (
        chex.assert_tree_all_finite(grad_mll_at_param),
        f"Gradient of MLL with respect to {param} is infinite/NaN",
    )
    assert not jnp.all(grad_mll_at_param == 0.0), (
        f"Gradient of MLL with respect to {param} is zero"
    )

This tests consistently results in NaN gradients when parallel = True and estimating gradients w.r.t. F, chol_Q, H, or chol_R. I've also run into issues on similar problems with parallel=False.

This also occurs if you try to compute gradients w.r.t. F in the car tracking example.

Debugging a bit, this seems to occur in backwards passes over the tria(...) calls in the Kalman filter_operator. Add jitter before these operations. In particular, I get valid gradient estimators with

def _add_jitter(A: Array, jitter: float = 1e-8) -> Array:
    return A + jnp.eye(A.shape[-2], A.shape[-1]) * jitter

def filtering_operator(
    elem_i: FilterScanElement, elem_j: FilterScanElement
) -> FilterScanElement:
    """Binary associative operator for the square root Kalman filter.

    Args:
        elem_i: Filter scan element for the previous time step.
        elem_j: Filter scan element for the current time step.

    Returns:
        FilterScanElement: The output of the associative operator applied to the input elements.
    """
    A1, b1, U1, eta1, Z1, ell1 = elem_i
    A2, b2, U2, eta2, Z2, ell2 = elem_j

    nx = Z2.shape[0]

    Xi = jnp.block([[U1.T @ Z2, jnp.eye(nx)], [Z2, jnp.zeros_like(A1)]])
    tria_xi = tria(_add_jitter(Xi))
    Xi11 = tria_xi[:nx, :nx]
    Xi21 = tria_xi[nx : nx + nx, :nx]
    Xi22 = tria_xi[nx : nx + nx, nx:]

    tmp_1 = solve_triangular(Xi11, U1.T, lower=True).T
    D_inv = jnp.eye(nx) - tmp_1 @ Xi21.T
    tmp_2 = D_inv @ (b1 + U1 @ (U1.T @ eta2))

    A = A2 @ D_inv @ A1
    b = A2 @ tmp_2 + b2
    _int = jnp.concatenate([A2 @ tmp_1, U2], axis=1)
    U = tria(_add_jitter(_int))
    eta = A1.T @ (D_inv.T @ (eta2 - Z2 @ (Z2.T @ b1))) + eta1
    _int = jnp.concatenate([A1.T @ Xi22, Z1], axis=1)
    Z = tria(_add_jitter(_int))

    mu = cho_solve((U1, True), b1)
    t1 = b1 @ mu - (eta2 + mu) @ tmp_2
    ell = ell1 + ell2 - 0.5 * t1 + 0.5 * jnp.linalg.slogdet(D_inv)[1]

    return FilterScanElement(A, b, U, eta, Z, ell)

I think adding jitters on any subset of these tria operators doesn't fix things. This happens for 32- and 64-bit floating points numbers.

Happy to open a PR with the added test + jitters, if that seems like the correct solution.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions