Skip to content

Conversation

@Ashutosh0x
Copy link

Resolves #422. This PR modernizes the _check_sharding function to use the leaf.sharding attribute directly for jax.Array objects, bypassing jax.typeof(leaf) which could be unreliable or inconsistent in newer JAX versions (0.4.35+ and the upcoming 0.8.x). It also refines sharding-related error messages to be more accurate while maintaining backward compatibility with existing test regexes.

@hbq1
Copy link
Member

hbq1 commented Jan 6, 2026

Hi @Ashutosh0x thanks for the PR! Happy to accept it if you could address my comment + fix the style warnings from the CI.

@Ashutosh0x Ashutosh0x force-pushed the fix/jax-0.8.2-compat branch from 66b2b6a to ba1e5fc Compare January 6, 2026 14:11
@Ashutosh0x
Copy link
Author

Ashutosh0x commented Jan 6, 2026

Thanks @hbq1! I've addressed the feedback:

  • Added the else branch in _check_sharding to restore backward compatibility using jax.typeof(x).sharding.
  • Fixed the style warnings by cleaning up the redundant parentheses in assert_tree_is_on_device.

Ready for another look!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

jax 0.8.2 incpatibility

2 participants