Skip to content

Conversation

@Steboss
Copy link
Contributor

@Steboss Steboss commented Apr 16, 2025

This PR is for testing my changes to AXLearn, to avoid having this error:

[pod/axlearn-fuji-3b-14466154556-kgkm5/axlearn-fuji-model]   File "/opt/axlearn/axlearn/common/trainer.py", line 1162, in compile_train_step
[pod/axlearn-fuji-3b-14466154556-kgkm5/axlearn-fuji-model]     lowered_train_step = jit_train_step.lower(trainer_state, input_batch)
[pod/axlearn-fuji-3b-14466154556-kgkm5/axlearn-fuji-model]   File "/opt/axlearn/axlearn/common/trainer.py", line 1204, in _train_step
[pod/axlearn-fuji-3b-14466154556-kgkm5/axlearn-fuji-model]     fwd_bwd_outputs, learner_output_collection = F(
[pod/axlearn-fuji-3b-14466154556-kgkm5/axlearn-fuji-model]   File "/opt/axlearn/axlearn/common/module.py", line 1096, in functional
[pod/axlearn-fuji-3b-14466154556-kgkm5/axlearn-fuji-model]     method_outputs, output_collection = fn(*args, **kwargs)
[pod/axlearn-fuji-3b-14466154556-kgkm5/axlearn-fuji-model]   File "/opt/axlearn/axlearn/common/module.py", line 1032, in __call__
[pod/axlearn-fuji-3b-14466154556-kgkm5/axlearn-fuji-model]     raise_for_cycles(dict(context=self.context, args=args, kwargs=kwargs))
[pod/axlearn-fuji-3b-14466154556-kgkm5/axlearn-fuji-model]   File "/opt/axlearn/axlearn/common/utils.py", line 1930, in raise_for_cycles
[pod/axlearn-fuji-3b-14466154556-kgkm5/axlearn-fuji-model]     cycles = find_cycles(tree)
[pod/axlearn-fuji-3b-14466154556-kgkm5/axlearn-fuji-model]   File "/opt/axlearn/axlearn/common/utils.py", line 1924, in find_cycles
[pod/axlearn-fuji-3b-14466154556-kgkm5/axlearn-fuji-model]     return _find_cycles(tree, key_path=[], seen=[])
[pod/axlearn-fuji-3b-14466154556-kgkm5/axlearn-fuji-model]   File "/opt/axlearn/axlearn/common/utils.py", line 1914, in _find_cycles
[pod/axlearn-fuji-3b-14466154556-kgkm5/axlearn-fuji-model]     items = pytree_children(tree)
[pod/axlearn-fuji-3b-14466154556-kgkm5/axlearn-fuji-model]   File "/opt/axlearn/axlearn/common/utils.py", line 1872, in pytree_children
[pod/axlearn-fuji-3b-14466154556-kgkm5/axlearn-fuji-model]     registry_with_keypaths = jax._src.tree_util._registry_with_keypaths
[pod/axlearn-fuji-3b-14466154556-kgkm5/axlearn-fuji-model] AttributeError: module 'jax._src.tree_util' has no attribute '_registry_with_keypaths'

with recent JAX version.
The changes are on my AXLearn repo

I'll flag this error to AXLearn team

@Steboss Steboss requested a review from olupton April 17, 2025 11:10
@Steboss
Copy link
Contributor Author

Steboss commented Apr 28, 2025

As soon as this PR will be merged, I can remove JAX and XLA pins. This new XLA PR fixes this error:

ValueError: INTERNAL: Expected command buffer to be in state create but it was in state update: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).

and my AXLearn fork fixes JAX deprecation. Here is the PR I opened on AXLearn with new changes.

@Steboss Steboss requested a review from olupton May 6, 2025 17:37
olupton
olupton previously approved these changes May 7, 2025
@Steboss Steboss merged commit 5b20692 into main May 8, 2025
89 of 91 checks passed
@Steboss Steboss deleted the sbosisio/axlearn-test-tree branch May 8, 2025 09:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants