Skip to content

Commit f762fce

Browse files
[NFC] FileCheck tests check all overloads (#354)
* Clean up type hints * On failure, dump the failing IR to a file and provide a reproducer command * When multiple overloaded IR is available for a kernel and no signature has been passed, check all of the overloads.
1 parent d455b9b commit f762fce

File tree

2 files changed

+122
-60
lines changed

2 files changed

+122
-60
lines changed

conftest.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import pytest
2+
3+
4+
def pytest_addoption(parser):
5+
parser.addoption(
6+
"--dump-failed-filechecks",
7+
action="store_true",
8+
help="Dump reproducers for FileCheck tests that fail.",
9+
)
10+
11+
12+
@pytest.fixture(scope="class")
13+
def initialize_from_pytest_config(request):
14+
"""
15+
Fixture to initialize the test case with pytest configuration options.
16+
"""
17+
request.cls._dump_failed_filechecks = request.config.getoption(
18+
"dump_failed_filechecks"
19+
)

numba_cuda/numba/cuda/testing.py

Lines changed: 103 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,37 @@
11
import os
22
import platform
33
import shutil
4-
4+
import pytest
5+
from datetime import datetime
56
from numba.core.utils import PYVERSION
67
from numba.cuda.cuda_paths import get_conda_ctk
78
from numba.cuda.cudadrv import driver, devices, libs
89
from numba.cuda.dispatcher import CUDADispatcher
910
from numba.core import config
1011
from numba.tests.support import TestCase
1112
from pathlib import Path
12-
from typing import Union
13+
14+
from typing import Iterable, Union
1315
from io import StringIO
1416
import unittest
1517

1618
if PYVERSION >= (3, 10):
17-
from filecheck.matcher import Matcher, Options
19+
from filecheck.matcher import Matcher
20+
from filecheck.options import Options
1821
from filecheck.parser import Parser, pattern_for_opts
1922
from filecheck.finput import FInput
2023

2124
numba_cuda_dir = Path(__file__).parent
2225
test_data_dir = numba_cuda_dir / "tests" / "data"
2326

2427

25-
class FileCheckTestCaseMixin:
28+
@pytest.mark.usefixtures("initialize_from_pytest_config")
29+
class CUDATestCase(TestCase):
2630
"""
27-
Mixin for tests that use FileCheck.
31+
For tests that use a CUDA device. Test methods in a CUDATestCase must not
32+
be run out of module order, because the ContextResettingTestCase may reset
33+
the context and destroy resources used by a normal CUDATestCase if any of
34+
its tests are run between tests from a CUDATestCase.
2835
2936
Methods assertFileCheckAsm and assertFileCheckLLVM will inspect a
3037
CUDADispatcher and assert that the compilation artifacts match the
@@ -34,56 +41,96 @@ class FileCheckTestCaseMixin:
3441
matches FileCheck checks, and is not specific to CUDADispatcher.
3542
"""
3643

44+
def setUp(self):
45+
self._low_occupancy_warnings = config.CUDA_LOW_OCCUPANCY_WARNINGS
46+
self._warn_on_implicit_copy = config.CUDA_WARN_ON_IMPLICIT_COPY
47+
48+
# Disable warnings about low gpu utilization in the test suite
49+
config.CUDA_LOW_OCCUPANCY_WARNINGS = 0
50+
# Disable warnings about host arrays in the test suite
51+
config.CUDA_WARN_ON_IMPLICIT_COPY = 0
52+
53+
def tearDown(self):
54+
config.CUDA_LOW_OCCUPANCY_WARNINGS = self._low_occupancy_warnings
55+
config.CUDA_WARN_ON_IMPLICIT_COPY = self._warn_on_implicit_copy
56+
57+
Signature = Union[tuple[type, ...], None]
58+
59+
def _getIRContents(
60+
self,
61+
ir_result: Union[dict[Signature, str], str],
62+
signature: Union[Signature, None] = None,
63+
) -> Iterable[str]:
64+
if isinstance(ir_result, str):
65+
assert signature is None, (
66+
"Cannot use signature because the kernel was only compiled for one signature"
67+
)
68+
return [ir_result]
69+
70+
if signature is None:
71+
return list(ir_result.values())
72+
73+
return [ir_result[signature]]
74+
3775
def assertFileCheckAsm(
3876
self,
3977
ir_producer: CUDADispatcher,
4078
signature: Union[tuple[type, ...], None] = None,
41-
check_prefixes: list[str] = ("ASM",),
42-
**extra_filecheck_options: dict[str, Union[str, int]],
79+
check_prefixes: tuple[str] = ("ASM",),
80+
**extra_filecheck_options,
4381
) -> None:
4482
"""
4583
Assert that the assembly output of the given CUDADispatcher matches
4684
the FileCheck checks given in the kernel's docstring.
4785
"""
48-
ir_content = ir_producer.inspect_asm()
49-
if signature:
50-
ir_content = ir_content[signature]
51-
check_patterns = ir_producer.__doc__
52-
self.assertFileCheckMatches(
53-
ir_content,
54-
check_patterns=check_patterns,
55-
check_prefixes=check_prefixes,
56-
**extra_filecheck_options,
86+
ir_contents = self._getIRContents(ir_producer.inspect_asm(), signature)
87+
assert ir_contents, "No assembly output found for the given signature."
88+
assert ir_producer.__doc__ is not None, (
89+
"Kernel docstring is required. To pass checks explicitly, use assertFileCheckMatches."
5790
)
91+
check_patterns = ir_producer.__doc__
92+
for ir_content in ir_contents:
93+
self.assertFileCheckMatches(
94+
ir_content,
95+
check_patterns=check_patterns,
96+
check_prefixes=check_prefixes,
97+
**extra_filecheck_options,
98+
)
5899

59100
def assertFileCheckLLVM(
60101
self,
61102
ir_producer: CUDADispatcher,
62103
signature: Union[tuple[type, ...], None] = None,
63-
check_prefixes: list[str] = ("LLVM",),
64-
**extra_filecheck_options: dict[str, Union[str, int]],
104+
check_prefixes: tuple[str] = ("LLVM",),
105+
**extra_filecheck_options,
65106
) -> None:
66107
"""
67108
Assert that the LLVM IR output of the given CUDADispatcher matches
68109
the FileCheck checks given in the kernel's docstring.
69110
"""
70-
ir_content = ir_producer.inspect_llvm()
71-
if signature:
72-
ir_content = ir_content[signature]
73-
check_patterns = ir_producer.__doc__
74-
self.assertFileCheckMatches(
75-
ir_content,
76-
check_patterns=check_patterns,
77-
check_prefixes=check_prefixes,
78-
**extra_filecheck_options,
111+
ir_contents = self._getIRContents(ir_producer.inspect_llvm(), signature)
112+
assert ir_contents, "No LLVM IR output found for the given signature."
113+
assert ir_producer.__doc__ is not None, (
114+
"Kernel docstring is required. To pass checks explicitly, use assertFileCheckMatches."
79115
)
116+
check_patterns = ir_producer.__doc__
117+
for ir_content in ir_contents:
118+
assert ir_content, (
119+
"LLVM IR content is empty for the given signature."
120+
)
121+
self.assertFileCheckMatches(
122+
ir_content,
123+
check_patterns=check_patterns,
124+
check_prefixes=check_prefixes,
125+
**extra_filecheck_options,
126+
)
80127

81128
def assertFileCheckMatches(
82129
self,
83130
ir_content: str,
84131
check_patterns: str,
85-
check_prefixes: list[str] = ("CHECK",),
86-
**extra_filecheck_options: dict[str, Union[str, int]],
132+
check_prefixes: tuple[str] = ("CHECK",),
133+
**extra_filecheck_options,
87134
) -> None:
88135
"""
89136
Assert that the given string matches the passed FileCheck checks.
@@ -98,7 +145,7 @@ def assertFileCheckMatches(
98145
self.skipTest("FileCheck requires Python 3.10 or later")
99146
opts = Options(
100147
match_filename="-",
101-
check_prefixes=check_prefixes,
148+
check_prefixes=list(check_prefixes),
102149
**extra_filecheck_options,
103150
)
104151
input_file = FInput(fname="-", content=ir_content)
@@ -107,39 +154,35 @@ def assertFileCheckMatches(
107154
matcher.stderr = StringIO()
108155
result = matcher.run()
109156
if result != 0:
157+
dump_instructions = ""
158+
if self._dump_failed_filechecks:
159+
dump_directory = Path(
160+
datetime.now().strftime("numba-ir-%Y_%m_%d_%H_%M_%S")
161+
)
162+
if not dump_directory.exists():
163+
dump_directory.mkdir(parents=True, exist_ok=True)
164+
base_path = self.id().replace(".", "_")
165+
ir_dump = dump_directory / Path(base_path).with_suffix(".ll")
166+
checks_dump = dump_directory / Path(base_path).with_suffix(
167+
".checks"
168+
)
169+
with (
170+
open(ir_dump, "w") as ir_file,
171+
open(checks_dump, "w") as checks_file,
172+
):
173+
_ = ir_file.write(ir_content + "\n")
174+
_ = checks_file.write(check_patterns)
175+
dump_instructions = f"Reproduce with:\n\nfilecheck --check-prefixes={','.join(check_prefixes)} {checks_dump} --input-file={ir_dump}"
176+
110177
self.fail(
111178
f"FileCheck failed:\n{matcher.stderr.getvalue()}\n\n"
112-
f"Check prefixes:\n{check_prefixes}\n\n"
113-
f"Check patterns:\n{check_patterns}\n"
114-
f"IR:\n{ir_content}\n\n"
179+
+ f"Check prefixes:\n{check_prefixes}\n\n"
180+
+ f"Check patterns:\n{check_patterns}\n"
181+
+ f"IR:\n{ir_content}\n\n"
182+
+ dump_instructions
115183
)
116184

117185

118-
class CUDATestCase(FileCheckTestCaseMixin, TestCase):
119-
"""
120-
For tests that use a CUDA device. Test methods in a CUDATestCase must not
121-
be run out of class order, because a ContextResettingTestCase may reset
122-
the context and destroy resources used by a normal CUDATestCase if any of
123-
its tests are run between tests from a CUDATestCase. Historically this was
124-
ensured with a SerialMixin for the Numba runtests-based test runner, but
125-
with pytest-xdist we must use `--dist loadscope` when running tests in
126-
parallel to ensure that tests from each test class are grouped together.
127-
"""
128-
129-
def setUp(self):
130-
self._low_occupancy_warnings = config.CUDA_LOW_OCCUPANCY_WARNINGS
131-
self._warn_on_implicit_copy = config.CUDA_WARN_ON_IMPLICIT_COPY
132-
133-
# Disable warnings about low gpu utilization in the test suite
134-
config.CUDA_LOW_OCCUPANCY_WARNINGS = 0
135-
# Disable warnings about host arrays in the test suite
136-
config.CUDA_WARN_ON_IMPLICIT_COPY = 0
137-
138-
def tearDown(self):
139-
config.CUDA_LOW_OCCUPANCY_WARNINGS = self._low_occupancy_warnings
140-
config.CUDA_WARN_ON_IMPLICIT_COPY = self._warn_on_implicit_copy
141-
142-
143186
class ContextResettingTestCase(CUDATestCase):
144187
"""
145188
For tests where the context needs to be reset after each test. Typically
@@ -231,8 +274,8 @@ def skip_if_mvc_enabled(reason):
231274
def skip_if_mvc_libraries_unavailable(fn):
232275
libs_available = False
233276
try:
234-
import cubinlinker # noqa: F401
235-
import ptxcompiler # noqa: F401
277+
import cubinlinker # noqa: F401 # type: ignore
278+
import ptxcompiler # noqa: F401 # type: ignore
236279

237280
libs_available = True
238281
except ImportError:

0 commit comments

Comments
 (0)