Skip to content

Commit a1f1316

Browse files
committed
Refactored core datastructures to be compatible with the changes
1 parent 6d35210 commit a1f1316

File tree

5 files changed

+70
-125
lines changed

5 files changed

+70
-125
lines changed

numba_rvsdg/core/datastructures/byte_flow.py

Lines changed: 19 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import dis
2-
from copy import deepcopy
32
from dataclasses import dataclass
43

54
from numba_rvsdg.core.datastructures.scfg import SCFG
@@ -63,87 +62,48 @@ def from_bytecode(code) -> "ByteFlow":
6362
def _join_returns(self):
6463
"""Joins the return blocks within the corresponding SCFG.
6564
66-
This method creates a deep copy of the SCFG and performs
67-
operation to join return blocks within the control flow.
68-
It returns a new ByteFlow object with the updated SCFG.
69-
70-
Returns
71-
-------
72-
byteflow: ByteFlow
73-
The new ByteFlow object with updated SCFG.
65+
This method performs operation to join return blocks within
66+
the control flow.
7467
"""
75-
scfg = deepcopy(self.scfg)
76-
scfg.join_returns()
77-
return ByteFlow(bc=self.bc, scfg=scfg)
68+
self.scfg.join_returns()
7869

7970
def _restructure_loop(self):
8071
"""Restructures the loops within the corresponding SCFG.
8172
82-
Creates a deep copy of the SCFG and performs the operation to
83-
restructure loop constructs within the control flow using
84-
the algorithm LOOP RESTRUCTURING from section 4.1 of Bahmann2015.
73+
Performs the operation to restructure loop constructs within
74+
the control flow using the algorithm LOOP RESTRUCTURING from
75+
section 4.1 of Bahmann2015.
8576
It applies the restructuring operation to both the main SCFG
86-
and any subregions within it. It returns a new ByteFlow object
87-
with the updated SCFG.
88-
89-
Returns
90-
-------
91-
byteflow: ByteFlow
92-
The new ByteFlow object with updated SCFG.
77+
and any subregions within it.
9378
"""
94-
scfg = deepcopy(self.scfg)
95-
restructure_loop(scfg.region)
96-
for region in _iter_subregions(scfg):
79+
restructure_loop(self.scfg.region)
80+
for region in _iter_subregions(self.scfg):
9781
restructure_loop(region)
98-
return ByteFlow(bc=self.bc, scfg=scfg)
9982

10083
def _restructure_branch(self):
10184
"""Restructures the branches within the corresponding SCFG.
10285
103-
Creates a deep copy of the SCFG and performs the operation to
104-
restructure branch constructs within the control flow. It applies
105-
the restructuring operation to both the main SCFG and any
106-
subregions within it. It returns a new ByteFlow object with
107-
the updated SCFG.
108-
109-
Returns
110-
-------
111-
byteflow: ByteFlow
112-
The new ByteFlow object with updated SCFG.
86+
This method applies restructuring branch operation to both
87+
the main SCFG and any subregions within it.
11388
"""
114-
scfg = deepcopy(self.scfg)
115-
restructure_branch(scfg.region)
116-
for region in _iter_subregions(scfg):
89+
restructure_branch(self.scfg.region)
90+
for region in _iter_subregions(self.scfg):
11791
restructure_branch(region)
118-
return ByteFlow(bc=self.bc, scfg=scfg)
11992

