Skip to content

Commit 92243a8

Browse files
authored
triage tool: manage test case output files (#1487)
The main feature added here is that a special directory, `/triage-tool-output`, is mounted into the container when the test case is executed, and the host-side versions of this directory are tracked by the tool so it can identify those corresponding to the last-known-good and first-known-bad invocations of the test case. This means that the user can immediately examine "what changed", without having to manually re-run those two configurations. Other minor changes: - Add a `Container` interface with `PyxisContainer` and `DockerContainer` implementations, for neater type annotations. - Add a result cache, which allows the tool to avoid some redundant re-executions - Write the same log information to `info.log` as to the console - Only call `git fetch` if needed; if commit-level triage is using a "new"/failing container, all relevant git history should already be available.
1 parent 563231e commit 92243a8

File tree

12 files changed

+390
-164
lines changed

12 files changed

+390
-164
lines changed

.github/triage/jax_toolbox_triage/args.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,26 +4,28 @@
44
import os
55
import pathlib
66
import tempfile
7+
import typing
78

89
# Software we know may exist in the containers that we might be able to triage
910
# We know how to recompile JAX/XLA, so it's OK that they include C++ code
1011
# TransformerEngine is intentionally excluded because build-te.sh is not plumbed yet.
1112
# Flax and MaxText are pure Python, so it's OK we don't have a way of compiling them,
1213
# but they are not always installed in containers we want to triage.
13-
compulsory_software = {"xla", "jax"}
14-
optional_software = {"flax", "maxtext"}
14+
# Note this is not a `set` for the sake of run-to-run determinism.
15+
compulsory_software = ["xla", "jax"]
16+
optional_software = ["flax", "maxtext"]
1517

1618

17-
def parse_commit_argument(s):
18-
ret = {}
19+
def parse_commit_argument(s: str) -> typing.Dict[str, str]:
20+
ret: typing.Dict[str, str] = {}
1921
for part in s.split(","):
2022
sw, commit = part.split(":", 1)
2123
assert sw not in ret, ret
2224
ret[sw] = commit
2325
return ret
2426

2527

26-
def parse_args(args=None):
28+
def parse_args(args=None) -> argparse.Namespace:
2729
parser = argparse.ArgumentParser(
2830
description="""
2931
Triage failures in JAX/XLA-related tests. The expectation is that the given
@@ -203,10 +205,9 @@ def parse_args(args=None):
203205
args.passing_container is not None or args.failing_container is not None
204206
), "At least one of --passing-container and --failing-container must be passed."
205207
for prefix in ["passing", "failing"]:
206-
assert (
207-
getattr(args, f"{prefix}_container") is not None
208-
or getattr(args, f"{prefix}_commits").keys() >= compulsory_software
209-
), (
208+
assert getattr(args, f"{prefix}_container") is not None or getattr(
209+
args, f"{prefix}_commits"
210+
).keys() >= set(compulsory_software), (
210211
f"--{prefix}-commits must specify all of {compulsory_software} if "
211212
f"--{prefix}-container is not specified"
212213
)
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
from abc import ABC, abstractmethod
2+
import logging
3+
import subprocess
4+
import typing
5+
6+
7+
class Container(ABC):
8+
def __init__(self, *, logger: logging.Logger):
9+
self._logger = logger
10+
11+
@abstractmethod
12+
def __enter__(self) -> "Container":
13+
"""
14+
Launch the container instance
15+
"""
16+
pass
17+
18+
@abstractmethod
19+
def __exit__(self, *exc_info) -> None:
20+
"""
21+
Shut down the container instance
22+
"""
23+
pass
24+
25+
@abstractmethod
26+
def __repr__(self) -> str:
27+
pass
28+
29+
@abstractmethod
30+
def exec(
31+
self,
32+
command: typing.List[str],
33+
policy: typing.Literal["once", "once_per_container", "default"] = "default",
34+
stderr: typing.Literal["interleaved", "separate"] = "interleaved",
35+
workdir=None,
36+
) -> subprocess.CompletedProcess:
37+
"""
38+
Run a command inside a persistent container.
39+
"""
40+
pass
41+
42+
def check_exec(
43+
self, cmd: typing.List[str], **kwargs
44+
) -> subprocess.CompletedProcess:
45+
result = self.exec(cmd, **kwargs)
46+
if result.returncode != 0:
47+
self._logger.fatal(
48+
f"{' '.join(cmd)} exited with return code {result.returncode}"
49+
)
50+
self._logger.fatal(result.stdout)
51+
result.check_returncode()
52+
return result
53+
54+
@abstractmethod
55+
def exists(self) -> bool:
56+
"""
57+
Check if the container exists.
58+
"""
59+
pass

.github/triage/jax_toolbox_triage/docker.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,19 @@
33
import subprocess
44
import typing
55

6+
from .container import Container
67
from .utils import run_and_log
78

89

9-
class DockerContainer:
10+
class DockerContainer(Container):
1011
def __init__(
1112
self,
1213
url: str,
1314
*,
1415
logger: logging.Logger,
1516
mounts: typing.List[typing.Tuple[pathlib.Path, pathlib.Path]],
1617
):
17-
self._logger = logger
18+
super().__init__(logger=logger)
1819
self._mount_args = []
1920
for src, dst in mounts:
2021
self._mount_args += ["-v", f"{src}:{dst}"]
@@ -74,19 +75,6 @@ def exec(
7475
stderr=stderr,
7576
)
7677

77-
def check_exec(
78-
self, cmd: typing.List[str], **kwargs
79-
) -> subprocess.CompletedProcess:
80-
result = self.exec(cmd, **kwargs)
81-
if result.returncode != 0:
82-
self._logger.fatal(
83-
f"{' '.join(cmd)} exited with return code {result.returncode}"
84-
)
85-
self._logger.fatal(result.stdout)
86-
self._logger.fatal(result.stderr)
87-
result.check_returncode()
88-
return result
89-
9078
def exists(self) -> bool:
9179
"""
9280
Check if the given container exists.

.github/triage/jax_toolbox_triage/logic.py

Lines changed: 93 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import functools
44
import itertools
55
import logging
6+
import pathlib
67
import typing
78

89
from .utils import console_log_level
@@ -15,6 +16,7 @@ class TestResult:
1516
"""
1617

1718
__test__ = False # stop pytest gathering this
19+
host_output_directory: pathlib.Path
1820
result: bool
1921
stdouterr: str
2022

@@ -80,7 +82,7 @@ def container_search(
8082
logger: logging.Logger,
8183
skip_precondition_checks: bool,
8284
threshold_days: int,
83-
):
85+
) -> typing.Tuple[datetime.date, datetime.date]:
8486
adjust = functools.partial(
8587
adjust_date, logger=logger, container_exists=container_exists
8688
)
@@ -197,36 +199,29 @@ def __call__(self, *, commits: typing.Dict[str, str]) -> TestResult:
197199
...
198200

199201

200-
def _first(xs):
202+
T = typing.TypeVar("T")
203+
U = typing.TypeVar("U")
204+
FlatCommitDict = typing.Tuple[typing.Tuple[str, str], ...]
205+
206+
207+
def _first(xs: typing.Iterable[T]) -> T:
201208
return next(iter(xs))
202209

203210

204-
def _not_first(d):
211+
def _not_first(d: typing.Dict[T, U]) -> typing.Iterable[typing.Tuple[T, U]]:
205212
return itertools.islice(d.items(), 1, None)
206213

207214

208-
def commit_search(
215+
def _commit_search(
209216
*,
210217
commits: typing.OrderedDict[
211218
str, typing.Sequence[typing.Tuple[str, datetime.datetime]]
212219
],
213220
build_and_test: BuildAndTest,
214221
logger: logging.Logger,
215222
skip_precondition_checks: bool,
216-
):
217-
"""
218-
Bisect a failure back to a single commit.
219-
220-
Arguments:
221-
commits: *ordered* dictionary of commit sequences for different software
222-
packages, e.g. commits["jax"][0] is (hash, date) of the passing JAX
223-
commit. The ordering of packages has implications for precisely how
224-
the triage proceeds.
225-
build_and_test: callable that tests if a given vector of commits passes
226-
logger: instance to log output to
227-
skip_precondition_checks: if True, some tests that should pass/fail by
228-
construction are skipped
229-
"""
223+
result_cache: typing.Dict[FlatCommitDict, TestResult],
224+
) -> typing.Tuple[typing.Dict[str, str], TestResult, typing.Optional[TestResult]]:
230225
assert all(len(commit_list) for commit_list in commits.values()), (
231226
"Not enough commits: need at least one commit for each package",
232227
commits,
@@ -236,6 +231,11 @@ def commit_search(
236231
commits,
237232
)
238233

234+
def _cache_key(
235+
commits: typing.Dict[str, str],
236+
) -> FlatCommitDict:
237+
return tuple(sorted(commits.items()))
238+
239239
if skip_precondition_checks:
240240
logger.info("Skipping check that 'good' commits reproduce success")
241241
else:
@@ -245,6 +245,8 @@ def commit_search(
245245
package: commit_list[0][0] for package, commit_list in commits.items()
246246
}
247247
check_pass = build_and_test(commits=passing_commits)
248+
assert _cache_key(passing_commits) not in result_cache
249+
result_cache[_cache_key(passing_commits)] = check_pass
248250
if check_pass.result:
249251
logger.info("Verified test passes using 'good' commits")
250252
else:
@@ -264,6 +266,8 @@ def commit_search(
264266
# below is actionable without checking the debug logfile.
265267
with console_log_level(logger, logging.DEBUG):
266268
check_fail = build_and_test(commits=failing_commits)
269+
assert _cache_key(failing_commits) not in result_cache
270+
result_cache[_cache_key(failing_commits)] = check_fail
267271
if not check_fail.result:
268272
logger.info(
269273
"Verified test failure using 'bad' commits. IMPORTANT: you should check "
@@ -303,9 +307,11 @@ def commit_search(
303307
bisect_commits[secondary] = commit
304308
log_msg += f", {len(commit_list)} remaining {secondary} commits"
305309
logger.info(log_msg)
306-
bisect_result = build_and_test(commits=bisect_commits).result
310+
bisect_result = build_and_test(commits=bisect_commits)
311+
assert _cache_key(bisect_commits) not in result_cache
312+
result_cache[_cache_key(bisect_commits)] = bisect_result
307313

308-
if bisect_result:
314+
if bisect_result.result:
309315
# Test passed, continue searching in the second half
310316
for package, index in indices.items():
311317
commits[package] = commits[package][index:]
@@ -336,7 +342,11 @@ def commit_search(
336342
f"Two {primary} commits remain, checking if {commits[primary][-1][0]} is the "
337343
"culprit"
338344
)
339-
blame = build_and_test(commits=blame_commits)
345+
# It's possible that this combination has already been tested at this point
346+
blame = result_cache.get(_cache_key(blame_commits))
347+
if blame is None:
348+
blame = build_and_test(commits=blame_commits)
349+
result_cache[_cache_key(blame_commits)] = blame
340350
if blame.result:
341351
# Test passed with {pX, sZ, tZ, ...} but was known to fail with
342352
# {pZ, sZ, tZ, ...}. Therefore pZ is the culprit commit.
@@ -349,18 +359,78 @@ def commit_search(
349359
f"{primary}_bad": bad_commit,
350360
f"{primary}_good": good_commit,
351361
}
362+
first_known_bad = {primary: bad_commit}
352363
for secondary, secondary_commit in _not_first(blame_commits):
364+
first_known_bad[secondary] = secondary_commit
353365
ret[f"{secondary}_ref"] = secondary_commit
354-
return ret
366+
# `blame` represents the last-known-good test result, first-known-bad was seen
367+
# earlier, or possibly not at all e.g. if `skip_precondition_checks` is True
368+
# and first-known-bad was the end of the search range.
369+
first_known_bad_result = result_cache.get(_cache_key(first_known_bad))
370+
if first_known_bad_result is None:
371+
if skip_precondition_checks:
372+
logger.info(
373+
"Did not find a cached result for the first-known-bad "
374+
f"configuration {first_known_bad}, this is probably due to "
375+
"--skip-precondition-checks having been passed."
376+
)
377+
else:
378+
logger.error(
379+
"Did not find a cached result for the first-known-bad "
380+
f"configuration {first_known_bad}, this is unexpected!"
381+
)
382+
return ret, blame, first_known_bad_result
355383
else:
356384
# Test failed with both {pX, sZ, tZ, ...} and {pZ, sZ, tZ, ...}, so
357385
# we can fix the primary package to pX and recurse with the old
358386
# secondary package (s) as the new primary, and the old primary
359387
# package (p) moved to the end.
360388
commits[primary] = [commits.pop(primary)[0]]
361-
return commit_search(
389+
return _commit_search(
362390
build_and_test=build_and_test,
363391
commits=commits,
364392
logger=logger,
365393
skip_precondition_checks=True,
394+
result_cache=result_cache,
366395
)
396+
397+
398+
def commit_search(
399+
*,
400+
commits: typing.OrderedDict[
401+
str, typing.Sequence[typing.Tuple[str, datetime.datetime]]
402+
],
403+
build_and_test: BuildAndTest,
404+
logger: logging.Logger,
405+
skip_precondition_checks: bool,
406+
) -> typing.Tuple[
407+
typing.Dict[str, str],
408+
TestResult,
409+
typing.Optional[TestResult],
410+
]:
411+
"""
412+
Bisect a failure back to a single commit.
413+
414+
Arguments:
415+
commits: *ordered* dictionary of commit sequences for different software
416+
packages, e.g. commits["jax"][0] is (hash, date) of the passing JAX
417+
commit. The ordering of packages has implications for precisely how
418+
the triage proceeds.
419+
build_and_test: callable that tests if a given vector of commits passes
420+
logger: instance to log output to
421+
skip_precondition_checks: if True, some tests that should pass/fail by
422+
construction are skipped
423+
424+
Returns a 3-tuple of (summary_dict, last_known_good, first_known_bad),
425+
where the last element can be None if skip_precondition_checks=True. The
426+
last two elements' .result fields will always be, respectively, True and
427+
False, but the other fields can be used to obtain stdout+stderr and
428+
output files from those test invocations.
429+
"""
430+
return _commit_search(
431+
commits=commits,
432+
build_and_test=build_and_test,
433+
logger=logger,
434+
skip_precondition_checks=skip_precondition_checks,
435+
result_cache={},
436+
)

0 commit comments

Comments
 (0)