Skip to content

Fix gradient calculation for Kalman.#213

Merged
AdrienCorenflos merged 8 commits intomainfrom
fix-gradient-kalman
Mar 11, 2026
Merged

Fix gradient calculation for Kalman.#213
AdrienCorenflos merged 8 commits intomainfrom
fix-gradient-kalman

Conversation

@AdrienCorenflos
Copy link
Contributor

There were two issues:

  1. using logdet(D_inv) create some failures in the JVP thereof, I'm still not sure why, it was related to solve(x, b) for a matrix x that was actually perfectly fine. I have replaced this with a better calculation for the determinant anyway, see derivation comment before it.
  2. The QR decomposition failed to compute gradients when A had degenerate rank. This is normal because it is not defined where A is degenerate due to the QR decomposition being non-unique. However, what we use in probabilistic terms are gradients with respect to parameters of R @ R.T, which is unique and encodes the same info as cov = chol @ chol.T. So I fixed the tria JVP by making sure that the calculation of the gradient was correct for the span of R, and otherwise invariant for decompositions giving the same R @ R.T.

Closes #211

There were two issues:
1. using logdet(D_inv) create some failures in the JVP thereof, I'm still not sure why, it was related to solve(x, b) for a matrix x that was actually perfectly fine. I have replaced this with a better calculation for the determinant anyway, see derivation comment before it.
2. The QR decomposition failed to compute gradients when A had degenerate rank. This is normal because it is not defined where A is degenerate due to the QR decomposition being non-unique. However, what we use in probabilistic terms are gradients with respect to parameters of R @ R.T, which is unique and encodes the same info as cov = chol @ chol.T. So I fixed the tria JVP by making sure that the calculation of the gradient was correct for the span of R, and otherwise invariant for decompositions giving the same R @ R.T.
@Sahel13
Copy link
Collaborator

Sahel13 commented Mar 9, 2026

@AdrienCorenflos isn't the following test supposed to pass (it doesn't)?

def test_tria_jvp_preserves_gram_matrix_for_rank_deficient_input():
    A = jnp.array([[1.0, 0.0], [1.0, 0.0], [2.0, 3.0]])
    dA = jnp.array([[0.0, 0.0], [1.0, 2.0], [3.0, 4.0]])

    def gram_via_tria(x):
        R = tria(x)
        return R @ R.T

    def gram_direct(x):
        return x @ x.T

    gram_jvp_from_tria = jax.jvp(gram_via_tria, (A,), (dA,))[1]
    gram_jvp_direct = jax.jvp(gram_direct, (A,), (dA,))[1]

    assert jnp.allclose(gram_jvp_from_tria, gram_jvp_direct)

Co-authored-by: Sahel Iqbal <sahel13miqbal@proton.me>
@AdrienCorenflos
Copy link
Contributor Author

AdrienCorenflos commented Mar 9, 2026

@AdrienCorenflos isn't the following test supposed to pass (it doesn't)?

def test_tria_jvp_preserves_gram_matrix_for_rank_deficient_input():
    A = jnp.array([[1.0, 0.0], [1.0, 0.0], [2.0, 3.0]])
    dA = jnp.array([[0.0, 0.0], [1.0, 2.0], [3.0, 4.0]])

    def gram_via_tria(x):
        R = tria(x)
        return R @ R.T

    def gram_direct(x):
        return x @ x.T

    gram_jvp_from_tria = jax.jvp(gram_via_tria, (A,), (dA,))[1]
    gram_jvp_direct = jax.jvp(gram_direct, (A,), (dA,))[1]

    assert jnp.allclose(gram_jvp_from_tria, gram_jvp_direct)

How come the tests are passing?

@Sahel13
Copy link
Collaborator

Sahel13 commented Mar 9, 2026

How come the tests are passing?

not sure tbh, still checking

edit: this one is a new one that i added, btw

@AdrienCorenflos
Copy link
Contributor Author

I think that should pass though yes. What's the output?

@Sahel13
Copy link
Collaborator

Sahel13 commented Mar 9, 2026

I think that should pass though yes. What's the output?

E       assert Array(False, dtype=bool)
E        +  where Array(False, dtype=bool) = <PjitFunction of <function allclose at 0x7f757c4d6480>>(Array([[ 1.,  1.,  7.],\n       [ 1.,  1.,  7.],\n       [ 7.,  7., 36.]], dtype=float64), Array([[ 0.,  1.,  3.],\n       [ 1.,  2., 11.],\n       [ 3., 11., 36.]], dtype=float64))
E        +    where <PjitFunction of <function allclose at 0x7f757c4d6480>> = jnp.allclose

@AdrienCorenflos
Copy link
Contributor Author

AdrienCorenflos commented Mar 9, 2026

Oh I see, I think your dA vectors are not in the span of A.

Edit: you can't have a tangent vector that's not compatible with the input. That's just not possible.

@Sahel13
Copy link
Collaborator

Sahel13 commented Mar 9, 2026

ok yeah good point

edit: where is this requirement for dA to be in the span of A coming from, though? can't the input tangent be anything that has a valid shape? i can see that dR is constrained

@AdrienCorenflos
Copy link
Contributor Author

I'll have a look, I think you're right and I ignored a term in the kernel of R I shouldn't have

@AdrienCorenflos
Copy link
Contributor Author

@Sahel13 I had forgotten the null-space part of the gradient because of a stupid mistake!

I had written $R R^{\dagger} = I$ when it's actually only true if R is invertible...

@AdrienCorenflos
Copy link
Contributor Author

I put in the test you suggested @Sahel13 and also increased precision for the kalman gradient test: what I thought was a numerical issue in finite differences was actually a plain bug!

Co-authored-by: Sahel Iqbal <sahel13miqbal@proton.me>
@AdrienCorenflos AdrienCorenflos merged commit 736c509 into main Mar 11, 2026
2 checks passed
@AdrienCorenflos AdrienCorenflos deleted the fix-gradient-kalman branch March 11, 2026 14:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Numerical Instabilities for Kalman Filter Gradients

2 participants