Skip to content

Commit f5c5552

Browse files
committed
Vendor in generators and removerefctpass for CUDA-specific changes
1 parent 8095d5e commit f5c5552

File tree

2 files changed

+510
-0
lines changed

2 files changed

+510
-0
lines changed
Lines changed: 387 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,387 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: BSD-2-Clause
3+
4+
"""
5+
Support for lowering generators.
6+
"""
7+
8+
import llvmlite.ir
9+
from llvmlite.ir import Constant, IRBuilder
10+
11+
from numba.cuda import types, config, cgutils
12+
from numba.cuda.core.funcdesc import FunctionDescriptor
13+
14+
15+
class GeneratorDescriptor(FunctionDescriptor):
16+
"""
17+
The descriptor for a generator's next function.
18+
"""
19+
20+
__slots__ = ()
21+
22+
@classmethod
23+
def from_generator_fndesc(cls, func_ir, fndesc, gentype, mangler):
24+
"""
25+
Build a GeneratorDescriptor for the generator returned by the
26+
function described by *fndesc*, with type *gentype*.
27+
28+
The generator inherits the env_name from the *fndesc*.
29+
All emitted functions for the generator shares the same Env.
30+
"""
31+
assert isinstance(gentype, types.Generator)
32+
restype = gentype.yield_type
33+
args = ["gen"]
34+
argtypes = (gentype,)
35+
qualname = fndesc.qualname + ".next"
36+
unique_name = fndesc.unique_name + ".next"
37+
self = cls(
38+
fndesc.native,
39+
fndesc.modname,
40+
qualname,
41+
unique_name,
42+
fndesc.doc,
43+
fndesc.typemap,
44+
restype,
45+
fndesc.calltypes,
46+
args,
47+
fndesc.kws,
48+
argtypes=argtypes,
49+
mangler=mangler,
50+
inline=False,
51+
env_name=fndesc.env_name,
52+
)
53+
return self
54+
55+
@property
56+
def llvm_finalizer_name(self):
57+
"""
58+
The LLVM name of the generator's finalizer function
59+
(if <generator type>.has_finalizer is true).
60+
"""
61+
return "finalize_" + self.mangled_name
62+
63+
64+
class BaseGeneratorLower(object):
65+
"""
66+
Base support class for lowering generators.
67+
"""
68+
69+
def __init__(self, lower):
70+
self.context = lower.context
71+
self.fndesc = lower.fndesc
72+
self.library = lower.library
73+
self.func_ir = lower.func_ir
74+
self.lower = lower
75+
76+
self.geninfo = lower.generator_info
77+
self.gentype = self.get_generator_type()
78+
self.gendesc = GeneratorDescriptor.from_generator_fndesc(
79+
lower.func_ir, self.fndesc, self.gentype, self.context.mangler
80+
)
81+
# Helps packing non-omitted arguments into a structure
82+
self.arg_packer = self.context.get_data_packer(self.fndesc.argtypes)
83+
84+
self.resume_blocks = {}
85+
86+
@property
87+
def call_conv(self):
88+
return self.lower.call_conv
89+
90+
def get_args_ptr(self, builder, genptr):
91+
return cgutils.gep_inbounds(builder, genptr, 0, 1)
92+
93+
def get_resume_index_ptr(self, builder, genptr):
94+
return cgutils.gep_inbounds(
95+
builder, genptr, 0, 0, name="gen.resume_index"
96+
)
97+
98+
def get_state_ptr(self, builder, genptr):
99+
return cgutils.gep_inbounds(builder, genptr, 0, 2, name="gen.state")
100+
101+
def lower_init_func(self, lower):
102+
"""
103+
Lower the generator's initialization function (which will fill up
104+
the passed-by-reference generator structure).
105+
"""
106+
lower.setup_function(self.fndesc)
107+
108+
builder = lower.builder
109+
110+
# Insert the generator into the target context in order to allow
111+
# calling from other Numba-compiled functions.
112+
lower.context.insert_generator(
113+
self.gentype, self.gendesc, [self.library]
114+
)
115+
116+
# Init argument values
117+
lower.extract_function_arguments()
118+
119+
lower.pre_lower()
120+
121+
# Initialize the return structure (i.e. the generator structure).
122+
retty = self.context.get_return_type(self.gentype)
123+
# Structure index #0: the initial resume index (0 == start of generator)
124+
resume_index = self.context.get_constant(types.int32, 0)
125+
# Structure index #2: the states
126+
statesty = retty.elements[2]
127+
128+
lower.debug_print("# low_init_func incref")
129+
# Incref all NRT arguments before storing into generator states
130+
if self.context.enable_nrt:
131+
for argty, argval in zip(self.fndesc.argtypes, lower.fnargs):
132+
self.context.nrt.incref(builder, argty, argval)
133+
134+
# Filter out omitted arguments
135+
argsval = self.arg_packer.as_data(builder, lower.fnargs)
136+
137+
# Zero initialize states
138+
statesval = Constant(statesty, None)
139+
gen_struct = cgutils.make_anonymous_struct(
140+
builder, [resume_index, argsval, statesval], retty
141+
)
142+
143+
retval = self.box_generator_struct(lower, gen_struct)
144+
145+
lower.debug_print("# low_init_func before return")
146+
self.call_conv.return_value(builder, retval)
147+
lower.post_lower()
148+
149+
def lower_next_func(self, lower):
150+
"""
151+
Lower the generator's next() function (which takes the
152+
passed-by-reference generator structure and returns the next
153+
yielded value).
154+
"""
155+
lower.setup_function(self.gendesc)
156+
lower.debug_print(
157+
"# lower_next_func: {0}".format(self.gendesc.unique_name)
158+
)
159+
assert self.gendesc.argtypes[0] == self.gentype
160+
builder = lower.builder
161+
function = lower.function
162+
163+
# Extract argument values and other information from generator struct
164+
(genptr,) = self.call_conv.get_arguments(function)
165+
self.arg_packer.load_into(
166+
builder, self.get_args_ptr(builder, genptr), lower.fnargs
167+
)
168+
169+
self.resume_index_ptr = self.get_resume_index_ptr(builder, genptr)
170+
self.gen_state_ptr = self.get_state_ptr(builder, genptr)
171+
172+
prologue = function.append_basic_block("generator_prologue")
173+
174+
# Lower the generator's Python code
175+
entry_block_tail = lower.lower_function_body()
176+
177+
# Add block for StopIteration on entry
178+
stop_block = function.append_basic_block("stop_iteration")
179+
builder.position_at_end(stop_block)
180+
self.call_conv.return_stop_iteration(builder)
181+
182+
# Add prologue switch to resume blocks
183+
builder.position_at_end(prologue)
184+
# First Python block is also the resume point on first next() call
185+
self.resume_blocks[0] = lower.blkmap[lower.firstblk]
186+
187+
# Create front switch to resume points
188+
switch = builder.switch(builder.load(self.resume_index_ptr), stop_block)
189+
for index, block in self.resume_blocks.items():
190+
switch.add_case(index, block)
191+
192+
# Close tail of entry block
193+
builder.position_at_end(entry_block_tail)
194+
builder.branch(prologue)
195+
196+
def lower_finalize_func(self, lower):
197+
"""
198+
Lower the generator's finalizer.
199+
"""
200+
fnty = llvmlite.ir.FunctionType(
201+
llvmlite.ir.VoidType(), [self.context.get_value_type(self.gentype)]
202+
)
203+
function = cgutils.get_or_insert_function(
204+
lower.module, fnty, self.gendesc.llvm_finalizer_name
205+
)
206+
entry_block = function.append_basic_block("entry")
207+
builder = IRBuilder(entry_block)
208+
209+
genptrty = self.context.get_value_type(self.gentype)
210+
genptr = builder.bitcast(function.args[0], genptrty)
211+
self.lower_finalize_func_body(builder, genptr)
212+
213+
def return_from_generator(self, lower):
214+
"""
215+
Emit a StopIteration at generator end and mark the generator exhausted.
216+
"""
217+
indexval = Constant(self.resume_index_ptr.type.pointee, -1)
218+
lower.builder.store(indexval, self.resume_index_ptr)
219+
self.call_conv.return_stop_iteration(lower.builder)
220+
221+
def create_resumption_block(self, lower, index):
222+
block_name = "generator_resume%d" % (index,)
223+
block = lower.function.append_basic_block(block_name)
224+
lower.builder.position_at_end(block)
225+
self.resume_blocks[index] = block
226+
227+
def debug_print(self, builder, msg):
228+
if config.DEBUG_JIT:
229+
self.context.debug_print(builder, "DEBUGJIT: {0}".format(msg))
230+
231+
232+
class GeneratorLower(BaseGeneratorLower):
233+
"""
234+
Support class for lowering nopython generators.
235+
"""
236+
237+
def get_generator_type(self):
238+
return self.fndesc.restype
239+
240+
def box_generator_struct(self, lower, gen_struct):
241+
return gen_struct
242+
243+
def lower_finalize_func_body(self, builder, genptr):
244+
"""
245+
Lower the body of the generator's finalizer: decref all live
246+
state variables.
247+
"""
248+
self.debug_print(builder, "# generator: finalize")
249+
if self.context.enable_nrt:
250+
# Always dereference all arguments
251+
# self.debug_print(builder, "# generator: clear args")
252+
args_ptr = self.get_args_ptr(builder, genptr)
253+
for ty, val in self.arg_packer.load(builder, args_ptr):
254+
self.context.nrt.decref(builder, ty, val)
255+
256+
self.debug_print(builder, "# generator: finalize end")
257+
builder.ret_void()
258+
259+
260+
class PyGeneratorLower(BaseGeneratorLower):
261+
"""
262+
Support class for lowering object mode generators.
263+
"""
264+
265+
def get_generator_type(self):
266+
"""
267+
Compute the actual generator type (the generator function's return
268+
type is simply "pyobject").
269+
"""
270+
return types.Generator(
271+
gen_func=self.func_ir.func_id.func,
272+
yield_type=types.pyobject,
273+
arg_types=(types.pyobject,) * self.func_ir.arg_count,
274+
state_types=(types.pyobject,) * len(self.geninfo.state_vars),
275+
has_finalizer=True,
276+
)
277+
278+
def box_generator_struct(self, lower, gen_struct):
279+
"""
280+
Box the raw *gen_struct* as a Python object.
281+
"""
282+
gen_ptr = cgutils.alloca_once_value(lower.builder, gen_struct)
283+
return lower.pyapi.from_native_generator(
284+
gen_ptr, self.gentype, lower.envarg
285+
)
286+
287+
def init_generator_state(self, lower):
288+
"""
289+
NULL-initialize all generator state variables, to avoid spurious
290+
decref's on cleanup.
291+
"""
292+
lower.builder.store(
293+
Constant(self.gen_state_ptr.type.pointee, None), self.gen_state_ptr
294+
)
295+
296+
def lower_finalize_func_body(self, builder, genptr):
297+
"""
298+
Lower the body of the generator's finalizer: decref all live
299+
state variables.
300+
"""
301+
pyapi = self.context.get_python_api(builder)
302+
resume_index_ptr = self.get_resume_index_ptr(builder, genptr)
303+
resume_index = builder.load(resume_index_ptr)
304+
# If resume_index is 0, next() was never called
305+
# If resume_index is -1, generator terminated cleanly
306+
# (note function arguments are saved in state variables,
307+
# so they don't need a separate cleanup step)
308+
need_cleanup = builder.icmp_signed(
309+
">", resume_index, Constant(resume_index.type, 0)
310+
)
311+
312+
with cgutils.if_unlikely(builder, need_cleanup):
313+
# Decref all live vars (some may be NULL)
314+
gen_state_ptr = self.get_state_ptr(builder, genptr)
315+
for state_index in range(len(self.gentype.state_types)):
316+
state_slot = cgutils.gep_inbounds(
317+
builder, gen_state_ptr, 0, state_index
318+
)
319+
ty = self.gentype.state_types[state_index]
320+
val = self.context.unpack_value(builder, ty, state_slot)
321+
pyapi.decref(val)
322+
323+
builder.ret_void()
324+
325+
326+
class LowerYield(object):
327+
"""
328+
Support class for lowering a particular yield point.
329+
"""
330+
331+
def __init__(self, lower, yield_point, live_vars):
332+
self.lower = lower
333+
self.context = lower.context
334+
self.builder = lower.builder
335+
self.genlower = lower.genlower
336+
self.gentype = self.genlower.gentype
337+
338+
self.gen_state_ptr = self.genlower.gen_state_ptr
339+
self.resume_index_ptr = self.genlower.resume_index_ptr
340+
self.yp = yield_point
341+
self.inst = self.yp.inst
342+
self.live_vars = live_vars
343+
self.live_var_indices = [
344+
lower.generator_info.state_vars.index(v) for v in live_vars
345+
]
346+
347+
def lower_yield_suspend(self):
348+
self.lower.debug_print("# generator suspend")
349+
# Save live vars in state
350+
for state_index, name in zip(self.live_var_indices, self.live_vars):
351+
state_slot = cgutils.gep_inbounds(
352+
self.builder, self.gen_state_ptr, 0, state_index
353+
)
354+
ty = self.gentype.state_types[state_index]
355+
# The yield might be in a loop, in which case the state might
356+
# contain a predicate var that branches back to the loop head, in
357+
# this case the var is live but in sequential lowering won't have
358+
# been alloca'd yet, so do this here.
359+
fetype = self.lower.typeof(name)
360+
self.lower._alloca_var(name, fetype)
361+
val = self.lower.loadvar(name)
362+
# IncRef newly stored value
363+
if self.context.enable_nrt:
364+
self.context.nrt.incref(self.builder, ty, val)
365+
366+
self.context.pack_value(self.builder, ty, val, state_slot)
367+
# Save resume index
368+
indexval = Constant(self.resume_index_ptr.type.pointee, self.inst.index)
369+
self.builder.store(indexval, self.resume_index_ptr)
370+
self.lower.debug_print("# generator suspend end")
371+
372+
def lower_yield_resume(self):
373+
# Emit resumption point
374+
self.genlower.create_resumption_block(self.lower, self.inst.index)
375+
self.lower.debug_print("# generator resume")
376+
# Reload live vars from state
377+
for state_index, name in zip(self.live_var_indices, self.live_vars):
378+
state_slot = cgutils.gep_inbounds(
379+
self.builder, self.gen_state_ptr, 0, state_index
380+
)
381+
ty = self.gentype.state_types[state_index]
382+
val = self.context.unpack_value(self.builder, ty, state_slot)
383+
self.lower.storevar(val, name)
384+
# Previous storevar is making an extra incref
385+
if self.context.enable_nrt:
386+
self.context.nrt.decref(self.builder, ty, val)
387+
self.lower.debug_print("# generator resume end")

0 commit comments

Comments
 (0)