12093
def restructure(self):
12194
"""Applies join_returns, restructure_loop and restructure_branch
12295
in the respective order on the SCFG.
12396
124-
Creates a deep copy of the SCFG and applies a series of
125-
restructuring operations to it. The operations include
126-
joining return blocks, restructuring loop constructs, and
127-
restructuring branch constructs. It returns a new ByteFlow
128-
object with the updated SCFG.
129-
130-
Returns
131-
-------
132-
byteflow: ByteFlow
133-
The new ByteFlow object with updated SCFG.
97+
Applies a series of restructuring operations to given SCFG.
98+
The operations include joining return blocks, restructuring
99+
loop constructs, and restructuring branch constructs.
134100
"""
135-
scfg = deepcopy(self.scfg)
136101
# close
137-
scfg.join_returns()
102+
self._join_returns()
138103
# handle loop
139-
restructure_loop(scfg.region)
140-
for region in _iter_subregions(scfg):
141-
restructure_loop(region)
104+
self._restructure_loop()
142105
# handle branch
143-
restructure_branch(scfg.region)
144-
for region in _iter_subregions(scfg):
145-
restructure_branch(region)
146-
return ByteFlow(bc=self.bc, scfg=scfg)
106+
self._restructure_branch()
147107

148108

149109
def _iter_subregions(scfg: "SCFG"):

numba_rvsdg/core/datastructures/flow_info.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,15 +128,15 @@ def build_basicblocks(self: "FlowInfo", end_offset=None) -> "SCFG":
128128
term_offset = _prev_inst_offset(end)
129129
if term_offset not in self.jump_insts:
130130
# implicit jump
131-
targets = (names[end],)
131+
targets = [names[end]]
132132
else:
133-
targets = tuple(names[o] for o in self.jump_insts[term_offset])
133+
targets = [names[o] for o in self.jump_insts[term_offset]]
134134
block = PythonBytecodeBlock(
135135
name=name,
136136
begin=begin,
137137
end=end,
138138
_jump_targets=targets,
139-
backedges=(),
139+
backedges=[],
140140
)
141141
scfg.add_block(block)
142142
return scfg

