11import os
22import platform
33import shutil
4-
4+ import pytest
5+ from datetime import datetime
56from numba .core .utils import PYVERSION
67from numba .cuda .cuda_paths import get_conda_ctk
78from numba .cuda .cudadrv import driver , devices , libs
89from numba .cuda .dispatcher import CUDADispatcher
910from numba .core import config
1011from numba .tests .support import TestCase
1112from pathlib import Path
12- from typing import Union
13+
14+ from typing import Iterable , Union
1315from io import StringIO
1416import unittest
1517
1618if PYVERSION >= (3 , 10 ):
17- from filecheck .matcher import Matcher , Options
19+ from filecheck .matcher import Matcher
20+ from filecheck .options import Options
1821 from filecheck .parser import Parser , pattern_for_opts
1922 from filecheck .finput import FInput
2023
2124numba_cuda_dir = Path (__file__ ).parent
2225test_data_dir = numba_cuda_dir / "tests" / "data"
2326
2427
25- class FileCheckTestCaseMixin :
28+ @pytest .mark .usefixtures ("initialize_from_pytest_config" )
29+ class CUDATestCase (TestCase ):
2630 """
27- Mixin for tests that use FileCheck.
31+ For tests that use a CUDA device. Test methods in a CUDATestCase must not
32+ be run out of module order, because the ContextResettingTestCase may reset
33+ the context and destroy resources used by a normal CUDATestCase if any of
34+ its tests are run between tests from a CUDATestCase.
2835
2936 Methods assertFileCheckAsm and assertFileCheckLLVM will inspect a
3037 CUDADispatcher and assert that the compilation artifacts match the
@@ -34,56 +41,96 @@ class FileCheckTestCaseMixin:
3441 matches FileCheck checks, and is not specific to CUDADispatcher.
3542 """
3643
44+ def setUp (self ):
45+ self ._low_occupancy_warnings = config .CUDA_LOW_OCCUPANCY_WARNINGS
46+ self ._warn_on_implicit_copy = config .CUDA_WARN_ON_IMPLICIT_COPY
47+
48+ # Disable warnings about low gpu utilization in the test suite
49+ config .CUDA_LOW_OCCUPANCY_WARNINGS = 0
50+ # Disable warnings about host arrays in the test suite
51+ config .CUDA_WARN_ON_IMPLICIT_COPY = 0
52+
53+ def tearDown (self ):
54+ config .CUDA_LOW_OCCUPANCY_WARNINGS = self ._low_occupancy_warnings
55+ config .CUDA_WARN_ON_IMPLICIT_COPY = self ._warn_on_implicit_copy
56+
57+ Signature = Union [tuple [type , ...], None ]
58+
59+ def _getIRContents (
60+ self ,
61+ ir_result : Union [dict [Signature , str ], str ],
62+ signature : Union [Signature , None ] = None ,
63+ ) -> Iterable [str ]:
64+ if isinstance (ir_result , str ):
65+ assert signature is None , (
66+ "Cannot use signature because the kernel was only compiled for one signature"
67+ )
68+ return [ir_result ]
69+
70+ if signature is None :
71+ return list (ir_result .values ())
72+
73+ return [ir_result [signature ]]
74+
3775 def assertFileCheckAsm (
3876 self ,
3977 ir_producer : CUDADispatcher ,
4078 signature : Union [tuple [type , ...], None ] = None ,
41- check_prefixes : list [str ] = ("ASM" ,),
42- ** extra_filecheck_options : dict [ str , Union [ str , int ]] ,
79+ check_prefixes : tuple [str ] = ("ASM" ,),
80+ ** extra_filecheck_options ,
4381 ) -> None :
4482 """
4583 Assert that the assembly output of the given CUDADispatcher matches
4684 the FileCheck checks given in the kernel's docstring.
4785 """
48- ir_content = ir_producer .inspect_asm ()
49- if signature :
50- ir_content = ir_content [signature ]
51- check_patterns = ir_producer .__doc__
52- self .assertFileCheckMatches (
53- ir_content ,
54- check_patterns = check_patterns ,
55- check_prefixes = check_prefixes ,
56- ** extra_filecheck_options ,
86+ ir_contents = self ._getIRContents (ir_producer .inspect_asm (), signature )
87+ assert ir_contents , "No assembly output found for the given signature."
88+ assert ir_producer .__doc__ is not None , (
89+ "Kernel docstring is required. To pass checks explicitly, use assertFileCheckMatches."
5790 )
91+ check_patterns = ir_producer .__doc__
92+ for ir_content in ir_contents :
93+ self .assertFileCheckMatches (
94+ ir_content ,
95+ check_patterns = check_patterns ,
96+ check_prefixes = check_prefixes ,
97+ ** extra_filecheck_options ,
98+ )
5899
59100 def assertFileCheckLLVM (
60101 self ,
61102 ir_producer : CUDADispatcher ,
62103 signature : Union [tuple [type , ...], None ] = None ,
63- check_prefixes : list [str ] = ("LLVM" ,),
64- ** extra_filecheck_options : dict [ str , Union [ str , int ]] ,
104+ check_prefixes : tuple [str ] = ("LLVM" ,),
105+ ** extra_filecheck_options ,
65106 ) -> None :
66107 """
67108 Assert that the LLVM IR output of the given CUDADispatcher matches
68109 the FileCheck checks given in the kernel's docstring.
69110 """
70- ir_content = ir_producer .inspect_llvm ()
71- if signature :
72- ir_content = ir_content [signature ]
73- check_patterns = ir_producer .__doc__
74- self .assertFileCheckMatches (
75- ir_content ,
76- check_patterns = check_patterns ,
77- check_prefixes = check_prefixes ,
78- ** extra_filecheck_options ,
111+ ir_contents = self ._getIRContents (ir_producer .inspect_llvm (), signature )
112+ assert ir_contents , "No LLVM IR output found for the given signature."
113+ assert ir_producer .__doc__ is not None , (
114+ "Kernel docstring is required. To pass checks explicitly, use assertFileCheckMatches."
79115 )
116+ check_patterns = ir_producer .__doc__
117+ for ir_content in ir_contents :
118+ assert ir_content , (
119+ "LLVM IR content is empty for the given signature."
120+ )
121+ self .assertFileCheckMatches (
122+ ir_content ,
123+ check_patterns = check_patterns ,
124+ check_prefixes = check_prefixes ,
125+ ** extra_filecheck_options ,
126+ )
80127
81128 def assertFileCheckMatches (
82129 self ,
83130 ir_content : str ,
84131 check_patterns : str ,
85- check_prefixes : list [str ] = ("CHECK" ,),
86- ** extra_filecheck_options : dict [ str , Union [ str , int ]] ,
132+ check_prefixes : tuple [str ] = ("CHECK" ,),
133+ ** extra_filecheck_options ,
87134 ) -> None :
88135 """
89136 Assert that the given string matches the passed FileCheck checks.
@@ -98,7 +145,7 @@ def assertFileCheckMatches(
98145 self .skipTest ("FileCheck requires Python 3.10 or later" )
99146 opts = Options (
100147 match_filename = "-" ,
101- check_prefixes = check_prefixes ,
148+ check_prefixes = list ( check_prefixes ) ,
102149 ** extra_filecheck_options ,
103150 )
104151 input_file = FInput (fname = "-" , content = ir_content )
@@ -107,39 +154,35 @@ def assertFileCheckMatches(
107154 matcher .stderr = StringIO ()
108155 result = matcher .run ()
109156 if result != 0 :
157+ dump_instructions = ""
158+ if self ._dump_failed_filechecks :
159+ dump_directory = Path (
160+ datetime .now ().strftime ("numba-ir-%Y_%m_%d_%H_%M_%S" )
161+ )
162+ if not dump_directory .exists ():
163+ dump_directory .mkdir (parents = True , exist_ok = True )
164+ base_path = self .id ().replace ("." , "_" )
165+ ir_dump = dump_directory / Path (base_path ).with_suffix (".ll" )
166+ checks_dump = dump_directory / Path (base_path ).with_suffix (
167+ ".checks"
168+ )
169+ with (
170+ open (ir_dump , "w" ) as ir_file ,
171+ open (checks_dump , "w" ) as checks_file ,
172+ ):
173+ _ = ir_file .write (ir_content + "\n " )
174+ _ = checks_file .write (check_patterns )
175+ dump_instructions = f"Reproduce with:\n \n filecheck --check-prefixes={ ',' .join (check_prefixes )} { checks_dump } --input-file={ ir_dump } "
176+
110177 self .fail (
111178 f"FileCheck failed:\n { matcher .stderr .getvalue ()} \n \n "
112- f"Check prefixes:\n { check_prefixes } \n \n "
113- f"Check patterns:\n { check_patterns } \n "
114- f"IR:\n { ir_content } \n \n "
179+ + f"Check prefixes:\n { check_prefixes } \n \n "
180+ + f"Check patterns:\n { check_patterns } \n "
181+ + f"IR:\n { ir_content } \n \n "
182+ + dump_instructions
115183 )
116184
117185
118- class CUDATestCase (FileCheckTestCaseMixin , TestCase ):
119- """
120- For tests that use a CUDA device. Test methods in a CUDATestCase must not
121- be run out of class order, because a ContextResettingTestCase may reset
122- the context and destroy resources used by a normal CUDATestCase if any of
123- its tests are run between tests from a CUDATestCase. Historically this was
124- ensured with a SerialMixin for the Numba runtests-based test runner, but
125- with pytest-xdist we must use `--dist loadscope` when running tests in
126- parallel to ensure that tests from each test class are grouped together.
127- """
128-
129- def setUp (self ):
130- self ._low_occupancy_warnings = config .CUDA_LOW_OCCUPANCY_WARNINGS
131- self ._warn_on_implicit_copy = config .CUDA_WARN_ON_IMPLICIT_COPY
132-
133- # Disable warnings about low gpu utilization in the test suite
134- config .CUDA_LOW_OCCUPANCY_WARNINGS = 0
135- # Disable warnings about host arrays in the test suite
136- config .CUDA_WARN_ON_IMPLICIT_COPY = 0
137-
138- def tearDown (self ):
139- config .CUDA_LOW_OCCUPANCY_WARNINGS = self ._low_occupancy_warnings
140- config .CUDA_WARN_ON_IMPLICIT_COPY = self ._warn_on_implicit_copy
141-
142-
143186class ContextResettingTestCase (CUDATestCase ):
144187 """
145188 For tests where the context needs to be reset after each test. Typically
@@ -231,8 +274,8 @@ def skip_if_mvc_enabled(reason):
231274def skip_if_mvc_libraries_unavailable (fn ):
232275 libs_available = False
233276 try :
234- import cubinlinker # noqa: F401
235- import ptxcompiler # noqa: F401
277+ import cubinlinker # noqa: F401 # type: ignore
278+ import ptxcompiler # noqa: F401 # type: ignore
236279
237280 libs_available = True
238281 except ImportError :
0 commit comments