Skip to content

Commit b28cc73

Browse files
authored
triage-tool: support multi-node triage (#1421)
- Add `--{passing,failing}-commits` option, which allows the commit-level search to be manually restricted in scope - This also allows single-container triage, as you can pass `--{passing,failing}-container` and `--{failing,passing}-commits` - With the Pyxis backend, re-use the same container instance within a single triage tool process - This reduces the amount of time spent in container creation - Support multi-node/multi-process triage with the Pyxis backend - This is implemented by annotating the various commands run inside the containers as `once` (run once, in one container instance), `once_per_container` (run once per container instance, i.e. once per node), and `default` (run without extra `srun` arguments -- the caller must make this do the right thing e.g. by passing `--ntasks-per-node` to `salloc`. - `once` example: getting the JAX commit from a container - `once_per_container` example: building JAX - `default` example: running the test case
1 parent 926f8ea commit b28cc73

File tree

6 files changed

+213
-83
lines changed

6 files changed

+213
-83
lines changed

.github/triage/jax_toolbox_triage/args.py

Lines changed: 64 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,16 @@
66
import tempfile
77

88

9+
def parse_commit_argument(s):
10+
ret = {}
11+
for part in s.split(","):
12+
sw, commit = part.split(":", 1)
13+
assert sw not in ret, ret
14+
ret[sw] = commit
15+
assert ret.keys() == {"jax", "xla"}, ret.keys()
16+
return ret
17+
18+
919
def parse_args(args=None):
1020
parser = argparse.ArgumentParser(
1121
description="""
@@ -70,10 +80,11 @@ def parse_args(args=None):
7080
"--failing-container",
7181
help="""
7282
Skip the container-level search and pass this container to the commit-level
73-
search. If this is passed, --passing-container must be too, but --container
74-
is not required. This can be used to apply the commit-level bisection
75-
search to containers not from the ghcr.io/nvidia/jax:CONTAINER-YYYY-MM-DD
76-
series, although they must have a similar structure.""",
83+
search. If this is passed, --passing-container or --passing-commits must be
84+
too, but --container is not required. This can be used to apply the
85+
commit-level bisection search to containers not from the
86+
ghcr.io/nvidia/jax:CONTAINER-YYYY-MM-DD series, although they must have a
87+
similar structure.""",
7788
)
7889
container_search_args.add_argument(
7990
"--end-date",
@@ -88,10 +99,11 @@ def parse_args(args=None):
8899
"--passing-container",
89100
help="""
90101
Skip the container-level search and pass this container to the commit-level
91-
search. If this is passed, --failing-container must be too, but --container is
92-
not required. This can be used to apply the commit-level bisection search
93-
to containers not from the ghcr.io/nvidia/jax:CONTAINER-YYYY-MM-DD series,
94-
although they must have a similar structure.""",
102+
search. If this is passed, --failing-container or --failing-commits must be
103+
too, but --container is not required. This can be used to apply the
104+
commit-level bisection search to containers not from the
105+
ghcr.io/nvidia/jax:CONTAINER-YYYY-MM-DD series, although they must have a
106+
similar structure.""",
95107
)
96108
container_search_args.add_argument(
97109
"--start-date",
@@ -126,6 +138,24 @@ def parse_args(args=None):
126138
significantly speed up the commit-level search. By default, uses a temporary
127139
directory including the name of the current user.""",
128140
)
141+
commit_search_args.add_argument(
142+
"--failing-commits",
143+
help="""
144+
When combined with --passing-container, the commit-level triage will use
145+
that container and --failing-commits will specify the end of the commit
146+
range, rather than the commits being extracted from --failing-container.
147+
Expects an argument of form jax:jax_commit_hash,xla:xla_commit_hash.""",
148+
type=parse_commit_argument,
149+
)
150+
commit_search_args.add_argument(
151+
"--passing-commits",
152+
help="""
153+
When combined with --failing-container, the commit-level triage will use
154+
that container and --passing-commits will specify the start of the commit
155+
range, rather than the commits being extracted from --passing-container.
156+
Expects an argument of form jax:jax_commit_hash,xla:xla_commit_hash.""",
157+
type=parse_commit_argument,
158+
)
129159
parser.add_argument(
130160
"-v",
131161
"--container-mount",
@@ -145,29 +175,33 @@ def parse_args(args=None):
145175
)
146176
args = parser.parse_args(args=args)
147177
assert args.container_runtime in {"docker", "pyxis"}, args.container_runtime
148-
num_explicit_containers = (args.passing_container is not None) + (
149-
args.failing_container is not None
150-
)
151-
if num_explicit_containers == 1:
152-
raise Exception(
153-
"--passing-container and --failing-container must both be passed if either is"
178+
passing_commits_known = (args.passing_container is not None) or (
179+
args.passing_commits is not None
180+
)
181+
failing_commits_known = (args.failing_container is not None) or (
182+
args.failing_commits is not None
183+
)
184+
sets_of_known_commits = passing_commits_known + failing_commits_known
185+
if sets_of_known_commits == 2:
186+
# If the container-level search is being skipped, because a valid combination
187+
# of --{passing,failing}-{commits,container} is passed, then no container-level
188+
# search options should be passed.
189+
assert (
190+
args.container is None and args.start_date is None and args.end_date is None
191+
), (
192+
"No container-level search options should be passed if the passing/failing containers/commits have been passed explicitly."
154193
)
155-
if num_explicit_containers == 2:
156-
# Explicit mode, --container, --start-date and --end-date are all ignored
157-
if args.container:
158-
raise Exception(
159-
"--container must not be passed if --passing-container and --failing-container are"
160-
)
161-
if args.start_date:
162-
raise Exception(
163-
"--start-date must not be passed if --passing-container and --failing-container are"
164-
)
165-
if args.end_date:
166-
raise Exception(
167-
"--end-date must not be passed if --passing-container and --failing-container are"
168-
)
169-
elif num_explicit_containers == 0 and args.container is None:
194+
assert (
195+
args.passing_container is not None or args.failing_container is not None
196+
), ""
197+
elif sets_of_known_commits == 1:
170198
raise Exception(
171-
"--container must be passed if --passing-container and --failing-container are not"
199+
"If --passing-{commits OR container} is passed then --failing-{commits OR container} should be too"
200+
)
201+
else:
202+
# None of --{passing,failing}-{commits,container} were passed, make sure the
203+
# compulsory arguments for the container-level search were passed
204+
assert args.container is not None, (
205+
"--container must be passed for the container-level search"
172206
)
173207
return args

.github/triage/jax_toolbox_triage/docker.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,10 @@ def __repr__(self):
5656
return f"Docker({self._url})"
5757

5858
def exec(
59-
self, command: typing.List[str], workdir=None
59+
self,
60+
command: typing.List[str],
61+
policy: typing.Literal["once", "once_per_container", "default"] = "default",
62+
workdir=None,
6063
) -> subprocess.CompletedProcess:
6164
"""
6265
Run a command inside a persistent container.

.github/triage/jax_toolbox_triage/logic.py

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -214,43 +214,36 @@ def commit_search(
214214
end_jax_commit = jax_commits[-1][0]
215215
end_xla_commit = xla_commits[-1][0]
216216
if skip_precondition_checks:
217-
logger.info("Skipping check that vanilla rebuild + test reproduces failure")
217+
logger.info("Skipping check that 'bad' commits reproduce failure")
218218
else:
219-
# Verify we can build successfully and that the test fails as expected. These
220-
# commits are the ones already checked out in the container, but specifying
221-
# them explicitly is good for the summary JSON.
222-
logger.info("Building in the range-ending container...")
219+
# Verify we can build successfully and that the test fails as expected.
220+
logger.info("Verifying test failure using 'bad' commits")
223221
range_end_result, stdout, stderr = build_and_test(
224222
jax_commit=end_jax_commit, xla_commit=end_xla_commit
225223
)
226224
if not range_end_result:
227-
logger.info("Verified test failure after vanilla rebuild")
225+
logger.info("Verified test failure using 'bad' commits")
228226
else:
229-
logger.fatal("Vanilla rebuild did not reproduce test failure")
227+
logger.fatal("Could not reproduce failure with 'bad' commits")
230228
logger.fatal(stdout)
231229
logger.fatal(stderr)
232-
raise Exception(
233-
"Could not reproduce failure after rebuild in 'bad' container"
234-
)
230+
raise Exception("Could not reproduce failure with 'bad' commits")
235231

236-
# Verify that we can build the commit at the start of the range and reproduce the
237-
# test success there in the end-of-range container.
238-
range_start_result, stdout, stderr = build_and_test(
239-
jax_commit=start_jax_commit, xla_commit=start_xla_commit
240-
)
241-
if range_start_result:
242-
logger.info(
243-
"Test passed after rebuilding commits from start container in end container"
244-
)
232+
if skip_precondition_checks:
233+
logger.info("Skipping check that 'good' commits reproduce success")
245234
else:
246-
logger.fatal(
247-
"Test failed after rebuilding commits from start container in end container"
248-
)
249-
logger.fatal(stdout)
250-
logger.fatal(stderr)
251-
raise Exception(
252-
"Could not reproduce success with 'good' commits in 'bad' container"
235+
# Verify that we can build successfully and that the test succeeds as expected.
236+
logger.info("Verifying test success using 'good' commits")
237+
range_start_result, stdout, stderr = build_and_test(
238+
jax_commit=start_jax_commit, xla_commit=start_xla_commit
253239
)
240+
if range_start_result:
241+
logger.info("Verified test passes using 'good' commits")
242+
else:
243+
logger.fatal("Could not reproduce success with 'good' commits")
244+
logger.fatal(stdout)
245+
logger.fatal(stderr)
246+
raise Exception("Could not reproduce success with 'good' commits")
254247

255248
# Finally, start bisecting. This is XLA-centric; JAX is moved too but is secondary.
256249
while len(xla_commits) > 2:
@@ -260,6 +253,9 @@ def commit_search(
260253
for jax_index, (jax_hash, jax_date) in enumerate(jax_commits):
261254
if jax_date >= xla_date:
262255
break
256+
logger.info(
257+
f"Chose from {len(xla_commits)} remaining XLA commits and {len(jax_commits)} remaining JAX commits"
258+
)
263259
bisect_result, _, _ = build_and_test(jax_commit=jax_hash, xla_commit=xla_hash)
264260
if bisect_result:
265261
# Test passed, continue searching in the second half

.github/triage/jax_toolbox_triage/main.py

Lines changed: 65 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ def get_commit(
3232
results = []
3333
for suffix in ["", "-source"]:
3434
dirname = f"/opt/{repo}{suffix}"
35-
result = container.exec(["git", "rev-parse", "HEAD"], workdir=dirname)
35+
result = container.exec(
36+
["git", "rev-parse", "HEAD"], policy="once", workdir=dirname
37+
)
3638
results.append(result)
3739
if result.returncode == 0:
3840
commit = result.stdout.strip()
@@ -108,12 +110,7 @@ def check_container(date: datetime.date) -> TestResult:
108110
)
109111
return TestResult(result=test_pass, stdout=result.stdout, stderr=result.stderr)
110112

111-
if args.passing_container is not None:
112-
assert args.failing_container is not None
113-
# Skip the container-level search because explicit end points were given
114-
passing_url = args.passing_container
115-
failing_url = args.failing_container
116-
else:
113+
if args.passing_container is None and args.failing_container is None:
117114
# Search through the published containers, narrowing down to a pair of dates with
118115
# the property that the test passed on `range_start` and fails on `range_end`.
119116
range_start, range_end = container_search(
@@ -127,20 +124,44 @@ def check_container(date: datetime.date) -> TestResult:
127124
)
128125
passing_url = container_url(range_start)
129126
failing_url = container_url(range_end)
127+
else:
128+
# Skip the container-level search because at lease one explicit end point was
129+
# given
130+
passing_url = args.passing_container
131+
failing_url = args.failing_container
132+
133+
jax_dir = "/opt/jax"
134+
xla_dir = "/opt/xla"
135+
if args.passing_commits is not None:
136+
start_jax_commit = args.passing_commits["jax"]
137+
start_xla_commit = args.passing_commits["xla"]
138+
else:
139+
assert passing_url is not None
140+
with Container(passing_url) as worker:
141+
start_jax_commit, jax_dir = get_commit(worker, "jax")
142+
start_xla_commit, xla_dir = get_commit(worker, "xla")
143+
144+
if args.failing_commits is not None:
145+
end_jax_commit = args.failing_commits["jax"]
146+
end_xla_commit = args.failing_commits["xla"]
147+
else:
148+
assert failing_url is not None
149+
with Container(failing_url) as worker:
150+
end_jax_commit, jax_dir = get_commit(worker, "jax")
151+
end_xla_commit, xla_dir = get_commit(worker, "xla")
152+
153+
bisection_url = failing_url or passing_url
154+
assert bisection_url is not None
155+
assert jax_dir is not None
156+
assert xla_dir is not None
130157

131158
# Container-level search is now complete. Triage proceeds inside the `range_end``
132159
# container. First, we check that rewinding JAX and XLA inside the `range_end``
133160
# container to the commits used in the `range_start` container passes, whereas
134161
# using the `range_end` commits reproduces the failure.
135162

136-
with Container(passing_url) as worker:
137-
start_jax_commit, _ = get_commit(worker, "jax")
138-
start_xla_commit, _ = get_commit(worker, "xla")
139-
140163
# Fire up the container that will be used for the fine search.
141-
with Container(failing_url) as worker:
142-
end_jax_commit, jax_dir = get_commit(worker, "jax")
143-
end_xla_commit, xla_dir = get_commit(worker, "xla")
164+
with Container(bisection_url) as worker:
144165
logger.info(
145166
(
146167
f"Bisecting JAX [{start_jax_commit}, {end_jax_commit}] and "
@@ -150,6 +171,9 @@ def check_container(date: datetime.date) -> TestResult:
150171

151172
# Get the full lists of JAX/XLA commits and dates
152173
def commits(start, end, dir):
174+
worker.check_exec(
175+
["git", "fetch"], policy="once_per_container", workdir=dir
176+
)
153177
result = worker.check_exec(
154178
[
155179
"git",
@@ -159,6 +183,7 @@ def commits(start, end, dir):
159183
"--format=%H %cI",
160184
f"{start}^..{end}",
161185
],
186+
policy="once",
162187
workdir=dir,
163188
)
164189
data = []
@@ -190,21 +215,39 @@ def build_and_test(
190215
jaxlib, and run the test command. Throws on error when checking out or
191216
building, and returns the status of the test command.
192217
"""
193-
worker.check_exec(["git", "stash"], workdir=xla_dir)
194-
worker.check_exec(["git", "stash"], workdir=jax_dir)
195-
worker.check_exec(["git", "checkout", xla_commit], workdir=xla_dir)
196-
worker.check_exec(["git", "checkout", jax_commit], workdir=jax_dir)
197-
logger.info(f"Checking out XLA {xla_commit} JAX {jax_commit}")
218+
worker.check_exec(
219+
["git", "stash"], policy="once_per_container", workdir=xla_dir
220+
)
221+
worker.check_exec(
222+
["git", "stash"], policy="once_per_container", workdir=jax_dir
223+
)
224+
worker.check_exec(
225+
["git", "checkout", xla_commit],
226+
policy="once_per_container",
227+
workdir=xla_dir,
228+
)
229+
worker.check_exec(
230+
["git", "checkout", jax_commit],
231+
policy="once_per_container",
232+
workdir=jax_dir,
233+
)
234+
logger.info(f"Checking out XLA {xla_commit} JAX {jax_commit} in {worker}")
198235
# Build JAX
199236
before = time.monotonic()
200237
# Unfortunately the build system does not always seem to handle incremental
201238
# rebuilds correctly.
202-
worker.check_exec(["bazel", "clean", "--expunge"], workdir=jax_dir)
239+
worker.check_exec(
240+
["bazel", "clean", "--expunge"],
241+
policy="once_per_container",
242+
workdir=jax_dir,
243+
)
203244
build_jax = [
204245
"build-jax.sh",
205246
f"--bazel-cache={args.bazel_cache}",
206247
]
207-
build_result = worker.check_exec(build_jax, workdir=jax_dir)
248+
build_result = worker.check_exec(
249+
build_jax, policy="once_per_container", workdir=jax_dir
250+
)
208251
middle = time.monotonic()
209252
logger.info(f"Build completed in {middle - before:.1f}s")
210253
logger.debug(
@@ -217,7 +260,7 @@ def build_and_test(
217260
"commit",
218261
{
219262
"build_time": middle - before,
220-
"container": failing_url,
263+
"container": bisection_url,
221264
"jax": jax_commit,
222265
"result": test_result.returncode == 0,
223266
"test_time": test_time,

0 commit comments

Comments
 (0)