Skip to content

Commit 8fa9182

Browse files
committed
Merge branch 'sbosisio/axlearn_improvements' of github.com:NVIDIA/JAX-Toolbox into sbosisio/axlearn_improvements
2 parents c47a318 + 4fc790d commit 8fa9182

File tree

78 files changed

+334
-2262
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

78 files changed

+334
-2262
lines changed

.github/container/Dockerfile.jax

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,11 +98,12 @@ ADD build-jax.sh local_cuda_arch pytest-xdist.sh test-jax.sh /usr/local/bin/
9898
RUN mkdir -p /opt/pip-tools.d
9999

100100
## Editable installations of jax and jaxlib
101+
# Note that jax now is an independent wheel, extra [k8s] needs to be from build path also
101102
RUN <<"EOF" bash -ex
102103
for component in $(ls ${BUILD_PATH_JAXLIB}); do
103104
echo "-e file://${BUILD_PATH_JAXLIB}/${component}" >> /opt/pip-tools.d/requirements-jax.in;
104105
done
105-
echo "-e file://${SRC_PATH_JAX}[k8s]" >> /opt/pip-tools.d/requirements-jax.in
106+
echo "-e file://${BUILD_PATH_JAXLIB}/jax[k8s]" >> /opt/pip-tools.d/requirements-jax.in
106107
EOF
107108

108109
## Flax

.github/container/Dockerfile.mjx

Lines changed: 0 additions & 54 deletions
This file was deleted.

.github/container/build-jax.sh

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,8 @@ pushd ${SRC_PATH_JAX}
288288
time python "${SRC_PATH_JAX}/build/build.py" build \
289289
--editable \
290290
--use_clang \
291-
--wheels=jaxlib,jax-cuda-plugin,jax-cuda-pjrt \
291+
--use_new_wheel_build_rule \
292+
--wheels=jax,jaxlib,jax-cuda-plugin,jax-cuda-pjrt \
292293
--cuda_compute_capabilities=$TF_CUDA_COMPUTE_CAPABILITIES \
293294
--bazel_options=--linkopt=-fuse-ld=lld \
294295
--local_xla_path=$SRC_PATH_XLA \
@@ -298,12 +299,13 @@ popd
298299

299300
# Make sure that JAX depends on the local jaxlib installation
300301
# https://jax.readthedocs.io/en/latest/developer.html#specifying-dependencies-on-local-wheels
301-
line="jaxlib @ file://${BUILD_PATH_JAXLIB}/jaxlib"
302+
line="jax @ file://${BUILD_PATH_JAXLIB}/jax"
302303
if ! grep -xF "${line}" "${SRC_PATH_JAX}/build/requirements.in"; then
303304
pushd "${SRC_PATH_JAX}"
304305
echo "${line}" >> build/requirements.in
305-
echo "jax-cuda${TF_CUDA_MAJOR_VERSION}-pjrt @ file://${BUILD_PATH_JAXLIB}/jax-cuda-pjrt" >> build/requirements.in
306-
echo "jax-cuda${TF_CUDA_MAJOR_VERSION}-plugin @ file://${BUILD_PATH_JAXLIB}/jax-cuda-plugin" >> build/requirements.in
306+
echo "jaxlib @ file://${BUILD_PATH_JAXLIB}/jaxlib" >> build/requirements.in
307+
echo "jax-cuda${TF_CUDA_MAJOR_VERSION}-pjrt @ file://${BUILD_PATH_JAXLIB}/jax_cuda${TF_CUDA_MAJOR_VERSION}_pjrt" >> build/requirements.in
308+
echo "jax-cuda${TF_CUDA_MAJOR_VERSION}-plugin @ file://${BUILD_PATH_JAXLIB}/jax_cuda${TF_CUDA_MAJOR_VERSION}_plugin" >> build/requirements.in
307309
PYTHON_VERSION=$(python -c 'import sys; print("{}.{}".format(*sys.version_info[:2]))')
308310
bazel run --verbose_failures=true //build:requirements.update --repo_env=HERMETIC_PYTHON_VERSION="${PYTHON_VERSION}"
309311
popd
@@ -318,13 +320,13 @@ else
318320
fi
319321

