Skip to content

Commit 0b95844

Browse files
committed
Adapted transformations to basic block changes
1 parent d28a326 commit 0b95844

File tree

1 file changed

+12
-41
lines changed

1 file changed

+12
-41
lines changed

numba_rvsdg/core/transformations.py

Lines changed: 12 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,7 @@ def loop_restructure_helper(scfg: SCFG, loop: Set[str]):
5858
and len(exiting_blocks) == 1
5959
and backedge_blocks[0] == next(iter(exiting_blocks))
6060
):
61-
scfg.add_block(
62-
scfg.graph.pop(backedge_blocks[0]).declare_backedge(loop_head)
63-
)
61+
scfg.graph[backedge_blocks[0]].declare_backedge(loop_head)
6462
return
6563

6664
# The synthetic exiting latch and synthetic exit need to be created
@@ -147,7 +145,6 @@ def reverse_lookup(d, value):
147145
synth_assign_block = SyntheticAssignment(
148146
name=synth_assign,
149147
_jump_targets=(synth_exiting_latch,),
150-
backedges=(),
151148
variable_assignment=variable_assignment,
152149
)
153150
# Insert the assignment to the scfg
@@ -177,20 +174,17 @@ def reverse_lookup(d, value):
177174
# that point to the headers, no need to add a backedge,
178175
# since it will be contained in the SyntheticExitingLatch
179176
# later on.
180-
block = scfg.graph.pop(name)
177+
block = scfg.graph[name]
181178
jts = list(block.jump_targets)
182179
for h in headers:
183180
if h in jts:
184181
jts.remove(h)
185-
scfg.add_block(
186-
block.replace_jump_targets(jump_targets=tuple(jts))
187-
)
182+
block.change_jump_targets(jump_targets=tuple(jts))
188183
# Setup the assignment block and initialize it with the
189184
# correct jump_targets and variable assignment.
190185
synth_assign_block = SyntheticAssignment(
191186
name=synth_assign,
192187
_jump_targets=(synth_exiting_latch,),
193-
backedges=(),
194188
variable_assignment=variable_assignment,
195189
)
196190
# Add the new block to the SCFG
@@ -199,11 +193,7 @@ def reverse_lookup(d, value):
199193
new_jt[new_jt.index(jt)] = synth_assign
200194
# finally, replace the jump_targets for this block with the new
201195
# ones
202-
scfg.add_block(
203-
scfg.graph.pop(name).replace_jump_targets(
204-
jump_targets=tuple(new_jt)
205-
)
206-
)
196+
scfg.graph[name].change_jump_targets(jump_targets=tuple(new_jt))
207197
# Add any new blocks to the loop.
208198
loop.update(new_blocks)
209199

@@ -214,10 +204,10 @@ def reverse_lookup(d, value):
214204
synth_exit if needs_synth_exit else next(iter(exit_blocks)),
215205
loop_head,
216206
),
217-
backedges=(loop_head,),
218207
variable=backedge_variable,
219208
branch_value_table=backedge_value_table,
220209
)
210+
synth_exiting_latch_block.declare_backedge(loop_head)
221211
loop.add(synth_exiting_latch)
222212
scfg.add_block(synth_exiting_latch_block)
223213
# If an exit is to be created, we do so too, but only add it to the scfg,
@@ -226,7 +216,6 @@ def reverse_lookup(d, value):
226216
synth_exit_block = SyntheticExitBranch(
227217
name=synth_exit,
228218
_jump_targets=tuple(exit_blocks),
229-
backedges=(),
230219
variable=exit_variable,
231220
branch_value_table=exit_value_table,
232221
)
@@ -329,29 +318,18 @@ def update_exiting(
329318
):
330319
# Recursively updates the exiting blocks of a regionblock
331320
region_exiting = region_block.exiting
332-
region_exiting_block: BasicBlock = region_block.subregion.graph.pop(
321+
region_exiting_block: BasicBlock = region_block.subregion.graph[
333322
region_exiting
334-
)
323+
]
335324
jt = list(region_exiting_block._jump_targets)
336325
for idx, s in enumerate(jt):
337326
if s is new_region_header:
338327
jt[idx] = new_region_name
339-
region_exiting_block = region_exiting_block.replace_jump_targets(
340-
jump_targets=tuple(jt)
341-
)
342-
be = list(region_exiting_block.backedges)
343-
for idx, s in enumerate(be):
344-
if s is new_region_header:
345-
be[idx] = new_region_name
346-
region_exiting_block = region_exiting_block.replace_backedges(
347-
backedges=tuple(be)
348-
)
328+
region_exiting_block.change_jump_targets(jump_targets=tuple(jt))
349329
if isinstance(region_exiting_block, RegionBlock):
350330
region_exiting_block = update_exiting(
351331
region_exiting_block, new_region_header, new_region_name
352332
)
353-
region_block.subregion.add_block(region_exiting_block)
354-
return region_block
355333

356334

357335
def extract_region(
@@ -381,27 +359,20 @@ def extract_region(
381359
# the SCFG represents should not be the meta region.
382360
assert scfg.region.kind != "meta"
383361
continue
384-
entry = scfg.graph.pop(name)
362+
entry = scfg.graph[name]
385363
jt = list(entry._jump_targets)
386364
for idx, s in enumerate(jt):
387365
if s is region_header:
388366
jt[idx] = region_name
389-
entry = entry.replace_jump_targets(jump_targets=tuple(jt))
390-
be = list(entry.backedges)
391-
for idx, s in enumerate(be):
392-
if s is region_header:
393-
be[idx] = region_name
394-
entry = entry.replace_backedges(backedges=tuple(be))
367+
entry.change_jump_targets(jump_targets=tuple(jt))
395368
# If the entry itself is a region, update it's
396369
# exiting blocks too, recursively
397370
if isinstance(entry, RegionBlock):
398-
entry = update_exiting(entry, region_header, region_name)
399-
scfg.add_block(entry)
371+
update_exiting(entry, region_header, region_name)
400372

401373
region = RegionBlock(
402374
name=region_name,
403375
_jump_targets=scfg[region_exiting].jump_targets,
404-
backedges=(),
405376
kind=region_kind,
406377
header=region_header,
407378
subregion=head_subgraph,
@@ -426,7 +397,7 @@ def extract_region(
426397
# update the parent region
427398
for k, v in region.subregion.graph.items():
428399
if isinstance(v, RegionBlock):
429-
object.__setattr__(v, "parent_region", region)
400+
v.replace_parent(region)
430401

431402

432403
def restructure_branch(parent_region: RegionBlock):

0 commit comments

Comments
 (0)