-
Notifications
You must be signed in to change notification settings - Fork 5
Description
I ran into some troubles with NaN gradients in the Kalman filter. Here's a simple test that can be attended to test_kalman.py and reproduces:
def generate_simple_LTI_system(key: jax.Array, num_steps: int = 50):
x_dim, y_dim = 2, 1
m0 = jnp.zeros(x_dim)
chol_P0 = jnp.eye(x_dim)
F = jnp.array([[0.4, 0.1], [0.1, 0.6]])
c = jnp.zeros(x_dim)
chol_Q = 0.1**0.5 * jnp.eye(x_dim)
H = jnp.array([[1.0, 0.0]])
d = jnp.zeros(y_dim)
chol_R = 0.5 * jnp.eye(y_dim)
xs = []
ys = []
key, state_key = jax.random.split(key)
x = m0 + chol_P0 @ jax.random.normal(state_key, (x_dim,))
xs.append(x)
for _ in range(num_steps):
key, state_key, obs_key = jax.random.split(key, 3)
x = F @ x + chol_Q @ jax.random.normal(state_key, (x_dim,))
y = H @ x + chol_R @ jax.random.normal(obs_key, (y_dim,))
xs.append(x)
ys.append(y)
ys = jnp.stack(ys)
return m0, chol_P0, F, c, chol_Q, H, d, chol_R, ys
@pytest.mark.parametrize("seed", seeds)
@pytest.mark.parametrize("num_time_steps", num_time_steps)
@pytest.mark.parametrize("parallel", [False, True])
@pytest.mark.parametrize("param", ["F", "c", "chol_Q", "H", "d", "chol_R"])
def test_simple_LTI_grad_mll_non_nan(seed, num_time_steps, parallel, param):
key = jax.random.key(seed)
m0, chol_P0, F, c, chol_Q, H, d, chol_R, ys = generate_simple_LTI_system(
key, num_steps=num_time_steps
)
def get_mll_from_param(param_value):
params = {
"m0": m0,
"chol_P0": chol_P0,
"F": F,
"c": c,
"chol_Q": chol_Q,
"H": H,
"d": d,
"chol_R": chol_R,
}
params[param] = param_value
def get_init_params(model_inputs: int):
return params["m0"], params["chol_P0"]
def get_dynamics_params(model_inputs: int):
return params["F"], params["c"], params["chol_Q"]
def get_observation_params(model_inputs: int):
return params["H"], params["d"], params["chol_R"], ys[model_inputs - 1]
filter_obj = kalman.build_filter(
get_init_params, get_dynamics_params, get_observation_params
)
states = filter(filter_obj, jnp.arange(len(ys) + 1), parallel=parallel)
return states.log_normalizing_constant[-1]
grad_mll_at_param = jax.grad(get_mll_from_param)(eval(param))
(
chex.assert_tree_all_finite(grad_mll_at_param),
f"Gradient of MLL with respect to {param} is infinite/NaN",
)
assert not jnp.all(grad_mll_at_param == 0.0), (
f"Gradient of MLL with respect to {param} is zero"
)This tests consistently results in NaN gradients when parallel = True and estimating gradients w.r.t. F, chol_Q, H, or chol_R. I've also run into issues on similar problems with parallel=False.
This also occurs if you try to compute gradients w.r.t. F in the car tracking example.
Debugging a bit, this seems to occur in backwards passes over the tria(...) calls in the Kalman filter_operator. Add jitter before these operations. In particular, I get valid gradient estimators with
def _add_jitter(A: Array, jitter: float = 1e-8) -> Array:
return A + jnp.eye(A.shape[-2], A.shape[-1]) * jitter
def filtering_operator(
elem_i: FilterScanElement, elem_j: FilterScanElement
) -> FilterScanElement:
"""Binary associative operator for the square root Kalman filter.
Args:
elem_i: Filter scan element for the previous time step.
elem_j: Filter scan element for the current time step.
Returns:
FilterScanElement: The output of the associative operator applied to the input elements.
"""
A1, b1, U1, eta1, Z1, ell1 = elem_i
A2, b2, U2, eta2, Z2, ell2 = elem_j
nx = Z2.shape[0]
Xi = jnp.block([[U1.T @ Z2, jnp.eye(nx)], [Z2, jnp.zeros_like(A1)]])
tria_xi = tria(_add_jitter(Xi))
Xi11 = tria_xi[:nx, :nx]
Xi21 = tria_xi[nx : nx + nx, :nx]
Xi22 = tria_xi[nx : nx + nx, nx:]
tmp_1 = solve_triangular(Xi11, U1.T, lower=True).T
D_inv = jnp.eye(nx) - tmp_1 @ Xi21.T
tmp_2 = D_inv @ (b1 + U1 @ (U1.T @ eta2))
A = A2 @ D_inv @ A1
b = A2 @ tmp_2 + b2
_int = jnp.concatenate([A2 @ tmp_1, U2], axis=1)
U = tria(_add_jitter(_int))
eta = A1.T @ (D_inv.T @ (eta2 - Z2 @ (Z2.T @ b1))) + eta1
_int = jnp.concatenate([A1.T @ Xi22, Z1], axis=1)
Z = tria(_add_jitter(_int))
mu = cho_solve((U1, True), b1)
t1 = b1 @ mu - (eta2 + mu) @ tmp_2
ell = ell1 + ell2 - 0.5 * t1 + 0.5 * jnp.linalg.slogdet(D_inv)[1]
return FilterScanElement(A, b, U, eta, Z, ell)I think adding jitters on any subset of these tria operators doesn't fix things. This happens for 32- and 64-bit floating points numbers.
Happy to open a PR with the added test + jitters, if that seems like the correct solution.