numba_rvsdg/core/datastructures/scfg.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -506,14 +506,14 @@ def insert_block(
506506
# TODO: needs a diagram and documentaion
507507
# initialize new block
508508
new_block = block_type(
509-
name=new_name, _jump_targets=successors, backedges=set()
509+
name=new_name, _jump_targets=list(successors), backedges=[]
510510
)
511511
# add block to self
512512
self.add_block(new_block)
513513
# Replace any arcs from any of predecessors to any of successors with
514514
# an arc through the inserted block instead.
515515
for name in predecessors:
516-
block = self.graph.pop(name)
516+
block = self.graph[name]
517517
jt = list(block.jump_targets)
518518
if successors:
519519
for s in successors:
@@ -524,7 +524,7 @@ def insert_block(
524524
jt.pop(jt.index(s))
525525
else:
526526
jt.append(new_name)
527-
self.add_block(block.replace_jump_targets(jump_targets=tuple(jt)))
527+
block.replace_jump_targets(jump_targets=jt)
528528

529529
def insert_SyntheticExit(
530530
self,
@@ -618,8 +618,8 @@ def insert_block_and_control_blocks(
618618
variable_assignment[branch_variable] = branch_variable_value
619619
synth_assign_block = SyntheticAssignment(
620620
name=synth_assign,
621-
_jump_targets=(new_name,),
622-
backedges=(),
621+
_jump_targets=[new_name],
622+
backedges=[],
623623
variable_assignment=variable_assignment,
624624
)
625625
# add block
@@ -631,16 +631,12 @@ def insert_block_and_control_blocks(
631631
# replace previous successor with synth_assign
632632
jt[jt.index(s)] = synth_assign
633633
# finally, replace the jump_targets
634-
self.add_block(
635-
self.graph.pop(name).replace_jump_targets(
636-
jump_targets=tuple(jt)
637-
)
638-
)
634+
self.graph[name].replace_jump_targets(jump_targets=jt)
639635
# initialize new block, which will hold the branching table
640636
new_block = SyntheticHead(
641637
name=new_name,
642-
_jump_targets=tuple(successors),
643-
backedges=set(),
638+
_jump_targets=list(successors),
639+
backedges=[],
644640
variable=branch_variable,
645641
branch_value_table=branch_value_table,
646642
)
@@ -1084,18 +1080,22 @@ def reverse_lookup(value: type):
10841080
q = set()
10851081
# Order of elements doesn't matter since they're going to
10861082
# be sorted at the end.
1087-
q.update(scfg.graph.items())
1083+
graph_dict = {}
1084+
q.update(scfg.graph.keys())
1085+
graph_dict.update(scfg.graph)
10881086

10891087
while q:
1090-
key, value = q.pop()
1088+
key = q.pop()
1089+
value = scfg.graph[key]
10911090
if key in seen:
10921091
continue
10931092
seen.add(key)
10941093

10951094
block_type = reverse_lookup(type(value))
10961095
blocks[key] = {"type": block_type}
10971096
if isinstance(value, RegionBlock):
1098-
q.update(value.subregion.graph.items())
1097+
q.update(value.subregion.graph.keys())
1098+
graph_dict.update(value.subregion.graph)
10991099
blocks[key]["kind"] = value.kind
11001100
blocks[key]["contains"] = sorted(
11011101
[idx.name for idx in value.subregion.graph.values()]
@@ -1177,14 +1177,14 @@ def extract_block_info(
11771177
List of backedges of the requested block.
11781178
"""
11791179
block_info = blocks[current_name].copy()
1180-
block_edges = tuple(block_ref_dict[idx] for idx in edges[current_name])
1180+
block_edges = [block_ref_dict[idx] for idx in edges[current_name]]
11811181

11821182
if backedges.get(current_name):
1183-
block_backedges = tuple(
1183+
block_backedges = [
11841184
block_ref_dict[idx] for idx in backedges[current_name]
1185-
)
1185+
]
11861186
else:
1187-
block_backedges = ()
1187+
block_backedges = []
11881188

11891189
block_type = block_info.pop("type")
11901190

numba_rvsdg/core/transformations.py

Lines changed: 22 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,7 @@ def loop_restructure_helper(scfg: SCFG, loop: Set[str]):
5858
and len(exiting_blocks) == 1
5959
and backedge_blocks[0] == next(iter(exiting_blocks))
6060
):
61-
scfg.add_block(
62-
scfg.graph.pop(backedge_blocks[0]).declare_backedge(loop_head)
63-
)
61+
scfg.graph[backedge_blocks[0]].declare_backedge(loop_head)
6462
return
6563

6664
# The synthetic exiting latch and synthetic exit need to be created
@@ -146,8 +144,8 @@ def reverse_lookup(d, value):
146144
# Create the actual control variable block
147145
synth_assign_block = SyntheticAssignment(
148146
name=synth_assign,
149-
_jump_targets=(synth_exiting_latch,),
150-
backedges=(),
147+
_jump_targets=[synth_exiting_latch],
148+
backedges=[],
151149
variable_assignment=variable_assignment,
152150
)
153151
# Insert the assignment to the scfg
@@ -177,20 +175,18 @@ def reverse_lookup(d, value):
177175
# that point to the headers, no need to add a backedge,
178176
# since it will be contained in the SyntheticExitingLatch
179177
# later on.
180-
block = scfg.graph.pop(name)
178+
block = scfg.graph[name]
181179
jts = list(block.jump_targets)
182180
for h in headers:
183181
if h in jts:
184182
jts.remove(h)
185-
scfg.add_block(
186-
block.replace_jump_targets(jump_targets=tuple(jts))
187-
)
183+
block.replace_jump_targets(jump_targets=jts)
188184
# Setup the assignment block and initialize it with the
189185
# correct jump_targets and variable assignment.
190186
synth_assign_block = SyntheticAssignment(
191187
name=synth_assign,
192-
_jump_targets=(synth_exiting_latch,),
193-
backedges=(),
188+
_jump_targets=[synth_exiting_latch],
189+
backedges=[],
194190
variable_assignment=variable_assignment,
195191
)
196192
# Add the new block to the SCFG
@@ -199,22 +195,19 @@ def reverse_lookup(d, value):
199195
new_jt[new_jt.index(jt)] = synth_assign
200196
# finally, replace the jump_targets for this block with the new
201197
# ones
202-
scfg.add_block(
203-
scfg.graph.pop(name).replace_jump_targets(
204-
jump_targets=tuple(new_jt)
205-
)
206-
)
198+
scfg.graph[name].replace_jump_targets(jump_targets=new_jt)
199+
207200
# Add any new blocks to the loop.
208201
loop.update(new_blocks)
209202

210203
# Insert the exiting latch, add it to the loop and to the graph.
211204
synth_exiting_latch_block = SyntheticExitingLatch(
212205
name=synth_exiting_latch,
213-
_jump_targets=(
206+
_jump_targets=[
214207
synth_exit if needs_synth_exit else next(iter(exit_blocks)),
215208
loop_head,
216-
),
217-
backedges=(loop_head,),
209+
],
210+
backedges=[loop_head],
218211
variable=backedge_variable,
219212
branch_value_table=backedge_value_table,
220213
)
@@ -225,8 +218,8 @@ def reverse_lookup(d, value):
225218
if needs_synth_exit:
226219
synth_exit_block = SyntheticExitBranch(
227220
name=synth_exit,
228-
_jump_targets=tuple(exit_blocks),
229-
backedges=(),
221+
_jump_targets=list(exit_blocks),
222+
backedges=[],
230223
variable=exit_variable,
231224
branch_value_table=exit_value_table,
232225
)
@@ -329,29 +322,23 @@ def update_exiting(
329322
):
330323
# Recursively updates the exiting blocks of a regionblock
331324
region_exiting = region_block.exiting
332-
region_exiting_block: BasicBlock = region_block.subregion.graph.pop(
325+
region_exiting_block: BasicBlock = region_block.subregion.graph[
333326
region_exiting
334-
)
327+
]
335328
jt = list(region_exiting_block._jump_targets)
336329
for idx, s in enumerate(jt):
337330
if s is new_region_header:
338331
jt[idx] = new_region_name
339-
region_exiting_block = region_exiting_block.replace_jump_targets(
340-
jump_targets=tuple(jt)
341-
)
332+
region_exiting_block.replace_jump_targets(jump_targets=jt)
342333
be = list(region_exiting_block.backedges)
343334
for idx, s in enumerate(be):
344335
if s is new_region_header:
345336
be[idx] = new_region_name
346-
region_exiting_block = region_exiting_block.replace_backedges(
347-
backedges=tuple(be)
348-
)
337+
region_exiting_block.replace_backedges(backedges=be)
349338
if isinstance(region_exiting_block, RegionBlock):
350339
region_exiting_block = update_exiting(
351340
region_exiting_block, new_region_header, new_region_name
352341
)
353-
region_block.subregion.add_block(region_exiting_block)
354-
return region_block
355342

356343

357344
def extract_region(
@@ -381,27 +368,26 @@ def extract_region(
381368
# the SCFG represents should not be the meta region.
382369
assert scfg.region.kind != "meta"
383370
continue
384-
entry = scfg.graph.pop(name)
371+
entry = scfg.graph[name]
385372
jt = list(entry._jump_targets)
386373
for idx, s in enumerate(jt):
387374
if s is region_header:
388375
jt[idx] = region_name
389-
entry = entry.replace_jump_targets(jump_targets=tuple(jt))
376+
entry.replace_jump_targets(jump_targets=jt)
390377
be = list(entry.backedges)
391378
for idx, s in enumerate(be):
392379
if s is region_header:
393380
be[idx] = region_name
394-
entry = entry.replace_backedges(backedges=tuple(be))
381+
entry.replace_backedges(backedges=be)
395382
# If the entry itself is a region, update it's
396383
# exiting blocks too, recursively
397384
if isinstance(entry, RegionBlock):
398385
entry = update_exiting(entry, region_header, region_name)
399-
scfg.add_block(entry)
400386

401387
region = RegionBlock(
402388
name=region_name,
403389
_jump_targets=scfg[region_exiting].jump_targets,
404-
backedges=(),
390+
backedges=[],
405391
kind=region_kind,
406392
header=region_header,
407393
subregion=head_subgraph,

0 commit comments

Comments
 (0)