@@ -627,7 +627,7 @@ def bodyBuilder(iterVar):
627627 [iterVar ], rawIndex ).result
628628 cc .StoreOp (castedEle , targetEleAddr )
629629
630- self .createInvariantForLoop (sourceSize , bodyBuilder )
630+ self .createForLoop (sourceSize , bodyBuilder , invariant = True )
631631 return cc .StdvecInitOp (targetVecTy , targetPtr , length = sourceSize ).result
632632
633633 def __insertDbgStmt (self , value , dbgStmt ):
@@ -808,15 +808,16 @@ def checkControlAndTargetTypes(self, controls, targets):
808808 for i , target in enumerate (targets )
809809 ]
810810
811- def createInvariantForLoop (self ,
812- endVal ,
813- bodyBuilder ,
814- startVal = None ,
815- stepVal = None ,
816- isDecrementing = False ,
817- elseStmts = None ):
811+ def createForLoop (self ,
812+ endVal ,
813+ bodyBuilder ,
814+ startVal = None ,
815+ stepVal = None ,
816+ isDecrementing = False ,
817+ elseStmts = None ,
818+ invariant = False ):
818819 """
819- Create an invariant loop using the CC dialect.
820+ Create a loop using the CC dialect.
820821 """
821822 startVal = self .getConstantInt (0 ) if startVal == None else startVal
822823 stepVal = self .getConstantInt (1 ) if stepVal == None else stepVal
@@ -860,7 +861,9 @@ def createInvariantForLoop(self,
860861 cc .ContinueOp (elseBlock .arguments )
861862 self .symbolTable .popScope ()
862863
863- loop .attributes .__setitem__ ('invariant' , UnitAttr .get ())
864+ if invariant :
865+ loop .attributes ['invariant' ] = UnitAttr .get ()
866+
864867 return
865868
866869 def __applyQuantumOperation (self , opName , parameters , targets ):
@@ -877,7 +880,7 @@ def bodyBuilder(iterVal):
877880
878881 veqSize = quake .VeqSizeOp (self .getIntegerType (),
879882 quantumValue ).result
880- self .createInvariantForLoop (veqSize , bodyBuilder )
883+ self .createForLoop (veqSize , bodyBuilder , invariant = True )
881884 elif quake .RefType .isinstance (quantumValue .type ):
882885 opCtor ([], parameters , [], [quantumValue ])
883886 else :
@@ -1711,11 +1714,12 @@ def bodyBuilder(iterVar):
17111714 incrementedCounter = arith .AddIOp (loadedCounter , one ).result
17121715 cc .StoreOp (incrementedCounter , counter )
17131716
1714- self .createInvariantForLoop (endVal ,
1715- bodyBuilder ,
1716- startVal = startVal ,
1717- stepVal = stepVal ,
1718- isDecrementing = isDecrementing )
1717+ self .createForLoop (endVal ,
1718+ bodyBuilder ,
1719+ startVal = startVal ,
1720+ stepVal = stepVal ,
1721+ isDecrementing = isDecrementing ,
1722+ invariant = True )
17191723
17201724 self .pushValue (iterable )
17211725 self .pushValue (actualSize )
@@ -1812,7 +1816,7 @@ def bodyBuilder(iterVar):
18121816 DenseI64ArrayAttr .get ([1 ], context = self .ctx )).result
18131817 cc .StoreOp (element , eleAddr )
18141818
1815- self .createInvariantForLoop (totalSize , bodyBuilder )
1819+ self .createForLoop (totalSize , bodyBuilder , invariant = True )
18161820 self .pushValue (enumIterable )
18171821 self .pushValue (totalSize )
18181822 return
@@ -1921,7 +1925,7 @@ def bodyBuilder(iterVal):
19211925
19221926 veqSize = quake .VeqSizeOp (self .getIntegerType (),
19231927 target ).result
1924- self .createInvariantForLoop (veqSize , bodyBuilder )
1928+ self .createForLoop (veqSize , bodyBuilder , invariant = True )
19251929 return
19261930 elif quake .RefType .isinstance (target .type ):
19271931 opCtor ([], [], [], [target ], is_adj = True )
@@ -2002,7 +2006,7 @@ def bodyBuilder(iterVal):
20022006
20032007 veqSize = quake .VeqSizeOp (self .getIntegerType (),
20042008 target ).result
2005- self .createInvariantForLoop (veqSize , bodyBuilder )
2009+ self .createForLoop (veqSize , bodyBuilder , invariant = True )
20062010 return
20072011 self .emitFatalError (
20082012 'reset quantum operation on incorrect type {}.' .format (
@@ -2772,7 +2776,7 @@ def bodyBuilder(iterVal):
27722776
27732777 veqSize = quake .VeqSizeOp (self .getIntegerType (),
27742778 target ).result
2775- self .createInvariantForLoop (veqSize , bodyBuilder )
2779+ self .createForLoop (veqSize , bodyBuilder , invariant = True )
27762780 return
27772781 elif quake .RefType .isinstance (target .type ):
27782782 opCtor ([], [], [], [target ], is_adj = True )
@@ -2852,7 +2856,7 @@ def bodyBuilder(iterVal):
28522856
28532857 veqSize = quake .VeqSizeOp (self .getIntegerType (),
28542858 target ).result
2855- self .createInvariantForLoop (veqSize , bodyBuilder )
2859+ self .createForLoop (veqSize , bodyBuilder , invariant = True )
28562860 return
28572861 elif quake .RefType .isinstance (target .type ):
28582862 opCtor ([], [param ], [], [target ], is_adj = True )
@@ -2930,7 +2934,7 @@ def bodyBuilder(iterVal):
29302934
29312935 veqSize = quake .VeqSizeOp (self .getIntegerType (),
29322936 target ).result
2933- self .createInvariantForLoop (veqSize , bodyBuilder )
2937+ self .createForLoop (veqSize , bodyBuilder , invariant = True )
29342938 return
29352939 elif quake .RefType .isinstance (target .type ):
29362940 opCtor ([], params , [], [target ], is_adj = True )
@@ -3031,14 +3035,21 @@ def visit_ListComp(self, node):
30313035 node .generators [0 ].iter )
30323036 if quake .VeqType .isinstance (
30333037 self .symbolTable [node .generators [0 ].iter .id ].type ):
3034- # now we know we have `[expr(r) for r in iterable]`
3035- # reuse what we do in `visit_For()`
3036- forNode = ast .For ()
3037- forNode .iter = node .generators [0 ].iter
3038- forNode .target = node .generators [0 ].target
3039- forNode .body = [node .elt ]
3040- forNode .orelse = []
3041- self .visit_For (forNode )
3038+ iterable = self .symbolTable [node .generators [0 ].iter .id ]
3039+ totalSize = quake .VeqSizeOp (self .getIntegerType (),
3040+ iterable ).result
3041+
3042+ def bodyBuilder (iterVar ):
3043+ self .symbolTable .pushScope ()
3044+ q = quake .ExtractRefOp (self .getRefType (),
3045+ iterable ,
3046+ - 1 ,
3047+ index = iterVar ).result
3048+ self .symbolTable [node .generators [0 ].target .id ] = q
3049+ self .visit (node .elt )
3050+ self .symbolTable .popScope ()
3051+
3052+ self .createForLoop (totalSize , bodyBuilder , invariant = True )
30423053 return
30433054
30443055 # General case of
@@ -3107,7 +3118,7 @@ def bodyBuilder(iterVar):
31073118 cc .StoreOp (result , listValueAddr )
31083119 self .symbolTable .popScope ()
31093120
3110- self .createInvariantForLoop (iterableSize , bodyBuilder )
3121+ self .createForLoop (iterableSize , bodyBuilder , invariant = True )
31113122 self .pushValue (
31123123 cc .StdvecInitOp (cc .StdvecType .get (listComputePtrTy ),
31133124 listValue ,
@@ -3454,12 +3465,12 @@ def bodyBuilder(iterVar):
34543465 [self .visit (b ) for b in node .body ]
34553466 self .symbolTable .popScope ()
34563467
3457- self .createInvariantForLoop (endVal ,
3458- bodyBuilder ,
3459- startVal = startVal ,
3460- stepVal = stepVal ,
3461- isDecrementing = isDecrementing ,
3462- elseStmts = node .orelse )
3468+ self .createForLoop (endVal ,
3469+ bodyBuilder ,
3470+ startVal = startVal ,
3471+ stepVal = stepVal ,
3472+ isDecrementing = isDecrementing ,
3473+ elseStmts = node .orelse )
34633474
34643475 return
34653476
@@ -3528,9 +3539,9 @@ def bodyBuilder(iterVar):
35283539 [self .visit (b ) for b in node .body ]
35293540 self .symbolTable .popScope ()
35303541
3531- self .createInvariantForLoop (totalSize ,
3532- bodyBuilder ,
3533- elseStmts = node .orelse )
3542+ self .createForLoop (totalSize ,
3543+ bodyBuilder ,
3544+ elseStmts = node .orelse )
35343545 return
35353546
35363547 self .visit (node .iter )
@@ -3648,9 +3659,7 @@ def bodyBuilder(iterVar):
36483659 [self .visit (b ) for b in node .body ]
36493660 self .symbolTable .popScope ()
36503661
3651- self .createInvariantForLoop (totalSize ,
3652- bodyBuilder ,
3653- elseStmts = node .orelse )
3662+ self .createForLoop (totalSize , bodyBuilder , elseStmts = node .orelse )
36543663
36553664 def visit_While (self , node ):
36563665 """
@@ -3895,8 +3904,9 @@ def check_element(idx):
38953904 current = cc .LoadOp (accumulator ).result
38963905 cc .StoreOp (arith .OrIOp (current , cmp_result .result ), accumulator )
38973906
3898- self .createInvariantForLoop (self .__get_vector_size (right_val ),
3899- check_element )
3907+ self .createForLoop (self .__get_vector_size (right_val ),
3908+ check_element ,
3909+ invariant = True )
39003910
39013911 final_result = cc .LoadOp (accumulator ).result
39023912 if isinstance (op , ast .NotIn ):
0 commit comments