320322
# install jax and jaxlib
321-
pip --disable-pip-version-check install -e ${BUILD_PATH_JAXLIB}/jaxlib -e ${BUILD_PATH_JAXLIB}/jax-cuda-pjrt -e ${BUILD_PATH_JAXLIB}/jax-cuda-plugin -e "${SRC_PATH_JAX}"
323+
pip --disable-pip-version-check install -e ${BUILD_PATH_JAXLIB}/jaxlib -e ${BUILD_PATH_JAXLIB}/jax_cuda${TF_CUDA_MAJOR_VERSION}_pjrt -e ${BUILD_PATH_JAXLIB}/jax_cuda${TF_CUDA_MAJOR_VERSION}_plugin -e ${BUILD_PATH_JAXLIB}/jax
322324

323325
## after installation (example)
324-
# jax 0.4.36.dev20241125+f828f2d7d /opt/jax
325-
# jax-cuda12-pjrt 0.4.36.dev20241125 /opt/jaxlibs/jax-cuda-pjrt
326-
# jax-cuda12-plugin 0.4.36.dev20241125 /opt/jaxlibs/jax-cuda-plugin
327-
# jaxlib 0.4.36.dev20241125 /opt/jaxlibs/jaxlib
326+
# jax 0.5.4.dev20250325 /opt/jaxlibs/jax
327+
# jax-cuda12-pjrt 0.5.4.dev20250325 /opt/jaxlibs/jax_cuda12_pjrt
328+
# jax-cuda12-plugin 0.5.4.dev20250325 /opt/jaxlibs/jax_cuda12_plugin
329+
# jaxlib 0.5.4.dev20250325 /opt/jaxlibs/jaxlib
328330
pip list | grep jax
329331

330332
# Ensure directories are readable by all for non-root users

.github/container/manifest.yaml

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,11 @@ seqio:
7171
tracking_ref: main
7272
latest_verified_commit: 11706e4a1e01a81ea6b3e02c5ad147028d5b94bb
7373
mode: pip-vcs
74+
google-jetstream:
75+
url: https://github.com/AI-Hypercomputer/JetStream.git
76+
tracking_ref: main
77+
latest_verified_commit: b8b9cb2ea4668da2c5012fc4c7ba958424d82ac9
78+
mode: pip-vcs
7479
maxtext:
7580
url: https://github.com/google/maxtext.git
7681
tracking_ref: main
@@ -86,21 +91,6 @@ haliax:
8691
tracking_ref: main
8792
latest_verified_commit: 2a696a0c971901ff93afdaa965959d8e3b982ba9
8893
mode: git-clone
89-
mujoco:
90-
url: https://github.com/google-deepmind/mujoco.git
91-
tracking_ref: main
92-
latest_verified_commit: e95159b4f6d48d114b16a8dc13ad26b3e44bc3e2
93-
mode: git-clone
94-
mujoco-mpc:
95-
url: https://github.com/google-deepmind/mujoco_mpc.git
96-
tracking_ref: main
97-
latest_verified_commit: 4700f4a13be18398f5aaf6a33ed42e531967e3ae
98-
mode: git-clone
99-
language-to-reward-2023:
100-
url: https://github.com/google-deepmind/language_to_reward_2023.git
101-
tracking_ref: main
102-
latest_verified_commit: abb8e5125e4ecd0da378490b73448c05a694def5
103-
mode: git-clone
10494
mlperf-logging:
10595
url: https://github.com/mlcommons/logging.git
10696
tracking_ref: master

