Commit 5b20692
test tree_util changes (#1396)
This PR is for testing my changes to AXLearn, to avoid having this
error:
```python
[pod/axlearn-fuji-3b-14466154556-kgkm5/axlearn-fuji-model] File "/opt/axlearn/axlearn/common/trainer.py", line 1162, in compile_train_step
[pod/axlearn-fuji-3b-14466154556-kgkm5/axlearn-fuji-model] lowered_train_step = jit_train_step.lower(trainer_state, input_batch)
[pod/axlearn-fuji-3b-14466154556-kgkm5/axlearn-fuji-model] File "/opt/axlearn/axlearn/common/trainer.py", line 1204, in _train_step
[pod/axlearn-fuji-3b-14466154556-kgkm5/axlearn-fuji-model] fwd_bwd_outputs, learner_output_collection = F(
[pod/axlearn-fuji-3b-14466154556-kgkm5/axlearn-fuji-model] File "/opt/axlearn/axlearn/common/module.py", line 1096, in functional
[pod/axlearn-fuji-3b-14466154556-kgkm5/axlearn-fuji-model] method_outputs, output_collection = fn(*args, **kwargs)
[pod/axlearn-fuji-3b-14466154556-kgkm5/axlearn-fuji-model] File "/opt/axlearn/axlearn/common/module.py", line 1032, in __call__
[pod/axlearn-fuji-3b-14466154556-kgkm5/axlearn-fuji-model] raise_for_cycles(dict(context=self.context, args=args, kwargs=kwargs))
[pod/axlearn-fuji-3b-14466154556-kgkm5/axlearn-fuji-model] File "/opt/axlearn/axlearn/common/utils.py", line 1930, in raise_for_cycles
[pod/axlearn-fuji-3b-14466154556-kgkm5/axlearn-fuji-model] cycles = find_cycles(tree)
[pod/axlearn-fuji-3b-14466154556-kgkm5/axlearn-fuji-model] File "/opt/axlearn/axlearn/common/utils.py", line 1924, in find_cycles
[pod/axlearn-fuji-3b-14466154556-kgkm5/axlearn-fuji-model] return _find_cycles(tree, key_path=[], seen=[])
[pod/axlearn-fuji-3b-14466154556-kgkm5/axlearn-fuji-model] File "/opt/axlearn/axlearn/common/utils.py", line 1914, in _find_cycles
[pod/axlearn-fuji-3b-14466154556-kgkm5/axlearn-fuji-model] items = pytree_children(tree)
[pod/axlearn-fuji-3b-14466154556-kgkm5/axlearn-fuji-model] File "/opt/axlearn/axlearn/common/utils.py", line 1872, in pytree_children
[pod/axlearn-fuji-3b-14466154556-kgkm5/axlearn-fuji-model] registry_with_keypaths = jax._src.tree_util._registry_with_keypaths
[pod/axlearn-fuji-3b-14466154556-kgkm5/axlearn-fuji-model] AttributeError: module 'jax._src.tree_util' has no attribute '_registry_with_keypaths'
```
with recent JAX version.
The changes are on [my AXLearn
repo](https://github.com/Steboss/axlearn/blob/sbosisio/tree_util/axlearn/common/utils.py#L1862)
I'll flag this error to AXLearn team
---------
Co-authored-by: Olli Lupton <[email protected]>1 parent b28cc73 commit 5b20692
File tree
4 files changed
+10
-3
lines changed- .github
- container
- workflows
4 files changed
+10
-3
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1 | 1 | | |
2 | 2 | | |
3 | | - | |
| 3 | + | |
4 | 4 | | |
5 | 5 | | |
6 | 6 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
53 | 53 | | |
54 | 54 | | |
55 | 55 | | |
56 | | - | |
| 56 | + | |
57 | 57 | | |
58 | 58 | | |
59 | 59 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
101 | 101 | | |
102 | 102 | | |
103 | 103 | | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
134 | 134 | | |
135 | 135 | | |
136 | 136 | | |
| 137 | + | |
| 138 | + | |
137 | 139 | | |
138 | 140 | | |
139 | 141 | | |
| |||
438 | 440 | | |
439 | 441 | | |
440 | 442 | | |
| 443 | + | |
441 | 444 | | |
442 | 445 | | |
443 | 446 | | |
| |||
478 | 481 | | |
479 | 482 | | |
480 | 483 | | |
481 | | - | |
| 484 | + | |
482 | 485 | | |
483 | 486 | | |
484 | 487 | | |
| |||
0 commit comments