Skip to content

Commit 5b20692

Browse files
Stebossolupton
andauthored
test tree_util changes (#1396)
This PR is for testing my changes to AXLearn, to avoid having this error: ```python [pod/axlearn-fuji-3b-14466154556-kgkm5/axlearn-fuji-model] File "/opt/axlearn/axlearn/common/trainer.py", line 1162, in compile_train_step [pod/axlearn-fuji-3b-14466154556-kgkm5/axlearn-fuji-model] lowered_train_step = jit_train_step.lower(trainer_state, input_batch) [pod/axlearn-fuji-3b-14466154556-kgkm5/axlearn-fuji-model] File "/opt/axlearn/axlearn/common/trainer.py", line 1204, in _train_step [pod/axlearn-fuji-3b-14466154556-kgkm5/axlearn-fuji-model] fwd_bwd_outputs, learner_output_collection = F( [pod/axlearn-fuji-3b-14466154556-kgkm5/axlearn-fuji-model] File "/opt/axlearn/axlearn/common/module.py", line 1096, in functional [pod/axlearn-fuji-3b-14466154556-kgkm5/axlearn-fuji-model] method_outputs, output_collection = fn(*args, **kwargs) [pod/axlearn-fuji-3b-14466154556-kgkm5/axlearn-fuji-model] File "/opt/axlearn/axlearn/common/module.py", line 1032, in __call__ [pod/axlearn-fuji-3b-14466154556-kgkm5/axlearn-fuji-model] raise_for_cycles(dict(context=self.context, args=args, kwargs=kwargs)) [pod/axlearn-fuji-3b-14466154556-kgkm5/axlearn-fuji-model] File "/opt/axlearn/axlearn/common/utils.py", line 1930, in raise_for_cycles [pod/axlearn-fuji-3b-14466154556-kgkm5/axlearn-fuji-model] cycles = find_cycles(tree) [pod/axlearn-fuji-3b-14466154556-kgkm5/axlearn-fuji-model] File "/opt/axlearn/axlearn/common/utils.py", line 1924, in find_cycles [pod/axlearn-fuji-3b-14466154556-kgkm5/axlearn-fuji-model] return _find_cycles(tree, key_path=[], seen=[]) [pod/axlearn-fuji-3b-14466154556-kgkm5/axlearn-fuji-model] File "/opt/axlearn/axlearn/common/utils.py", line 1914, in _find_cycles [pod/axlearn-fuji-3b-14466154556-kgkm5/axlearn-fuji-model] items = pytree_children(tree) [pod/axlearn-fuji-3b-14466154556-kgkm5/axlearn-fuji-model] File "/opt/axlearn/axlearn/common/utils.py", line 1872, in pytree_children [pod/axlearn-fuji-3b-14466154556-kgkm5/axlearn-fuji-model] registry_with_keypaths = jax._src.tree_util._registry_with_keypaths [pod/axlearn-fuji-3b-14466154556-kgkm5/axlearn-fuji-model] AttributeError: module 'jax._src.tree_util' has no attribute '_registry_with_keypaths' ``` with recent JAX version. The changes are on [my AXLearn repo](https://github.com/Steboss/axlearn/blob/sbosisio/tree_util/axlearn/common/utils.py#L1862) I'll flag this error to AXLearn team --------- Co-authored-by: Olli Lupton <[email protected]>
1 parent b28cc73 commit 5b20692

File tree

4 files changed

+10
-3
lines changed

4 files changed

+10
-3
lines changed

.github/container/Dockerfile.axlearn

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# syntax=docker/dockerfile:1-labs
22
ARG BASE_IMAGE=ghcr.io/nvidia/jax-mealkit:jax
3-
ARG URLREF_AXLEARN=https://github.com/Steboss/axlearn.git#main
3+
ARG URLREF_AXLEARN=https://github.com/apple/axlearn.git
44
ARG SRC_PATH_AXLEARN=/opt/axlearn
55

66
###############################################################################

.github/container/Dockerfile.jax

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ RUN build-jax.sh \
5353
--src-path-jax ${SRC_PATH_JAX} \
5454
--src-path-xla ${SRC_PATH_XLA} \
5555
--sm all \
56-
--xla-arm64-patch /opt/xla-arm64-neon.patch \
56+
--xla-arm64-patch /opt/xla-arm64-neon.patch \
5757
--clean
5858

5959
## Transformer engine: check out source and build wheel

.github/container/manifest.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,3 +101,7 @@ pathwaysutils:
101101
tracking_ref: main
102102
latest_verified_commit: 359776d454940ffaa337c36d1df16308d44a95a9
103103
mode: pip-vcs
104+
axlearn:
105+
url: https://github.com/Steboss/axlearn.git
106+
tracking_ref: sbosisio/working_branch
107+
mode: git-clone

.github/workflows/_ci.yaml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,8 @@ jobs:
134134
CONTAINER_NAME: axlearn
135135
DOCKERFILE: .github/container/Dockerfile.axlearn
136136
RUNNER_SIZE: large
137+
EXTRA_BUILD_ARGS: |
138+
URLREF_AXLEARN=${{ fromJson(inputs.SOURCE_URLREFS).AXLEARN }}
137139
secrets: inherit
138140

139141
collect-docker-tags:
@@ -438,6 +440,7 @@ jobs:
438440
# ARTIFACTS: |
439441
# test-equinox.log
440442
# secrets: inherit
443+
441444
test-te-h100:
442445
needs: build-jax
443446
if: inputs.ARCHITECTURE == 'amd64'
@@ -478,7 +481,7 @@ jobs:
478481
# merge the log files
479482
cat \
480483
/log/pytest-report-L0-unittest.jsonl
481-
/log/pytest-report-L0-distributed-unittest.jsonl
484+
/log/pytest-report-L0-distributed-unittest.jsonl
482485
> /log/pytest-report.jsonl
483486
EOF
484487
STATISTICS_SCRIPT: |

0 commit comments

Comments
 (0)