.github/container/nsys_jax/nsys_jax/analyses/Analysis.ipynb

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121
" xla_module_metadata,\n",
2222
")\n",
2323
"import matplotlib.pyplot as plt\n",
24-
"import numpy as np"
24+
"import numpy as np\n",
25+
"import pathlib"
2526
]
2627
},
2728
{
@@ -33,6 +34,7 @@
3334
"source": [
3435
"# Set the input data to use. default_data_prefix() checks the NSYS_JAX_DEFAULT_PREFIX environment variable, and if that is\n",
3536
"# not set then the current working directory is used. Use pathlib.Path if setting this explicitly.\n",
37+
"prefix = pathlib.Path(\".\") # modify this and comment out the next line\n",
3638
"prefix = default_data_prefix()"
3739
]
3840
},
@@ -128,15 +130,14 @@
128130
"id": "7727d800-13d3-4505-89e8-80a5fed63512",
129131
"metadata": {},
130132
"source": [
131-
"Here the index has four levels. `ProgramId`, `ProgramExecution` and `Device` have the same meanings as in `steady_state.module`.\n",
132-
"The fourth level (in the 3rd position) shows that this row is the `ThunkIndex`-th thunk within the `ProgramExecution`-th execution of XLA module `ProgramId`.\n",
133-
"Note that a given thunk can be executed multiple times within the same module, so indexing on the thunk name would not be unique.\n",
133+
"Here the index has five levels. `ProgramId`, `ProgramExecution` and `Device` have the same meanings as in `steady_state.module`.\n",
134+
"The two new levels, `Name` and `ThunkExecution`, show that a given row is the `ThunkExecution`-th execution within the `ProgramExecution`-th execution of XLA module `ProgramId` of thunk `Name`.\n",
135+
"The `ThunkExecution` value is needed because a given thunk can be executed multiple times within the same module.\n",
136+
"The `Name` of a thunk can be used, along with a `ProgramId`, to look up XLA metadata.\n",
134137
"\n",
135138
"The columns are as follows:\n",
136-
"- `Name`: the name of the thunk; this should be unique within a given `ProgramId` and can be used as a key to look up XLA metadata\n",
137139
"- `ProjStartMs`: see above, same meaning as in `steady_state.module`.\n",
138140
"- `Communication`: does this thunk represent communication between GPUs (*i.e.* a NCCL collective)? XLA overlaps communication and computation kernels, and `load_profiler_data` triggers an overlap calculation. `ProjDurMs` for a communication kernel shows only the duration that was **not** overlapped with computation kernels, while `ProjDurHiddenMs` shows the duration that **was** overlapped.\n",
139-
"- This is the `ThunkExecution`-th execution of this thunk for this `(ProgramId, ProgramExecution, Device)`\n",
140141
"\n",
141142
"The third data frame does not show any GPU execution, but is rather a host-side trace:"
142143
]
@@ -178,7 +179,7 @@
178179
"id": "2e82c357-4e9d-48e4-b758-fa5357b2c8bd",
179180
"metadata": {},
180181
"source": [
181-
"The index structure, and many of the columns, are equivalent to `thunk_df`. Additional columns are:\n",
182+
"The index structure, and many of the columns, are equivalent to the `.thunk` data frame. Additional columns are:\n",
182183
"\n",
183184
"- `MessageSize`: the message size of the collective in bytes; this aims to follow the same conventions as the NCCL tests\n",
184185
"- `Collective`: the type of collective communication\n",
@@ -524,7 +525,9 @@
524525
" # program, there may be different sub-groupings that are participating in smaller\n",
525526
" # collectives in the strict/NCCL sense. TODO: it would be better to identify those\n",
526527
" # sub-groupings and group them, but we currently lack the relevant information.\n",
527-
" collective_df = df.groupby([\"ProgramId\", \"ProgramExecution\", \"ThunkIndex\"])\n",
528+
" collective_df = df.groupby(\n",
529+
" [\"ProgramId\", \"ProgramExecution\", \"Name\", \"ThunkExecution\"]\n",
530+
" )\n",
528531
" # Take the fastest device kernel as a proxy for the actual bandwidth of the\n",
529532
" # collective.\n",
530533
" bandwidth_df = collective_df.agg(\n",
@@ -534,7 +537,6 @@
534537
" \"ProjStartMs\": \"min\",\n",
535538
" \"ProjDurFullMs\": \"min\",\n",
536539
" \"ProjEndMs\": \"max\",\n",
537-
" \"Name\": \"count\",\n",
538540
" }\n",
539541
" )\n",
540542
" axs[0].plot(\n",
@@ -582,9 +584,9 @@
582584
"\n",
583585
"# Calculate statistics over different devices and different executions of each thunk, including multiple executions of the same thunk within the same module\n",
584586
"compute_durations = steady_state.thunk.loc[\n",
585-
" ~steady_state.thunk[\"Communication\"], (\"Name\", \"ProjDurMs\")\n",
587+
" ~steady_state.thunk[\"Communication\"], \"ProjDurMs\"\n",
586588
"].groupby([\"ProgramId\", \"Name\"])\n",
587-
"compute_duration_stats = compute_durations[\"ProjDurMs\"].agg((\"mean\", \"std\"))\n",
589+
"compute_duration_stats = compute_durations.agg((\"mean\", \"std\"))\n",
588590
"compute_duration_means = compute_duration_stats[\"mean\"]\n",
589591
"compute_duration_rel_stds = compute_duration_stats[\"std\"] / compute_duration_means\n",
590592
"\n",
@@ -634,8 +636,7 @@
634636
"\n",
635637
"def durations_ms(idx):\n",
636638
" program_id, thunk_name = idx\n",
637-
" tmp = steady_state.thunk.loc[program_id, (\"Name\", \"ProjDurMs\")]\n",
638-
" return tmp.loc[tmp[\"Name\"] == thunk_name, \"ProjDurMs\"]\n",
639+
" return steady_state.thunk.loc[(program_id, slice(None), thunk_name), \"ProjDurMs\"]\n",
639640
"\n",
640641
"\n",
641642
"detailed_index = high_variance_means[high_variance_means > mean_threshold].index\n",
@@ -666,6 +667,7 @@
666667
" squeeze=False,\n",
667668
" tight_layout=True,\n",
668669
" )\n",
670+
" # Compute (non-comm) kernel timings\n",
669671
" time_df = steady_state.thunk.loc[\n",
670672
" ~steady_state.thunk[\"Communication\"], (\"ProjStartMs\", \"ProjDurMs\")\n",
671673
" ]\n",
@@ -688,14 +690,17 @@
688690
" ):\n",
689691
" # Mean over devices to get a single [thunk0_start, thunk0_end, thunk1_start, ...]\n",
690692
" # array for this execution of this module\n",
691-
" mean_times = interleave(exec_df.groupby(\"ThunkIndex\").agg(\"mean\"))\n",
693+
" mean_times = interleave(\n",
694+
" exec_df.groupby([\"Name\", \"ThunkExecution\"], sort=False).agg(\"mean\")\n",
695+
" )\n",
692696
" # x axis of the plot will be the average over executions of the module\n",
693697
" x_values.append(mean_times - mean_times[0])\n",
694698
" for device, device_values in exec_df.groupby(\"Device\"):\n",
695699
" # [thunk0_start, thunk0_end, ...] array for one device within one module exec\n",
696700
" # with the average over devices subtracted\n",
697701
" y_values[device].append(interleave(device_values) - mean_times)\n",
698702
" mean_start_time_ms = np.mean(x_values, axis=0)\n",
703+
" # all_values: (num_devices, num_module_executions, thunks_per_module)\n",
699704
" all_values = np.array(list(y_values.values()))\n",
700705
" ax.plot(\n",
701706
" mean_start_time_ms,\n",
@@ -728,18 +733,17 @@
728733
" exec_df[\"ProjEndMs\"]\n",
729734
" - steady_state.module.loc[(program_id, module_execution), \"ProjStartMs\"]\n",
730735
" )\n",
731-
" tmp = exec_df.groupby(\"ThunkIndex\").agg(\n",
736+
" tmp = exec_df.groupby([\"Name\", \"ThunkExecution\"]).agg(\n",
732737
" {\n",
733-
" \"Name\": \"first\",\n",
734738
" \"Collective\": \"first\",\n",
735739
" \"CollectiveSize\": \"first\",\n",
736740
" \"EndInModuleMs\": \"mean\",\n",
737741
" }\n",
738742
" )\n",
739743
" for coll_size, values in tmp.groupby(\"CollectiveSize\"):\n",
740744
" comm_x_values[coll_size].append(values[\"EndInModuleMs\"])\n",
741-
" (_, xmax), (ymin, ymax) = ax.get_xlim(), ax.get_ylim()\n",
742-
" ax.set_xlim(0, xmax)\n",
745+
" ymin, ymax = ax.get_ylim()\n",
746+
" ax.set_xlim(mean_start_time_ms[0], mean_start_time_ms[-1])\n",
743747
" ax.set_ylim(ymin, ymax)\n",
744748
" largest_collective = max(comm_x_values.keys())\n",
745749
" for n_color, (coll_size, values) in enumerate(comm_x_values.items()):\n",
@@ -748,10 +752,10 @@
748752
" collective_times,\n",
749753
" ymin,\n",
750754
" # Draw taller vertical lines for collectives involving more devices\n",
751-
" ymin * (1 - coll_size / largest_collective),\n",
755+
" ymin * (1 - 0.75 * coll_size / largest_collective),\n",
752756
" color=f\"C{n_color}\",\n",
753757
" label=f\"{coll_size}-device collective\",\n",
754-
" linestyle=\"--\",\n",
758+
" linestyle=\"-\",\n",
755759
" )\n",
756760
"\n",
757761
" ax.set_title(\n",
@@ -836,7 +840,9 @@
836840
"outputs": [],
837841
"source": [
838842
"num_traces = {\n",
839-
" module_id: xla_module_metadata(module_id, policy=\"all\").unique_result(\n",
843+
" module_id: xla_module_metadata(\n",
844+
" module_id, policy=\"all\", prefix=prefix\n",
845+
" ).unique_result(\n",
840846
" lambda hlo_module: len(\n",
841847
" hlo_module.proto().buffer_assignment.heap_simulator_traces\n",
842848
" )\n",
@@ -855,7 +861,7 @@
855861
" squeeze=False,\n",
856862
")\n",
857863
"for n_module, module_id in enumerate(module_ids_with_traces):\n",
858-
" protos = xla_module_metadata(module_id, policy=\"all\")\n",
864+
" protos = xla_module_metadata(module_id, policy=\"all\", prefix=prefix)\n",
859865
" sizes_by_logical_id = protos.unique_result(\n",
860866
" lambda proto: {\n",
861867
" buffer.id: buffer.size\n",

.github/container/nsys_jax/nsys_jax/analyses/communication.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,9 @@ def process_communication_data(steady_state):
3838
collective_types.add(collective)
3939
# This grouped data frame will have a row for each device that is participating
4040
# in this instance of the collective.
41-
devices = df.groupby(["ProgramId", "ProgramExecution", "ThunkIndex"])
41+
devices = df.groupby(
42+
["ProgramId", "ProgramExecution", "Name", "ThunkExecution"]
43+
)
4244
# Take the fastest device bandwidth. Rationale: the slower devices appear
4345
# slower because they spend some time waiting for the last device, and then all
4446
# devices complete the collective at the same time. The fastest device is
@@ -134,8 +136,7 @@ def process_hidden_ms_to_total_ms(steady_state):
134136
for collective, df in grouped_data:
135137
collective_types.add(collective)
136138
total_ms = df["ProjDurMs"] + df["ProjDurHiddenMs"]
137-
mean_dur_hidden_ms_to_total_ms = (df["ProjDurHiddenMs"] / total_ms).mean()
138-
summary_data[collective] = mean_dur_hidden_ms_to_total_ms
139+
summary_data[collective] = df["ProjDurHiddenMs"].sum() / total_ms.sum()
139140

140141
return collective_types, summary_data
141142

@@ -253,8 +254,7 @@ def main():
253254
# Load the profiler data; the compilation part is needed for the warmup heuristics
254255
all_data = load_profiler_data(args.prefix, frames={"communication", "compile"})
255256
# Align timestamps
256-
all_data, alignment_metadata = align_profiler_data_timestamps(all_data)
257-
print(f"Alignment metadata: {alignment_metadata}")
257+
all_data, _ = align_profiler_data_timestamps(all_data)
258258
# Partition the profile data into initialisation and steady-state running
259259
_, steady_state = apply_warmup_heuristics(all_data)
260260

0 commit comments

Comments
 (0)