Fix gradient calculation for Kalman.#213
Conversation
There were two issues: 1. using logdet(D_inv) create some failures in the JVP thereof, I'm still not sure why, it was related to solve(x, b) for a matrix x that was actually perfectly fine. I have replaced this with a better calculation for the determinant anyway, see derivation comment before it. 2. The QR decomposition failed to compute gradients when A had degenerate rank. This is normal because it is not defined where A is degenerate due to the QR decomposition being non-unique. However, what we use in probabilistic terms are gradients with respect to parameters of R @ R.T, which is unique and encodes the same info as cov = chol @ chol.T. So I fixed the tria JVP by making sure that the calculation of the gradient was correct for the span of R, and otherwise invariant for decompositions giving the same R @ R.T.
|
@AdrienCorenflos isn't the following test supposed to pass (it doesn't)? def test_tria_jvp_preserves_gram_matrix_for_rank_deficient_input():
A = jnp.array([[1.0, 0.0], [1.0, 0.0], [2.0, 3.0]])
dA = jnp.array([[0.0, 0.0], [1.0, 2.0], [3.0, 4.0]])
def gram_via_tria(x):
R = tria(x)
return R @ R.T
def gram_direct(x):
return x @ x.T
gram_jvp_from_tria = jax.jvp(gram_via_tria, (A,), (dA,))[1]
gram_jvp_direct = jax.jvp(gram_direct, (A,), (dA,))[1]
assert jnp.allclose(gram_jvp_from_tria, gram_jvp_direct) |
Co-authored-by: Sahel Iqbal <sahel13miqbal@proton.me>
How come the tests are passing? |
not sure tbh, still checking edit: this one is a new one that i added, btw |
|
I think that should pass though yes. What's the output? |
E assert Array(False, dtype=bool)
E + where Array(False, dtype=bool) = <PjitFunction of <function allclose at 0x7f757c4d6480>>(Array([[ 1., 1., 7.],\n [ 1., 1., 7.],\n [ 7., 7., 36.]], dtype=float64), Array([[ 0., 1., 3.],\n [ 1., 2., 11.],\n [ 3., 11., 36.]], dtype=float64))
E + where <PjitFunction of <function allclose at 0x7f757c4d6480>> = jnp.allclose |
|
Oh I see, I think your dA vectors are not in the span of A. Edit: you can't have a tangent vector that's not compatible with the input. That's just not possible. |
|
ok yeah good point edit: where is this requirement for dA to be in the span of A coming from, though? can't the input tangent be anything that has a valid shape? i can see that dR is constrained |
|
I'll have a look, I think you're right and I ignored a term in the kernel of R I shouldn't have |
|
@Sahel13 I had forgotten the null-space part of the gradient because of a stupid mistake! I had written |
|
I put in the test you suggested @Sahel13 and also increased precision for the kalman gradient test: what I thought was a numerical issue in finite differences was actually a plain bug! |
Co-authored-by: Sahel Iqbal <sahel13miqbal@proton.me>
There were two issues:
Closes #211