diff --git a/cuthbertlib/kalman/filtering.py b/cuthbertlib/kalman/filtering.py index 4631239..8c125d1 100644 --- a/cuthbertlib/kalman/filtering.py +++ b/cuthbertlib/kalman/filtering.py @@ -208,6 +208,42 @@ def filtering_operator( mu = cho_solve((U1, True), b1) t1 = b1 @ mu - (eta2 + mu) @ tmp_2 - ell = ell1 + ell2 - 0.5 * t1 + 0.5 * jnp.linalg.slogdet(D_inv)[1] + + # Derivation for O(nx) log-determinant computation of D_inv: + # This is a long comment but I wanted to include the full derivation for clarity and future reference. + # The key idea is to express D_inv in terms of the blocks of Xi and then apply Sylvester's determinant theorem + # to compute its determinant efficiently. + # + # 1. Expand the blocks of Xi @ Xi.T: + # (Xi @ Xi.T)[1,1] = I + U1.T @ Z2 @ Z2.T @ U1 + # (Xi @ Xi.T)[2,1] = Z2 @ Z2.T @ U1 + # + # 2. Equate to the corresponding blocks of L @ L.T: + # Xi11 @ Xi11.T = I + U1.T @ Z2 @ Z2.T @ U1 + # Xi21 @ Xi11.T = Z2 @ Z2.T @ U1 + # + # 3. Expand D_inv using tmp_1 = Xi11^{-1} @ U1.T: + # D_inv = I - tmp_1.T @ Xi21.T + # = I - U1 @ Xi11^{-T} @ Xi21.T + # + # 4. Apply Sylvester's determinant theorem: + # det(D_inv) = det(I - Xi11^{-T} @ Xi21.T @ U1) + # + # 5. Multiply interior by Xi11^{-1} @ Xi11 and substitute block identities: + # Let P = U1.T @ Z2 @ Z2.T @ U1 + # det(D_inv) = det(I - (Xi11 @ Xi11.T)^{-1} @ (Xi21 @ Xi11.T).T @ U1) + # = det(I - (I + P)^{-1} @ P) + # = det((I + P)^{-1}) + # = 1 / det(Xi11 @ Xi11.T) + # = det(Xi11)^{-2} + # + # 6. Simplify the log-determinant term in the log-likelihood: + # 0.5 * log(det(D_inv)) = -log(|det(Xi11)|) + # + # Since Xi11 is lower triangular, the log-determinant is the sum of the logs of its diagonal. + # Replace `0.5 * jnp.linalg.slogdet(D_inv)[1]` with: + # -jnp.sum(jnp.log(jnp.abs(jnp.diag(Xi11)))) + # ell = ell1 + ell2 - 0.5 * t1 + 0.5 * jnp.linalg.slogdet(D_inv)[1] + ell = ell1 + ell2 - 0.5 * t1 - jnp.sum(jnp.log(jnp.abs(jnp.diag(Xi11)))) return FilterScanElement(A, b, U, eta, Z, ell) diff --git a/cuthbertlib/linalg/tria.py b/cuthbertlib/linalg/tria.py index 97645d4..7fefe6e 100644 --- a/cuthbertlib/linalg/tria.py +++ b/cuthbertlib/linalg/tria.py @@ -1,21 +1,97 @@ """Implements triangularization operator a matrix via QR decomposition.""" import jax +import jax.numpy as jnp from cuthbertlib.types import Array +def _adj(x: Array) -> Array: + """Conjugate transpose for batched matrices.""" + return jnp.swapaxes(x.conj(), -1, -2) + + +@jax.custom_jvp def tria(A: Array) -> Array: - r"""A triangularization operator using QR decomposition. + """A triangularization operator using QR decomposition. Args: A: The matrix to triangularize. Returns: - A lower triangular matrix $R$ such that $R R^\top = A A^\top$. + A lower triangular matrix R such that R @ R.T = A @ A.T. - Reference: - [Arasaratnam and Haykin (2008)](https://ieeexplore.ieee.org/document/4524036): Square-Root Quadrature Kalman Filtering + References: + Paper: Arasaratnam and Haykin (2008): Square-Root Quadrature Kalman Filtering + https://ieeexplore.ieee.org/document/4524036 """ - _, R = jax.scipy.linalg.qr(A.T, mode="economic") - return R.T + _, R_qr = jnp.linalg.qr(_adj(A), mode="reduced") + return _adj(R_qr) + + +@tria.defjvp +def _tria_jvp(primals, tangents): + # Derivation of the exact analytical JVP for the lower-triangularization operator: + # + # 1. The operation computes a lower triangular R such that: + # R @ R.T = A @ A.T + # + # 2. Taking the differential of both sides yields the Gramian differential identity: + # dR @ R.T + R @ dR.T = dA @ A.T + A @ dA.T + # + # 3. For rank-deficient A, the exact lower-triangular tangent dR decomposes into + # column-space and null-space components: dR = dR_col + dR_null. + # + # 4. To find dR_col, multiply the identity from the left by R^\dagger (pseudoinverse) + # and from the right by R^{\dagger T}: + # R^\dagger @ dR_col + dR_col.T @ R^{\dagger T} = R^\dagger @ dA @ A.T @ R^{\dagger T} + R^\dagger @ A @ dA.T @ R^{\dagger T} + # + # 5. Let Q be the active subspace orthogonal factor such that A = R @ Q.T. + # Substituting A.T @ R^{\dagger T} = Q and R^\dagger @ A = Q.T: + # R^\dagger @ dR_col + dR_col.T @ R^{\dagger T} = R^\dagger @ dA @ Q + Q.T @ dA.T @ R^{\dagger T} + # + # 6. Define K = R^\dagger @ dA @ Q. The right hand side becomes K + K.T: + # R^\dagger @ dR_col + (R^\dagger @ dR_col).T = K + K.T + # + # 7. Define dM_col = R^\dagger @ dR_col. Solving for the lower-triangular dM_col: + # dM_col = tril(K + K.T) - diag(K) + # + # 8. Recover the column-space differential dR_col by left-multiplying by R: + # dR_col = R @ dM_col + # + # 9. To satisfy the orthogonal cross-terms for arbitrary perturbations outside + # the column space of A, add the null-space component projected via (I - R @ R^\dagger): + # dR_null = (I - R @ R^\dagger) @ dA @ Q + # + # 10. The complete, exact JVP is the sum of both components: + # dR = dR_col + dR_null + + (A,) = primals + (dA,) = tangents + + A_T = jnp.swapaxes(A, -1, -2) + Q, R_qr = jnp.linalg.qr(A_T, mode="reduced") + + R = jnp.swapaxes(R_qr, -1, -2) + + # Q has shape (..., M, N). A is (..., N, M). + # A^T = Q R_qr => A = R_qr^T Q^T = R Q^T + + R_pinv = jnp.linalg.pinv(R) + + # K = R^{-1} dA Q + # R can be degenerate so we use the pseudoinverse to ensure the JVP is well-defined everywhere, + K = R_pinv @ dA @ Q + K_T = jnp.swapaxes(K, -1, -2) + + # Solve for lower triangular perturbation dM + dM^T = K + K^T + I = jnp.eye(K.shape[-1], dtype=K.dtype) + dM = jnp.tril(K + K_T) - K * I + + # Compute the null-space part + dR_null = (jnp.eye(R.shape[-2], dtype=R.dtype) - R @ R_pinv) @ dA @ Q + + # Apply to get the tangent at R + dR = R @ dM + dR_null + + return R, dR diff --git a/tests/cuthbert/gaussian/test_kalman.py b/tests/cuthbert/gaussian/test_kalman.py index 93c3235..8f0b347 100644 --- a/tests/cuthbert/gaussian/test_kalman.py +++ b/tests/cuthbert/gaussian/test_kalman.py @@ -93,13 +93,16 @@ def get_observation_params(model_inputs: int) -> tuple[Array, Array, Array, Arra return filter, smoother, model_inputs -seeds = [1, 43, 99, 123, 456] +seeds = [1, 43, 99] x_dims = [3] y_dims = [1, 2] num_time_steps = [1, 25] common_params = list(itertools.product(seeds, x_dims, y_dims, num_time_steps)) +# Use fewer time steps for gradient test because finite difference will not work well for long sequences +gradient_params = list(itertools.product(seeds, x_dims, y_dims, [1, 2])) + @pytest.mark.parametrize("seed,x_dim,y_dim,num_time_steps", common_params) def test_offline_filter(seed, x_dim, y_dim, num_time_steps): @@ -150,6 +153,32 @@ def test_offline_filter(seed, x_dim, y_dim, num_time_steps): ) +@pytest.mark.parametrize("seed,x_dim,y_dim,num_time_steps", gradient_params) +def test_check_gradient(seed, x_dim, y_dim, num_time_steps): + m0, chol_P0, Fs, cs, chol_Qs, Hs, ds, chol_Rs, ys = generate_lgssm( + seed, x_dim, y_dim, num_time_steps + ) + + @jax.jit + def get_loglikelihood_from_params( + m0_, chol_P0_, Fs_, cs_, chol_Qs_, Hs_, ds_, chol_Rs_ + ): + kalman_filter, _, model_inputs = load_kalman_inference( + m0_, chol_P0_, Fs_, cs_, chol_Qs_, Hs_, ds_, chol_Rs_, ys + ) + states = filter(kalman_filter, model_inputs) + return states.log_normalizing_constant[-1] + + chex.assert_numerical_grads( + get_loglikelihood_from_params, + (m0, chol_P0, Fs, cs, chol_Qs, Hs, ds, chol_Rs), + order=1, + rtol=0.01, + atol=0.01, + eps=1e-5, + ) + + @pytest.mark.parametrize("seed", [1, 43, 99, 123, 456]) @pytest.mark.parametrize("x_dim", [1, 10]) @pytest.mark.parametrize("y_dim", [1, 5]) diff --git a/tests/cuthbertlib/kalman/test_filtering.py b/tests/cuthbertlib/kalman/test_filtering.py index 18ab99d..f26c2fa 100644 --- a/tests/cuthbertlib/kalman/test_filtering.py +++ b/tests/cuthbertlib/kalman/test_filtering.py @@ -56,7 +56,6 @@ def test_predict(seed, x_dim): pred_cov = pred_chol_cov @ pred_chol_cov.T des_m, des_cov = std_predict(m0, P0, F, c, Q) - chex.assert_trees_all_close((pred_m, pred_cov), (des_m, des_cov), rtol=1e-10) @@ -78,6 +77,10 @@ def test_update(seed, x_dim, y_dim): (m, chol_P), ell = update(m0, chol_P0, H, d, chol_R, y) P = chol_P @ chol_P.T + chex.assert_numerical_grads( + lambda *args: update(*args)[-1], (m0, chol_P0, H, d, chol_R, y), order=1 + ) + des_m, des_P, des_ell = std_update(m0, P0, H, d, R, y) chex.assert_trees_all_close((m, P), (des_m, des_P), rtol=1e-10) diff --git a/tests/cuthbertlib/linalg/test_tria.py b/tests/cuthbertlib/linalg/test_tria.py index 16f68b2..7748d4b 100644 --- a/tests/cuthbertlib/linalg/test_tria.py +++ b/tests/cuthbertlib/linalg/test_tria.py @@ -26,3 +26,20 @@ def test_tria(seed, shape): # Check that R @ R.T = A @ A.T assert jnp.allclose(R @ R.T, A @ A.T) + + +def test_tria_jvp_preserves_gram_matrix_for_rank_deficient_input(): + A = jnp.array([[1.0, 0.0], [1.0, 0.0], [1.0, 0.0]]) + dA = jnp.array([[0.0, 0.0], [1.0, 2.0], [3.0, 4.0]]) + + def gram_via_tria(x): + R = tria(x) + return R @ R.T + + def gram_direct(x): + return x @ x.T + + gram_jvp_from_tria = jax.jvp(gram_via_tria, (A,), (dA,))[1] + gram_jvp_direct = jax.jvp(gram_direct, (A,), (dA,))[1] + + assert jnp.allclose(gram_jvp_from_tria, gram_jvp_direct)