Skip to content

Commit 5a18dcf

Browse files
committed
[py] Fix loop creation in PyASTBridge
According to our definitions: A "invariant loop" is a loop that is _guaranteed_ to execute the body of the loop `totalIterations` times. _Early exits are not allowed_. Loops created through `createInvariantForLoop` are not guaranteed to be invariant. This change renames `createInvariantForLoop` to `createForLoop` and add a flag parameter that signals whether the loop is indeed invariant. Signed-off-by: boschmitt <[email protected]>
1 parent 81714fd commit 5a18dcf

15 files changed

+73
-63
lines changed

python/cudaq/kernel/ast_bridge.py

Lines changed: 55 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -627,7 +627,7 @@ def bodyBuilder(iterVar):
627627
[iterVar], rawIndex).result
628628
cc.StoreOp(castedEle, targetEleAddr)
629629

630-
self.createInvariantForLoop(sourceSize, bodyBuilder)
630+
self.createForLoop(sourceSize, bodyBuilder, invariant=True)
631631
return cc.StdvecInitOp(targetVecTy, targetPtr, length=sourceSize).result
632632

633633
def __insertDbgStmt(self, value, dbgStmt):
@@ -808,15 +808,16 @@ def checkControlAndTargetTypes(self, controls, targets):
808808
for i, target in enumerate(targets)
809809
]
810810

811-
def createInvariantForLoop(self,
812-
endVal,
813-
bodyBuilder,
814-
startVal=None,
815-
stepVal=None,
816-
isDecrementing=False,
817-
elseStmts=None):
811+
def createForLoop(self,
812+
endVal,
813+
bodyBuilder,
814+
startVal=None,
815+
stepVal=None,
816+
isDecrementing=False,
817+
elseStmts=None,
818+
invariant=False):
818819
"""
819-
Create an invariant loop using the CC dialect.
820+
Create a loop using the CC dialect.
820821
"""
821822
startVal = self.getConstantInt(0) if startVal == None else startVal
822823
stepVal = self.getConstantInt(1) if stepVal == None else stepVal
@@ -860,7 +861,9 @@ def createInvariantForLoop(self,
860861
cc.ContinueOp(elseBlock.arguments)
861862
self.symbolTable.popScope()
862863

863-
loop.attributes.__setitem__('invariant', UnitAttr.get())
864+
if invariant:
865+
loop.attributes['invariant'] = UnitAttr.get()
866+
864867
return
865868

866869
def __applyQuantumOperation(self, opName, parameters, targets):
@@ -877,7 +880,7 @@ def bodyBuilder(iterVal):
877880

878881
veqSize = quake.VeqSizeOp(self.getIntegerType(),
879882
quantumValue).result
880-
self.createInvariantForLoop(veqSize, bodyBuilder)
883+
self.createForLoop(veqSize, bodyBuilder, invariant=True)
881884
elif quake.RefType.isinstance(quantumValue.type):
882885
opCtor([], parameters, [], [quantumValue])
883886
else:
@@ -1711,11 +1714,12 @@ def bodyBuilder(iterVar):
17111714
incrementedCounter = arith.AddIOp(loadedCounter, one).result
17121715
cc.StoreOp(incrementedCounter, counter)
17131716

1714-
self.createInvariantForLoop(endVal,
1715-
bodyBuilder,
1716-
startVal=startVal,
1717-
stepVal=stepVal,
1718-
isDecrementing=isDecrementing)
1717+
self.createForLoop(endVal,
1718+
bodyBuilder,
1719+
startVal=startVal,
1720+
stepVal=stepVal,
1721+
isDecrementing=isDecrementing,
1722+
invariant=True)
17191723

17201724
self.pushValue(iterable)
17211725
self.pushValue(actualSize)
@@ -1812,7 +1816,7 @@ def bodyBuilder(iterVar):
18121816
DenseI64ArrayAttr.get([1], context=self.ctx)).result
18131817
cc.StoreOp(element, eleAddr)
18141818

