Skip to content

Commit 92b9be2

Browse files
authored
Merge branch 'main' into rparolin/always_use_cuda_core
2 parents 6d886d6 + 7d0646d commit 92b9be2

File tree

7 files changed

+726
-32
lines changed

7 files changed

+726
-32
lines changed
Lines changed: 288 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,288 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: BSD-2-Clause
3+
4+
"""
5+
This module implements code highlighting of numba-cuda function annotations.
6+
"""
7+
8+
from warnings import warn
9+
10+
warn(
11+
"The pretty_annotate functionality is experimental and might change API",
12+
FutureWarning,
13+
)
14+
15+
16+
def hllines(code, style):
17+
try:
18+
from pygments import highlight
19+
from pygments.lexers import PythonLexer
20+
from pygments.formatters import HtmlFormatter
21+
except ImportError:
22+
raise ImportError("please install the 'pygments' package")
23+
pylex = PythonLexer()
24+
"Given a code string, return a list of html-highlighted lines"
25+
hf = HtmlFormatter(noclasses=True, style=style, nowrap=True)
26+
res = highlight(code, pylex, hf)
27+
return res.splitlines()
28+
29+
30+
def htlines(code, style):
31+
try:
32+
from pygments import highlight
33+
from pygments.lexers import PythonLexer
34+
35+
# TerminalFormatter does not support themes, Terminal256 should,
36+
# but seem to not work.
37+
from pygments.formatters import TerminalFormatter
38+
except ImportError:
39+
raise ImportError("please install the 'pygments' package")
40+
pylex = PythonLexer()
41+
"Given a code string, return a list of ANSI-highlighted lines"
42+
hf = TerminalFormatter(style=style)
43+
res = highlight(code, pylex, hf)
44+
return res.splitlines()
45+
46+
47+
def get_ansi_template():
48+
try:
49+
from jinja2 import Template
50+
except ImportError:
51+
raise ImportError("please install the 'jinja2' package")
52+
return Template("""
53+
{%- for func_key in func_data.keys() -%}
54+
Function name: \x1b[34m{{func_data[func_key]['funcname']}}\x1b[39;49;00m
55+
{%- if func_data[func_key]['filename'] -%}
56+
{{'\n'}}In file: \x1b[34m{{func_data[func_key]['filename'] -}}\x1b[39;49;00m
57+
{%- endif -%}
58+
{{'\n'}}With signature: \x1b[34m{{func_key[1]}}\x1b[39;49;00m
59+
{{- "\n" -}}
60+
{%- for num, line, hl, hc in func_data[func_key]['pygments_lines'] -%}
61+
{{-'\n'}}{{ num}}: {{hc-}}
62+
{%- if func_data[func_key]['ir_lines'][num] -%}
63+
{%- for ir_line, ir_line_type in func_data[func_key]['ir_lines'][num] %}
64+
{{-'\n'}}--{{- ' '*func_data[func_key]['python_indent'][num]}}
65+
{{- ' '*(func_data[func_key]['ir_indent'][num][loop.index0]+4)
66+
}}{{ir_line }}\x1b[41m{{ir_line_type-}}\x1b[39;49;00m
67+
{%- endfor -%}
68+
{%- endif -%}
69+
{%- endfor -%}
70+
{%- endfor -%}
71+
""")
72+
73+
74+
def get_html_template():
75+
try:
76+
from jinja2 import Template
77+
except ImportError:
78+
raise ImportError("please install the 'jinja2' package")
79+
return Template("""
80+
<html>
81+
<head>
82+
<style>
83+
84+
.annotation_table {
85+
color: #000000;
86+
font-family: monospace;
87+
margin: 5px;
88+
width: 100%;
89+
}
90+
91+
/* override JupyterLab style */
92+
.annotation_table td {
93+
text-align: left;
94+
background-color: transparent;
95+
padding: 1px;
96+
}
97+
98+
.annotation_table tbody tr:nth-child(even) {
99+
background: white;
100+
}
101+
102+
.annotation_table code
103+
{
104+
background-color: transparent;
105+
white-space: normal;
106+
}
107+
108+
/* End override JupyterLab style */
109+
110+
tr:hover {
111+
background-color: rgba(92, 200, 249, 0.25);
112+
}
113+
114+
td.object_tag summary ,
115+
td.lifted_tag summary{
116+
font-weight: bold;
117+
display: list-item;
118+
}
119+
120+
span.lifted_tag {
121+
color: #00cc33;
122+
}
123+
124+
span.object_tag {
125+
color: #cc3300;
126+
}
127+
128+
129+
td.lifted_tag {
130+
background-color: #cdf7d8;
131+
}
132+
133+
td.object_tag {
134+
background-color: #fef5c8;
135+
}
136+
137+
code.ir_code {
138+
color: grey;
139+
font-style: italic;
140+
}
141+
142+
.metadata {
143+
border-bottom: medium solid black;
144+
display: inline-block;
145+
padding: 5px;
146+
width: 100%;
147+
}
148+
149+
.annotations {
150+
padding: 5px;
151+
}
152+
153+
.hidden {
154+
display: none;
155+
}
156+
157+
.buttons {
158+
padding: 10px;
159+
cursor: pointer;
160+
}
161+
</style>
162+
</head>
163+
164+
<body>
165+
{% for func_key in func_data.keys() %}
166+
<div class="metadata">
167+
Function name: {{func_data[func_key]['funcname']}}<br />
168+
{% if func_data[func_key]['filename'] %}
169+
in file: {{func_data[func_key]['filename']|escape}}<br />
170+
{% endif %}
171+
with signature: {{func_key[1]|e}}
172+
</div>
173+
<div class="annotations">
174+
<table class="annotation_table tex2jax_ignore">
175+
{%- for num, line, hl, hc in func_data[func_key]['pygments_lines'] -%}
176+
{%- if func_data[func_key]['ir_lines'][num] %}
177+
<tr><td style="text-align:left;" class="{{func_data[func_key]['python_tags'][num]}}">
178+
<details>
179+
<summary>
180+
<code>
181+
{{num}}:
182+
{{'&nbsp;'*func_data[func_key]['python_indent'][num]}}{{hl}}
183+
</code>
184+
</summary>
185+
<table class="annotation_table">
186+
<tbody>
187+
{%- for ir_line, ir_line_type in func_data[func_key]['ir_lines'][num] %}
188+
<tr class="ir_code">
189+
<td style="text-align: left;"><code>
190+
&nbsp;
191+
{{- '&nbsp;'*func_data[func_key]['python_indent'][num]}}
192+
{{ '&nbsp;'*func_data[func_key]['ir_indent'][num][loop.index0]}}{{ir_line|e -}}
193+
<span class="object_tag">{{ir_line_type}}</span>
194+
</code>
195+
</td>
196+
</tr>
197+
{%- endfor -%}
198+
</tbody>
199+
</table>
200+
</details>
201+
</td></tr>
202+
{% else -%}
203+
<tr><td style="text-align:left; padding-left: 22px;" class="{{func_data[func_key]['python_tags'][num]}}">
204+
<code>
205+
{{num}}:
206+
{{'&nbsp;'*func_data[func_key]['python_indent'][num]}}{{hl}}
207+
</code>
208+
</td></tr>
209+
{%- endif -%}
210+
{%- endfor -%}
211+
</table>
212+
</div>
213+
{% endfor %}
214+
</body>
215+
</html>
216+
""")
217+
218+
219+
def reform_code(annotation):
220+
"""
221+
Extract the code from the Numba-cuda annotation datastructure.
222+
223+
Pygments can only highlight full multi-line strings, the Numba-cuda
224+
annotation is list of single lines, with indentation removed.
225+
"""
226+
ident_dict = annotation["python_indent"]
227+
s = ""
228+
for n, l in annotation["python_lines"]:
229+
s = s + " " * ident_dict[n] + l + "\n"
230+
return s
231+
232+
233+
class Annotate:
234+
"""
235+
Construct syntax highlighted annotation for a given jitted function:
236+
237+
Example:
238+
239+
>>> from numba import cuda
240+
>>> import numpy as np
241+
>>> from numba.cuda.core.annotations.pretty_annotate import Annotate
242+
>>> @cuda.jit
243+
... def test(a):
244+
... tid = cuda.grid(1)
245+
... size = len(a)
246+
... if tid < size:
247+
... a[tid] = 1
248+
>>> test[(4), (16)](np.ones(100))
249+
>>> Annotate(test)
250+
251+
The last line will return an HTML and/or ANSI representation that will be
252+
displayed accordingly in Jupyter/IPython.
253+
254+
Function annotations persist across compilation for newly encountered
255+
type signatures and as a result annotations are shown for all signatures
256+
by default.
257+
258+
Annotations for a specific signature can be shown by using the
259+
``signature`` parameter. For the above jitted function:
260+
261+
>>> test.signatures
262+
[(Array(float64, 1, 'C', False, aligned=True),)]
263+
>>> Annotate(f, signature=f.signatures[0])
264+
# annotation for Array(float64, 1, 'C', False, aligned=True)
265+
"""
266+
267+
def __init__(self, function, signature=None, **kwargs):
268+
style = kwargs.get("style", "default")
269+
if not function.signatures:
270+
raise ValueError(
271+
"function need to be jitted for at least one signature"
272+
)
273+
ann = function.get_annotation_info(signature=signature)
274+
self.ann = ann
275+
276+
for k, v in ann.items():
277+
res = hllines(reform_code(v), style)
278+
rest = htlines(reform_code(v), style)
279+
v["pygments_lines"] = [
280+
(a, b, c, d)
281+
for (a, b), c, d in zip(v["python_lines"], res, rest)
282+
]
283+
284+
def _repr_html_(self):
285+
return get_html_template().render(func_data=self.ann)
286+
287+
def __repr__(self):
288+
return get_ansi_template().render(func_data=self.ann)

