@@ -363,7 +363,12 @@ def getConstantInt(self, value, width=64):
363363 ty = self .getIntegerType (width )
364364 return arith .ConstantOp (ty , self .getIntegerAttr (ty , value )).result
365365
366- def promoteOperandType (self , ty , operand ):
366+ def changeOperandToType (self , ty , operand , allowDemotion = True ):
367+ """
368+ Change the type of an operand to a specified type. This function primarily
369+ handles type conversions and promotions to higher types (complex > float > int).
370+ Demotion of floating type to integer is allowed by default.
371+ """
367372 if ty == operand .type :
368373 return operand
369374
@@ -374,13 +379,15 @@ def promoteOperandType(self, ty, operand):
374379 otherComplexType = ComplexType (operand .type )
375380 otherFloatType = otherComplexType .element_type
376381 if (floatType != otherFloatType ):
377- real = self .promoteOperandType (floatType ,
378- complex .ReOp (operand ).result )
379- imag = self .promoteOperandType (floatType ,
380- complex .ImOp (operand ).result )
382+ real = self .changeOperandToType (
383+ floatType ,
384+ complex .ReOp (operand ).result )
385+ imag = self .changeOperandToType (
386+ floatType ,
387+ complex .ImOp (operand ).result )
381388 operand = complex .CreateOp (complexType , real , imag ).result
382389 else :
383- real = self .promoteOperandType (floatType , operand )
390+ real = self .changeOperandToType (floatType , operand )
384391 imag = self .getConstantFloatWithType (0.0 , floatType )
385392 operand = complex .CreateOp (complexType , real , imag ).result
386393
@@ -406,8 +413,8 @@ def promoteOperandType(self, ty, operand):
406413 zint = zeroext ).result
407414
408415 if IntegerType .isinstance (ty ):
409- if F64Type .isinstance (operand .type ) or F32Type . isinstance (
410- operand .type ):
416+ if allowDemotion and ( F64Type .isinstance (operand .type ) or
417+ F32Type . isinstance ( operand .type ) ):
411418 operand = cc .CastOp (ty , operand , sint = True , zint = False ).result
412419 if IntegerType .isinstance (operand .type ):
413420 if IntegerType (ty ).width < IntegerType (operand .type ).width :
@@ -622,7 +629,7 @@ def bodyBuilder(iterVar):
622629 eleAddr = cc .ComputePtrOp (sourceElePtrTy , sourceDataPtr , [iterVar ],
623630 rawIndex ).result
624631 loadedEle = cc .LoadOp (eleAddr ).result
625- castedEle = self .promoteOperandType (targetEleType , loadedEle )
632+ castedEle = self .changeOperandToType (targetEleType , loadedEle )
626633 targetEleAddr = cc .ComputePtrOp (targetElePtrType , targetPtr ,
627634 [iterVar ], rawIndex ).result
628635 cc .StoreOp (castedEle , targetEleAddr )
@@ -762,7 +769,8 @@ def convertArithmeticToSuperiorType(self, values, type):
762769 """
763770 retValues = []
764771 for v in values :
765- retValues .append (self .promoteOperandType (type , v ))
772+ retValues .append (
773+ self .changeOperandToType (type , v , allowDemotion = False ))
766774
767775 return retValues
768776
@@ -1824,8 +1832,8 @@ def bodyBuilder(iterVar):
18241832 else :
18251833 imag = namedArgs ['imag' ]
18261834 real = namedArgs ['real' ]
1827- imag = self .promoteOperandType (self .getFloatType (), imag )
1828- real = self .promoteOperandType (self .getFloatType (), real )
1835+ imag = self .changeOperandToType (self .getFloatType (), imag )
1836+ real = self .changeOperandToType (self .getFloatType (), real )
18291837 self .pushValue (
18301838 complex .CreateOp (self .getComplexType (), real , imag ).result )
18311839 return
@@ -2141,8 +2149,8 @@ def bodyBuilder(iterVal):
21412149 elif node .func .id == 'int' :
21422150 # cast operation
21432151 value = self .popValue ()
2144- casted = self .promoteOperandType (IntegerType .get_signless (64 ),
2145- value )
2152+ casted = self .changeOperandToType (IntegerType .get_signless (64 ),
2153+ value )
21462154 self .pushValue (casted )
21472155 if not IntegerType .isinstance (casted .type ):
21482156 self .emitFatalError (
@@ -2306,15 +2314,15 @@ def bodyBuilder(iterVal):
23062314 ty = self .getComplexType (width = 32 )
23072315 eleTy = self .getFloatType (width = 32 )
23082316
2309- value = self .promoteOperandType (ty , value )
2317+ value = self .changeOperandToType (ty , value )
23102318 if (ty == value .type ):
23112319 self .pushValue (value )
23122320 return
23132321
23142322 real = complex .ReOp (value ).result
23152323 imag = complex .ImOp (value ).result
2316- real = self .promoteOperandType (eleTy , real )
2317- imag = self .promoteOperandType (eleTy , imag )
2324+ real = self .changeOperandToType (eleTy , real )
2325+ imag = self .changeOperandToType (eleTy , imag )
23182326
23192327 self .pushValue (complex .CreateOp (ty , real , imag ).result )
23202328 return
@@ -2325,18 +2333,18 @@ def bodyBuilder(iterVal):
23252333 if node .func .attr == 'float32' :
23262334 ty = self .getFloatType (width = 32 )
23272335
2328- value = self .promoteOperandType (ty , value )
2336+ value = self .changeOperandToType (ty , value )
23292337 self .pushValue (value )
23302338 return
23312339
23322340 # Promote argument's types for `numpy.func` calls to match python's semantics
23332341 if node .func .attr in ['sin' , 'cos' , 'sqrt' , 'ceil' , 'exp' ]:
23342342 if ComplexType .isinstance (value .type ):
2335- value = self .promoteOperandType ( self . getComplexType (),
2336- value )
2343+ value = self .changeOperandToType (
2344+ self . getComplexType (), value )
23372345 if IntegerType .isinstance (value .type ):
2338- value = self .promoteOperandType ( self . getFloatType (),
2339- value )
2346+ value = self .changeOperandToType (
2347+ self . getFloatType (), value )
23402348
23412349 if node .func .attr == 'cos' :
23422350 if ComplexType .isinstance (value .type ):
@@ -2367,8 +2375,8 @@ def bodyBuilder(iterVal):
23672375 floatType = complexType .element_type
23682376 real = complex .ReOp (value ).result
23692377 imag = complex .ImOp (value ).result
2370- left = self .promoteOperandType (complexType ,
2371- math .ExpOp (real ).result )
2378+ left = self .changeOperandToType (complexType ,
2379+ math .ExpOp (real ).result )
23722380 re2 = math .CosOp (imag ).result
23732381 im2 = math .SinOp (imag ).result
23742382 right = complex .CreateOp (ComplexType .get (floatType ),
@@ -4027,7 +4035,9 @@ def visit_Return(self, node):
40274035 result = self .ifPointerThenLoad (result )
40284036 if result .type != self .knownResultType :
40294037 # FIXME consider more auto-casting where possible
4030- result = self .promoteOperandType (self .knownResultType , result )
4038+ result = self .changeOperandToType (self .knownResultType ,
4039+ result ,
4040+ allowDemotion = True )
40314041
40324042 if result .type != self .knownResultType :
40334043 self .emitFatalError (
@@ -4211,8 +4221,12 @@ def visit_BinOp(self, node):
42114221
42124222 # Type promotion for addition, subtraction, multiplication, or division
42134223 if isinstance (node .op , (ast .Add , ast .Sub , ast .Mult , ast .Div )):
4214- right = self .promoteOperandType (left .type , right )
4215- left = self .promoteOperandType (right .type , left )
4224+ right = self .changeOperandToType (left .type ,
4225+ right ,
4226+ allowDemotion = False )
4227+ left = self .changeOperandToType (right .type ,
4228+ left ,
4229+ allowDemotion = False )
42164230
42174231 # Based on the op type and the leaf types, create the MLIR operator
42184232 if isinstance (node .op , ast .Add ):
@@ -4303,8 +4317,8 @@ def visit_BinOp(self, node):
43034317 if isinstance (node .op , ast .LShift ):
43044318 if IntegerType .isinstance (left .type ) and IntegerType .isinstance (
43054319 right .type ):
4306- left = self .promoteOperandType (self .getIntegerType (), left )
4307- right = self .promoteOperandType (self .getIntegerType (), right )
4320+ left = self .changeOperandToType (self .getIntegerType (), left )
4321+ right = self .changeOperandToType (self .getIntegerType (), right )
43084322 self .pushValue (arith .ShLIOp (left , right ).result )
43094323 return
43104324 else :
@@ -4315,8 +4329,8 @@ def visit_BinOp(self, node):
43154329 if isinstance (node .op , ast .RShift ):
43164330 if IntegerType .isinstance (left .type ) and IntegerType .isinstance (
43174331 right .type ):
4318- left = self .promoteOperandType (self .getIntegerType (), left )
4319- right = self .promoteOperandType (self .getIntegerType (), right )
4332+ left = self .changeOperandToType (self .getIntegerType (), left )
4333+ right = self .changeOperandToType (self .getIntegerType (), right )
43204334 self .pushValue (arith .ShRSIOp (left , right ).result )
43214335 return
43224336 else :
@@ -4327,8 +4341,8 @@ def visit_BinOp(self, node):
43274341 if isinstance (node .op , ast .BitAnd ):
43284342 if IntegerType .isinstance (left .type ) and IntegerType .isinstance (
43294343 right .type ):
4330- left = self .promoteOperandType (self .getIntegerType (), left )
4331- right = self .promoteOperandType (self .getIntegerType (), right )
4344+ left = self .changeOperandToType (self .getIntegerType (), left )
4345+ right = self .changeOperandToType (self .getIntegerType (), right )
43324346 self .pushValue (arith .AndIOp (left , right ).result )
43334347 return
43344348 else :
@@ -4339,8 +4353,8 @@ def visit_BinOp(self, node):
43394353 if isinstance (node .op , ast .BitOr ):
43404354 if IntegerType .isinstance (left .type ) and IntegerType .isinstance (
43414355 right .type ):
4342- left = self .promoteOperandType (self .getIntegerType (), left )
4343- right = self .promoteOperandType (self .getIntegerType (), right )
4356+ left = self .changeOperandToType (self .getIntegerType (), left )
4357+ right = self .changeOperandToType (self .getIntegerType (), right )
43444358 self .pushValue (arith .OrIOp (left , right ).result )
43454359 return
43464360 else :
@@ -4351,8 +4365,8 @@ def visit_BinOp(self, node):
43514365 if isinstance (node .op , ast .BitXor ):
43524366 if IntegerType .isinstance (left .type ) and IntegerType .isinstance (
43534367 right .type ):
4354- left = self .promoteOperandType (self .getIntegerType (), left )
4355- right = self .promoteOperandType (self .getIntegerType (), right )
4368+ left = self .changeOperandToType (self .getIntegerType (), left )
4369+ right = self .changeOperandToType (self .getIntegerType (), right )
43564370 self .pushValue (arith .XOrIOp (left , right ).result )
43574371 return
43584372 else :
0 commit comments