1815-
self.createInvariantForLoop(totalSize, bodyBuilder)
1819+
self.createForLoop(totalSize, bodyBuilder, invariant=True)
18161820
self.pushValue(enumIterable)
18171821
self.pushValue(totalSize)
18181822
return
@@ -1921,7 +1925,7 @@ def bodyBuilder(iterVal):
19211925

19221926
veqSize = quake.VeqSizeOp(self.getIntegerType(),
19231927
target).result
1924-
self.createInvariantForLoop(veqSize, bodyBuilder)
1928+
self.createForLoop(veqSize, bodyBuilder, invariant=True)
19251929
return
19261930
elif quake.RefType.isinstance(target.type):
19271931
opCtor([], [], [], [target], is_adj=True)
@@ -2002,7 +2006,7 @@ def bodyBuilder(iterVal):
20022006

20032007
veqSize = quake.VeqSizeOp(self.getIntegerType(),
20042008
target).result
2005-
self.createInvariantForLoop(veqSize, bodyBuilder)
2009+
self.createForLoop(veqSize, bodyBuilder, invariant=True)
20062010
return
20072011
self.emitFatalError(
20082012
'reset quantum operation on incorrect type {}.'.format(
@@ -2772,7 +2776,7 @@ def bodyBuilder(iterVal):
27722776

27732777
veqSize = quake.VeqSizeOp(self.getIntegerType(),
27742778
target).result
2775-
self.createInvariantForLoop(veqSize, bodyBuilder)
2779+
self.createForLoop(veqSize, bodyBuilder, invariant=True)
27762780
return
27772781
elif quake.RefType.isinstance(target.type):
27782782
opCtor([], [], [], [target], is_adj=True)
@@ -2852,7 +2856,7 @@ def bodyBuilder(iterVal):
28522856

28532857
veqSize = quake.VeqSizeOp(self.getIntegerType(),
28542858
target).result
2855-
self.createInvariantForLoop(veqSize, bodyBuilder)
2859+
self.createForLoop(veqSize, bodyBuilder, invariant=True)
28562860
return
28572861
elif quake.RefType.isinstance(target.type):
28582862
opCtor([], [param], [], [target], is_adj=True)
@@ -2930,7 +2934,7 @@ def bodyBuilder(iterVal):
29302934

29312935
veqSize = quake.VeqSizeOp(self.getIntegerType(),
29322936
target).result
2933-
self.createInvariantForLoop(veqSize, bodyBuilder)
2937+
self.createForLoop(veqSize, bodyBuilder, invariant=True)
29342938
return
29352939
elif quake.RefType.isinstance(target.type):
29362940
opCtor([], params, [], [target], is_adj=True)
@@ -3031,14 +3035,21 @@ def visit_ListComp(self, node):
30313035
node.generators[0].iter)
30323036
if quake.VeqType.isinstance(
30333037
self.symbolTable[node.generators[0].iter.id].type):
3034-
# now we know we have `[expr(r) for r in iterable]`
3035-
# reuse what we do in `visit_For()`
3036-
forNode = ast.For()
3037-
forNode.iter = node.generators[0].iter
3038-
forNode.target = node.generators[0].target
3039-
forNode.body = [node.elt]
3040-
forNode.orelse = []
3041-
self.visit_For(forNode)
3038+
iterable = self.symbolTable[node.generators[0].iter.id]
3039+
totalSize = quake.VeqSizeOp(self.getIntegerType(),
3040+
iterable).result
3041+
3042+
def bodyBuilder(iterVar):
3043+
self.symbolTable.pushScope()
3044+
q = quake.ExtractRefOp(self.getRefType(),
3045+
iterable,
3046+
-1,
3047+
index=iterVar).result
3048+
self.symbolTable[node.generators[0].target.id] = q
3049+
self.visit(node.elt)
3050+
self.symbolTable.popScope()
3051+
3052+
self.createForLoop(totalSize, bodyBuilder, invariant=True)
30423053
return
30433054

30443055
# General case of
@@ -3107,7 +3118,7 @@ def bodyBuilder(iterVar):
31073118
cc.StoreOp(result, listValueAddr)
31083119
self.symbolTable.popScope()
31093120

3110-
self.createInvariantForLoop(iterableSize, bodyBuilder)
3121+
self.createForLoop(iterableSize, bodyBuilder, invariant=True)
31113122
self.pushValue(
31123123
cc.StdvecInitOp(cc.StdvecType.get(listComputePtrTy),
31133124
listValue,
@@ -3454,12 +3465,12 @@ def bodyBuilder(iterVar):
34543465
[self.visit(b) for b in node.body]
34553466
self.symbolTable.popScope()
34563467

3457-
self.createInvariantForLoop(endVal,
3458-
bodyBuilder,
3459-
startVal=startVal,
3460-
stepVal=stepVal,
3461-
isDecrementing=isDecrementing,
3462-
elseStmts=node.orelse)
3468+
self.createForLoop(endVal,
3469+
bodyBuilder,
3470+
startVal=startVal,
3471+
stepVal=stepVal,
3472+
isDecrementing=isDecrementing,
3473+
elseStmts=node.orelse)
34633474

34643475
return
34653476

@@ -3528,9 +3539,9 @@ def bodyBuilder(iterVar):
35283539
[self.visit(b) for b in node.body]
35293540
self.symbolTable.popScope()
35303541

3531-
self.createInvariantForLoop(totalSize,
3532-
bodyBuilder,
3533-
elseStmts=node.orelse)
3542+
self.createForLoop(totalSize,
3543+
bodyBuilder,
3544+
elseStmts=node.orelse)
35343545
return
35353546

35363547
self.visit(node.iter)
@@ -3648,9 +3659,7 @@ def bodyBuilder(iterVar):
36483659
[self.visit(b) for b in node.body]
36493660
self.symbolTable.popScope()
36503661

3651-
self.createInvariantForLoop(totalSize,
3652-
bodyBuilder,
3653-
elseStmts=node.orelse)
3662+
self.createForLoop(totalSize, bodyBuilder, elseStmts=node.orelse)
36543663

36553664
def visit_While(self, node):
36563665
"""
@@ -3895,8 +3904,9 @@ def check_element(idx):
38953904
current = cc.LoadOp(accumulator).result
38963905
cc.StoreOp(arith.OrIOp(current, cmp_result.result), accumulator)
38973906

3898-
self.createInvariantForLoop(self.__get_vector_size(right_val),
3899-
check_element)
3907+
self.createForLoop(self.__get_vector_size(right_val),
3908+
check_element,
3909+
invariant=True)
39003910

39013911
final_result = cc.LoadOp(accumulator).result
39023912
if isinstance(op, ast.NotIn):

python/tests/mlir/ast_break.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,6 @@ def kernel(x: float):
6060
# CHECK: ^bb0(%[[VAL_21:.*]]: i64):
6161
# CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_21]], %[[VAL_3]] : i64
6262
# CHECK: cc.continue %[[VAL_22]] : i64
63-
# CHECK: } {invariant}
63+
# CHECK: }
6464
# CHECK: return
6565
# CHECK: }

python/tests/mlir/ast_continue.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,5 +64,5 @@ def kernel(x: float):
6464
# CHECK: ^bb0(%[[VAL_23:.*]]: i64):
6565
# CHECK: %[[VAL_24:.*]] = arith.addi %[[VAL_23]], %[[VAL_3]] : i64
6666
# CHECK: cc.continue %[[VAL_24]] : i64
67-
# CHECK: } {invariant}
67+
# CHECK: }
6868
# CHECK: }

python/tests/mlir/ast_decrementing_range.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,6 @@ def test(q: int, p: int):
4646
# CHECK: ^bb0(%[[VAL_13:.*]]: i64):
4747
# CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_13]], %[[VAL_2]] : i64
4848
# CHECK: cc.continue %[[VAL_14]] : i64
49-
# CHECK: } {invariant}
49+
# CHECK: }
5050
# CHECK: return
5151
# CHECK: }

python/tests/mlir/ast_elif.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def cost(thetas: np.ndarray): # can pass 1D ndarray or list
6666
# CHECK: ^bb0(%[[VAL_25:.*]]: i64):
6767
# CHECK: %[[VAL_26:.*]] = arith.addi %[[VAL_25]], %[[VAL_5]] : i64
6868
# CHECK: cc.continue %[[VAL_26]] : i64
69-
# CHECK: } {invariant}
69+
# CHECK: }
7070
# CHECK: return
7171
# CHECK: }
7272

@@ -103,6 +103,6 @@ def cost(thetas: np.ndarray): # can pass 1D ndarray or list
103103
# CHECK: ^bb0(%[[VAL_21:.*]]: i64):
104104
# CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_21]], %[[VAL_2]] : i64
105105
# CHECK: cc.continue %[[VAL_22]] : i64
106-
# CHECK: } {invariant}
106+
# CHECK: }
107107
# CHECK: return
108108
# CHECK: }

python/tests/mlir/ast_for_stdvec.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,6 @@ def cost(thetas: np.ndarray): # can pass 1D ndarray or list
5454
# CHECK: ^bb0(%[[VAL_17:.*]]: i64):
5555
# CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_17]], %[[VAL_1]] : i64
5656
# CHECK: cc.continue %[[VAL_18]] : i64
57-
# CHECK: } {invariant}
57+
# CHECK: }
5858
# CHECK: return
5959
# CHECK: }

python/tests/mlir/ast_iterate_loop_init.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,6 @@ def kernel(x: float):
6363
# CHECK: ^bb0(%[[VAL_26:.*]]: i64):
6464
# CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_26]], %[[VAL_3]] : i64
6565
# CHECK: cc.continue %[[VAL_27]] : i64
66-
# CHECK: } {invariant}
66+
# CHECK: }
6767
# CHECK: return
6868
# CHECK: }

python/tests/mlir/ast_list_init.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,6 @@ def kernel():
6363
# CHECK: ^bb0(%[[VAL_26:.*]]: i64):
6464
# CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_26]], %[[VAL_0]] : i64
6565
# CHECK: cc.continue %[[VAL_27]] : i64
66-
# CHECK: } {invariant}
66+
# CHECK: }
6767
# CHECK: return
6868
# CHECK: }

python/tests/mlir/ast_list_int.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,6 @@ def oracle(register: cudaq.qview, auxillary_qubit: cudaq.qubit,
5252
# CHECK: ^bb0(%[[VAL_16:.*]]: i64):
5353
# CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_16]], %[[VAL_3]] : i64
5454
# CHECK: cc.continue %[[VAL_17]] : i64
55-
# CHECK: } {invariant}
55+
# CHECK: }
5656
# CHECK: return
5757
# CHECK: }

python/tests/mlir/ast_qreg_slice.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def slice():
114114
# CHECK: ^bb0(%[[VAL_43:.*]]: i64):
115115
# CHECK: %[[VAL_44:.*]] = arith.addi %[[VAL_43]], %[[VAL_3]] : i64
116116
# CHECK: cc.continue %[[VAL_44]] : i64
117-
# CHECK: } {invariant}
117+
# CHECK: }
118118
# CHECK: %[[VAL_45:.*]] = quake.extract_ref %[[VAL_7]][3] : (!quake.veq<4>) -> !quake.ref
119119
# CHECK: quake.rz (%[[VAL_4]]) %[[VAL_45]] : (f64, !quake.ref) -> ()
120120
# CHECK: return

0 commit comments

Comments
 (0)