Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 37 additions & 1 deletion cuthbertlib/kalman/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,42 @@ def filtering_operator(

mu = cho_solve((U1, True), b1)
t1 = b1 @ mu - (eta2 + mu) @ tmp_2
ell = ell1 + ell2 - 0.5 * t1 + 0.5 * jnp.linalg.slogdet(D_inv)[1]

# Derivation for O(nx) log-determinant computation of D_inv:
# This is a long comment but I wanted to include the full derivation for clarity and future reference.
# The key idea is to express D_inv in terms of the blocks of Xi and then apply Sylvester's determinant theorem
# to compute its determinant efficiently.
#
# 1. Expand the blocks of Xi @ Xi.T:
# (Xi @ Xi.T)[1,1] = I + U1.T @ Z2 @ Z2.T @ U1
# (Xi @ Xi.T)[2,1] = Z2 @ Z2.T @ U1
#
# 2. Equate to the corresponding blocks of L @ L.T:
# Xi11 @ Xi11.T = I + U1.T @ Z2 @ Z2.T @ U1
# Xi21 @ Xi11.T = Z2 @ Z2.T @ U1
#
# 3. Expand D_inv using tmp_1 = Xi11^{-1} @ U1.T:
# D_inv = I - tmp_1.T @ Xi21.T
# = I - U1 @ Xi11^{-T} @ Xi21.T
#
# 4. Apply Sylvester's determinant theorem:
# det(D_inv) = det(I - Xi11^{-T} @ Xi21.T @ U1)
#
# 5. Multiply interior by Xi11^{-1} @ Xi11 and substitute block identities:
# Let P = U1.T @ Z2 @ Z2.T @ U1
# det(D_inv) = det(I - (Xi11 @ Xi11.T)^{-1} @ (Xi21 @ Xi11.T).T @ U1)
# = det(I - (I + P)^{-1} @ P)
# = det((I + P)^{-1})
# = 1 / det(Xi11 @ Xi11.T)
# = det(Xi11)^{-2}
#
# 6. Simplify the log-determinant term in the log-likelihood:
# 0.5 * log(det(D_inv)) = -log(|det(Xi11)|)
#
# Since Xi11 is lower triangular, the log-determinant is the sum of the logs of its diagonal.
# Replace `0.5 * jnp.linalg.slogdet(D_inv)[1]` with:
# -jnp.sum(jnp.log(jnp.abs(jnp.diag(Xi11))))
# ell = ell1 + ell2 - 0.5 * t1 + 0.5 * jnp.linalg.slogdet(D_inv)[1]
ell = ell1 + ell2 - 0.5 * t1 - jnp.sum(jnp.log(jnp.abs(jnp.diag(Xi11))))

return FilterScanElement(A, b, U, eta, Z, ell)
88 changes: 82 additions & 6 deletions cuthbertlib/linalg/tria.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,97 @@
"""Implements triangularization operator a matrix via QR decomposition."""

import jax
import jax.numpy as jnp

from cuthbertlib.types import Array


def _adj(x: Array) -> Array:
"""Conjugate transpose for batched matrices."""
return jnp.swapaxes(x.conj(), -1, -2)


@jax.custom_jvp
def tria(A: Array) -> Array:
r"""A triangularization operator using QR decomposition.
"""A triangularization operator using QR decomposition.

Args:
A: The matrix to triangularize.

Returns:
A lower triangular matrix $R$ such that $R R^\top = A A^\top$.
A lower triangular matrix R such that R @ R.T = A @ A.T.

Reference:
[Arasaratnam and Haykin (2008)](https://ieeexplore.ieee.org/document/4524036): Square-Root Quadrature Kalman Filtering
References:
Paper: Arasaratnam and Haykin (2008): Square-Root Quadrature Kalman Filtering
https://ieeexplore.ieee.org/document/4524036
"""
_, R = jax.scipy.linalg.qr(A.T, mode="economic")
return R.T
_, R_qr = jnp.linalg.qr(_adj(A), mode="reduced")
return _adj(R_qr)


@tria.defjvp
def _tria_jvp(primals, tangents):
# Derivation of the exact analytical JVP for the lower-triangularization operator:
#
# 1. The operation computes a lower triangular R such that:
# R @ R.T = A @ A.T
#
# 2. Taking the differential of both sides yields the Gramian differential identity:
# dR @ R.T + R @ dR.T = dA @ A.T + A @ dA.T
#
# 3. For rank-deficient A, the exact lower-triangular tangent dR decomposes into
# column-space and null-space components: dR = dR_col + dR_null.
#
# 4. To find dR_col, multiply the identity from the left by R^\dagger (pseudoinverse)
# and from the right by R^{\dagger T}:
# R^\dagger @ dR_col + dR_col.T @ R^{\dagger T} = R^\dagger @ dA @ A.T @ R^{\dagger T} + R^\dagger @ A @ dA.T @ R^{\dagger T}
#
# 5. Let Q be the active subspace orthogonal factor such that A = R @ Q.T.
# Substituting A.T @ R^{\dagger T} = Q and R^\dagger @ A = Q.T:
# R^\dagger @ dR_col + dR_col.T @ R^{\dagger T} = R^\dagger @ dA @ Q + Q.T @ dA.T @ R^{\dagger T}
#
# 6. Define K = R^\dagger @ dA @ Q. The right hand side becomes K + K.T:
# R^\dagger @ dR_col + (R^\dagger @ dR_col).T = K + K.T
#
# 7. Define dM_col = R^\dagger @ dR_col. Solving for the lower-triangular dM_col:
# dM_col = tril(K + K.T) - diag(K)
#
# 8. Recover the column-space differential dR_col by left-multiplying by R:
# dR_col = R @ dM_col
#
# 9. To satisfy the orthogonal cross-terms for arbitrary perturbations outside
# the column space of A, add the null-space component projected via (I - R @ R^\dagger):
# dR_null = (I - R @ R^\dagger) @ dA @ Q
#
# 10. The complete, exact JVP is the sum of both components:
# dR = dR_col + dR_null

(A,) = primals
(dA,) = tangents

A_T = jnp.swapaxes(A, -1, -2)
Q, R_qr = jnp.linalg.qr(A_T, mode="reduced")

R = jnp.swapaxes(R_qr, -1, -2)

# Q has shape (..., M, N). A is (..., N, M).
# A^T = Q R_qr => A = R_qr^T Q^T = R Q^T

R_pinv = jnp.linalg.pinv(R)

# K = R^{-1} dA Q
# R can be degenerate so we use the pseudoinverse to ensure the JVP is well-defined everywhere,
K = R_pinv @ dA @ Q
K_T = jnp.swapaxes(K, -1, -2)

# Solve for lower triangular perturbation dM + dM^T = K + K^T
I = jnp.eye(K.shape[-1], dtype=K.dtype)
dM = jnp.tril(K + K_T) - K * I

# Compute the null-space part
dR_null = (jnp.eye(R.shape[-2], dtype=R.dtype) - R @ R_pinv) @ dA @ Q

# Apply to get the tangent at R
dR = R @ dM + dR_null

return R, dR
31 changes: 30 additions & 1 deletion tests/cuthbert/gaussian/test_kalman.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,16 @@ def get_observation_params(model_inputs: int) -> tuple[Array, Array, Array, Arra
return filter, smoother, model_inputs


seeds = [1, 43, 99, 123, 456]
seeds = [1, 43, 99]
x_dims = [3]
y_dims = [1, 2]
num_time_steps = [1, 25]

common_params = list(itertools.product(seeds, x_dims, y_dims, num_time_steps))

# Use fewer time steps for gradient test because finite difference will not work well for long sequences
gradient_params = list(itertools.product(seeds, x_dims, y_dims, [1, 2]))


@pytest.mark.parametrize("seed,x_dim,y_dim,num_time_steps", common_params)
def test_offline_filter(seed, x_dim, y_dim, num_time_steps):
Expand Down Expand Up @@ -150,6 +153,32 @@ def test_offline_filter(seed, x_dim, y_dim, num_time_steps):
)


@pytest.mark.parametrize("seed,x_dim,y_dim,num_time_steps", gradient_params)
def test_check_gradient(seed, x_dim, y_dim, num_time_steps):
m0, chol_P0, Fs, cs, chol_Qs, Hs, ds, chol_Rs, ys = generate_lgssm(
seed, x_dim, y_dim, num_time_steps
)

@jax.jit
def get_loglikelihood_from_params(
m0_, chol_P0_, Fs_, cs_, chol_Qs_, Hs_, ds_, chol_Rs_
):
kalman_filter, _, model_inputs = load_kalman_inference(
m0_, chol_P0_, Fs_, cs_, chol_Qs_, Hs_, ds_, chol_Rs_, ys
)
states = filter(kalman_filter, model_inputs)
return states.log_normalizing_constant[-1]

chex.assert_numerical_grads(
get_loglikelihood_from_params,
(m0, chol_P0, Fs, cs, chol_Qs, Hs, ds, chol_Rs),
order=1,
rtol=0.01,
atol=0.01,
eps=1e-5,
)


@pytest.mark.parametrize("seed", [1, 43, 99, 123, 456])
@pytest.mark.parametrize("x_dim", [1, 10])
@pytest.mark.parametrize("y_dim", [1, 5])
Expand Down
5 changes: 4 additions & 1 deletion tests/cuthbertlib/kalman/test_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def test_predict(seed, x_dim):
pred_cov = pred_chol_cov @ pred_chol_cov.T

des_m, des_cov = std_predict(m0, P0, F, c, Q)

chex.assert_trees_all_close((pred_m, pred_cov), (des_m, des_cov), rtol=1e-10)


Expand All @@ -78,6 +77,10 @@ def test_update(seed, x_dim, y_dim):
(m, chol_P), ell = update(m0, chol_P0, H, d, chol_R, y)
P = chol_P @ chol_P.T

chex.assert_numerical_grads(
lambda *args: update(*args)[-1], (m0, chol_P0, H, d, chol_R, y), order=1
)

des_m, des_P, des_ell = std_update(m0, P0, H, d, R, y)

chex.assert_trees_all_close((m, P), (des_m, des_P), rtol=1e-10)
Expand Down
17 changes: 17 additions & 0 deletions tests/cuthbertlib/linalg/test_tria.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,20 @@ def test_tria(seed, shape):

# Check that R @ R.T = A @ A.T
assert jnp.allclose(R @ R.T, A @ A.T)


def test_tria_jvp_preserves_gram_matrix_for_rank_deficient_input():
A = jnp.array([[1.0, 0.0], [1.0, 0.0], [1.0, 0.0]])
dA = jnp.array([[0.0, 0.0], [1.0, 2.0], [3.0, 4.0]])

def gram_via_tria(x):
R = tria(x)
return R @ R.T

def gram_direct(x):
return x @ x.T

gram_jvp_from_tria = jax.jvp(gram_via_tria, (A,), (dA,))[1]
gram_jvp_direct = jax.jvp(gram_direct, (A,), (dA,))[1]

assert jnp.allclose(gram_jvp_from_tria, gram_jvp_direct)
Loading