Skip to content

Commit ee46a55

Browse files
authored
nsys-jax: XLA_FLAGS and nsys-patching fixes (#1337)
- Repeated entries in `XLA_FLAGS` will no longer cause assertion errors. Only the last flag will be passed to the child application. - `nsys-jax-patch-nsys` now uses [ugly] logic that works even for `cuda-nsight-systems` packages where `nsys` in `$PATH` is a shim - CI tests are added for this case
1 parent a5c50be commit ee46a55

File tree

3 files changed

+44
-15
lines changed

3 files changed

+44
-15
lines changed

.github/container/nsys_jax/nsys_jax/scripts/nsys_jax.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import tempfile
1919
import time
2020
import traceback
21+
from typing import Optional
2122
import zipfile
2223

2324
from .utils import execute_analysis_script, shuffle_analysis_arg
@@ -259,21 +260,22 @@ def override_nsys_default(arg, value):
259260
if "JAX_ENABLE_COMPILATION_CACHE" not in env:
260261
env["JAX_ENABLE_COMPILATION_CACHE"] = "false"
261262

263+
def format_flag(tup):
264+
n, v = tup
265+
return f"--{n}" if v is None else f"--{n}={v}"
266+
262267
# Get the existing XLA_FLAGS and parse them into a dictionary.
263-
xla_flag_list = shlex.split(env.get("XLA_FLAGS", ""))
264-
xla_flags = {}
265-
for flag in xla_flag_list:
268+
xla_flags: dict[str, Optional[str]] = {}
269+
for flag in shlex.split(env.get("XLA_FLAGS", "")):
266270
assert flag.startswith("--")
267271
bits = flag[2:].split("=", maxsplit=1)
268272
name, value = bits[0], bits[1] if len(bits) > 1 else None
269-
assert name not in xla_flags
273+
if name in xla_flags:
274+
print(
275+
f"WARNING: {format_flag((name, xla_flags[name]))} being overriden by {flag}"
276+
)
270277
xla_flags[name] = value
271278

272-
def as_list(flags):
273-
return [f"--{n}" if v is None else f"--{n}={v}" for n, v in flags.items()]
274-
275-
assert xla_flag_list == as_list(xla_flags)
276-
277279
def as_bool(s):
278280
"""String -> bool conversion following XLA's semantics."""
279281
if s.lower() == "true" or s == "1":
@@ -298,7 +300,7 @@ def as_bool(s):
298300

299301
# Serialise the modified XLA flags. shlex.join is tempting, but doesn't seem to
300302
# get the right result for --xla_dump_hlo_pass_re=.*, as it adds extra quotes.
301-
env["XLA_FLAGS"] = " ".join(as_list(xla_flags))
303+
env["XLA_FLAGS"] = " ".join(map(format_flag, xla_flags.items()))
302304

303305
# Run the application in nsys
304306
# TODO: consider being more fault-tolerant?

.github/container/nsys_jax/nsys_jax/scripts/patch_nsys.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import os
21
import re
32
import shutil
43
import subprocess
@@ -1372,11 +1371,20 @@ def main():
13721371
patch_content = None
13731372
if patch_content is not None:
13741373
print(f"Patching Nsight Systems version {m.group(1)}")
1375-
# e.g. /opt/nvidia/nsight-systems-cli/2024.7.1/target-linux-x64
1376-
tdir = os.path.dirname(os.path.realpath(nsys))
1374+
nsys_recipe_help = subprocess.check_output(
1375+
[nsys, "recipe", "--help"], text=True
1376+
)
1377+
m = re.search(
1378+
r"List of required Python packages: '(.*?)/nsys_recipe/requirements/common.txt'",
1379+
nsys_recipe_help,
1380+
)
1381+
assert m is not None, (
1382+
f"Could not determine target directory from: {nsys_recipe_help}"
1383+
)
1384+
# e.g. /opt/nvidia/nsight-systems-cli/2024.7.1/target-linux-x64/python/packages
13771385
subprocess.run(
13781386
[shutil.which("git"), "apply"],
1379-
cwd=os.path.join(tdir, "python", "packages"),
1387+
cwd=m.group(1),
13801388
input=patch_content,
13811389
check=True,
13821390
text=True,

.github/workflows/nsys-jax.yaml

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: nsys-jax pure-Python CI
1+
name: nsys-jax non-GPU CI
22

33
concurrency:
44
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
@@ -274,3 +274,22 @@ jobs:
274274
if [[ $format_status != 0 || $check_status != 0 ]]; then
275275
exit 1
276276
fi
277+
installation:
278+
strategy:
279+
matrix:
280+
include:
281+
- container: "nvidia/cuda:12.6.3-base-ubuntu24.04"
282+
nsys_package: "cuda-nsight-systems-12-6"
283+
- container: "nvidia/cuda:12.8.0-base-ubuntu24.04"
284+
nsys_package: "cuda-nsight-systems-12-8"
285+
runs-on: ubuntu-latest
286+
container: "${{ matrix.container }}"
287+
steps:
288+
- name: Install ${{ matrix.nsys_package }}
289+
run: |
290+
apt-get update
291+
apt-get install -y git python3-pip ${{ matrix.nsys_package }}
292+
- name: Install nsys-jax
293+
run: pip install --break-system-packages git+https://github.com/NVIDIA/JAX-Toolbox.git@${{ github.head_ref || github.sha }}#subdirectory=.github/container/nsys_jax
294+
- name: Run nsys-jax-patch-nsys
295+
run: nsys-jax-patch-nsys

0 commit comments

Comments
 (0)