numba_cuda/numba/cuda/core/annotations/type_annotations.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import textwrap
1212
from io import StringIO
1313

14-
import numba.core.dispatcher
1514
from numba.core import ir
1615

1716

@@ -83,6 +82,8 @@ def __init__(
8382
self.lifted_from = lifted_from
8483

8584
def prepare_annotations(self):
85+
from numba.cuda.dispatcher import LiftedLoop
86+
8687
# Prepare annotations
8788
groupedinst = defaultdict(list)
8889
found_lifted_loop = False
@@ -103,7 +104,7 @@ def prepare_annotations(self):
103104
):
104105
atype = self.calltypes[inst.value]
105106
elif isinstance(inst.value, ir.Const) and isinstance(
106-
inst.value.value, numba.core.dispatcher.LiftedLoop
107+
inst.value.value, LiftedLoop
107108
):
108109
atype = "XXX Lifted Loop XXX"
109110
found_lifted_loop = True
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: BSD-2-Clause
3+
4+
import logging
5+
import warnings
6+
7+
from importlib import metadata as importlib_metadata
8+
9+
10+
_already_initialized = False
11+
logger = logging.getLogger(__name__)
12+
13+
14+
def init_all():
15+
"""Execute all `numba_cuda_extensions` entry points with the name `init`
16+
17+
If extensions have already been initialized, this function does nothing.
18+
"""
19+
try:
20+
from numba.core import entrypoints
21+
22+
entrypoints.init_all()
23+
except ImportError:
24+
pass
25+
26+
global _already_initialized
27+
if _already_initialized:
28+
return
29+
30+
# Must put this here to avoid extensions re-triggering initialization
31+
_already_initialized = True
32+
33+
def load_ep(entry_point):
34+
"""Loads a given entry point. Warns and logs on failure."""
35+
logger.debug("Loading extension: %s", entry_point)
36+
try:
37+
func = entry_point.load()
38+
func()
39+
except Exception as e:
40+
msg = (
41+
f"Numba extension module '{entry_point.module}' "
42+
f"failed to load due to '{type(e).__name__}({str(e)})'."
43+
)
44+
warnings.warn(msg, stacklevel=3)
45+
logger.debug("Extension loading failed for: %s", entry_point)
46+
47+
eps = importlib_metadata.entry_points()
48+
# Split, Python 3.10+ and importlib_metadata 3.6+ have the "selectable"
49+
# interface, versions prior to that do not. See "compatibility note" in:
50+
# https://docs.python.org/3.10/library/importlib.metadata.html#entry-points
51+
if hasattr(eps, "select"):
52+
for entry_point in eps.select(
53+
group="numba_cuda_extensions", name="init"
54+
):
55+
load_ep(entry_point)
56+
else:
57+
for entry_point in eps.get("numba_cuda_extensions", ()):
58+
if entry_point.name == "init":
59+
load_ep(entry_point)

numba_cuda/numba/cuda/core/transforms.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def _loop_lift_modify_blocks(
193193
Modify the block inplace to call to the lifted-loop.
194194
Returns a dictionary of blocks of the lifted-loop.
195195
"""
196-
from numba.core.dispatcher import LiftedLoop
196+
from numba.cuda.dispatcher import LiftedLoop
197197

198198
# Copy loop blocks
199199
loop = loopinfo.loop
@@ -402,7 +402,7 @@ def with_lifting(func_ir, typingctx, targetctx, flags, locals):
402402
from numba.cuda.core import postproc
403403

404404
def dispatcher_factory(func_ir, objectmode=False, **kwargs):
405-
from numba.core.dispatcher import LiftedWith, ObjModeLiftedWith
405+
from numba.cuda.dispatcher import LiftedWith, ObjModeLiftedWith
406406

407407
myflags = flags.copy()
408408
if objectmode:

0 commit comments

Comments
 (0)