Skip to content

Conversation

@davecwright3
Copy link
Collaborator

@davecwright3 davecwright3 commented May 20, 2025

There were a number of places in discovery.matrix using hard coded NumPy and Scipy instead of the configured jnp and js, respectively. Some of these calls were in the outer functions of the make_* methods, so that's not a big deal. However, a number of the calls, especially in the *_varN classes/methods, were calling hard coded NumPy and Scipy in inner methods. There were also a few places where *.linalg.*_solve() *.linalg.*_factor() were called instead of the matrix_solve and matrix_factor globals, so they were hard coded to use cholesky methods.

I'm not sure if this was intentional, so this is a draft PR for now. @meyers-academic can you take a look at this and let me know what you think? I think you've looked at NumPy vs JAX in the _varN methods before?

@davecwright3
Copy link
Collaborator Author

Note: I've left the Uind method as pure NumPy because of the in-place edits that function makes. I didn't want to use .at[ind].set(val) because that does not work with the NumPy backend.

@davecwright3 davecwright3 changed the title Fix: Use JAX Scipy and NumPy in discovery.matrix fix: Use JAX Scipy and NumPy in discovery.matrix May 20, 2025
@meyers-academic
Copy link
Collaborator

I think many (most) of these are things that should stay the way they are. @vallis should confirm, but I think that in general it was written to use numpy unless in an inner function. I think many of these matrix_factor issues are similar -- there are places where one explicitly wants the numpy or normal scipy (e.g. inv() that returns a static matrix) vs. places where jax.scipy or matrix_factor are used (e.g. inner functions).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants