Skip to content

Commit 9653cc9

Browse files
[Python] [bug-fix] Incorrect operand type conversion for some arithmetic operations (#3346)
- Follow-up to PR #3084 - Operand type conversion when dealing with arithmetic operations like `add`, `sub`, `mult` and `div`, should not demote floating point type to integer types. Note that demotion is allowed when dealing with return types - which is explicitly specified by the type. - Renaming the function because `promoteOperandType` is a misnomer. Also added a flag to allow / disallow demotion of operand types. Signed-off-by: Pradnya Khalate <[email protected]>
1 parent 071a95e commit 9653cc9

File tree

2 files changed

+64
-37
lines changed

2 files changed

+64
-37
lines changed

python/cudaq/kernel/ast_bridge.py

Lines changed: 51 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,12 @@ def getConstantInt(self, value, width=64):
363363
ty = self.getIntegerType(width)
364364
return arith.ConstantOp(ty, self.getIntegerAttr(ty, value)).result
365365

366-
def promoteOperandType(self, ty, operand):
366+
def changeOperandToType(self, ty, operand, allowDemotion=True):
367+
"""
368+
Change the type of an operand to a specified type. This function primarily
369+
handles type conversions and promotions to higher types (complex > float > int).
370+
Demotion of floating type to integer is allowed by default.
371+
"""
367372
if ty == operand.type:
368373
return operand
369374

@@ -374,13 +379,15 @@ def promoteOperandType(self, ty, operand):
374379
otherComplexType = ComplexType(operand.type)
375380
otherFloatType = otherComplexType.element_type
376381
if (floatType != otherFloatType):
377-
real = self.promoteOperandType(floatType,
378-
complex.ReOp(operand).result)
379-
imag = self.promoteOperandType(floatType,
380-
complex.ImOp(operand).result)
382+
real = self.changeOperandToType(
383+
floatType,
384+
complex.ReOp(operand).result)
385+
imag = self.changeOperandToType(
386+
floatType,
387+
complex.ImOp(operand).result)
381388
operand = complex.CreateOp(complexType, real, imag).result
382389
else:
383-
real = self.promoteOperandType(floatType, operand)
390+
real = self.changeOperandToType(floatType, operand)
384391
imag = self.getConstantFloatWithType(0.0, floatType)
385392
operand = complex.CreateOp(complexType, real, imag).result
386393

@@ -406,8 +413,8 @@ def promoteOperandType(self, ty, operand):
406413
zint=zeroext).result
407414

408415
if IntegerType.isinstance(ty):
409-
if F64Type.isinstance(operand.type) or F32Type.isinstance(
410-
operand.type):
416+
if allowDemotion and (F64Type.isinstance(operand.type) or
417+
F32Type.isinstance(operand.type)):
411418
operand = cc.CastOp(ty, operand, sint=True, zint=False).result
412419
if IntegerType.isinstance(operand.type):
413420
if IntegerType(ty).width < IntegerType(operand.type).width:
@@ -622,7 +629,7 @@ def bodyBuilder(iterVar):
622629
eleAddr = cc.ComputePtrOp(sourceElePtrTy, sourceDataPtr, [iterVar],
623630
rawIndex).result
624631
loadedEle = cc.LoadOp(eleAddr).result
625-
castedEle = self.promoteOperandType(targetEleType, loadedEle)
632+
castedEle = self.changeOperandToType(targetEleType, loadedEle)
626633
targetEleAddr = cc.ComputePtrOp(targetElePtrType, targetPtr,
627634
[iterVar], rawIndex).result
628635
cc.StoreOp(castedEle, targetEleAddr)
@@ -762,7 +769,8 @@ def convertArithmeticToSuperiorType(self, values, type):
762769
"""
763770
retValues = []
764771
for v in values:
765-
retValues.append(self.promoteOperandType(type, v))
772+
retValues.append(
773+
self.changeOperandToType(type, v, allowDemotion=False))
766774

767775
return retValues
768776

@@ -1824,8 +1832,8 @@ def bodyBuilder(iterVar):
18241832
else:
18251833
imag = namedArgs['imag']
18261834
real = namedArgs['real']
1827-
imag = self.promoteOperandType(self.getFloatType(), imag)
1828-
real = self.promoteOperandType(self.getFloatType(), real)
1835+
imag = self.changeOperandToType(self.getFloatType(), imag)
1836+
real = self.changeOperandToType(self.getFloatType(), real)
18291837
self.pushValue(
18301838
complex.CreateOp(self.getComplexType(), real, imag).result)
18311839
return
@@ -2141,8 +2149,8 @@ def bodyBuilder(iterVal):
21412149
elif node.func.id == 'int':
21422150
# cast operation
21432151
value = self.popValue()
2144-
casted = self.promoteOperandType(IntegerType.get_signless(64),
2145-
value)
2152+
casted = self.changeOperandToType(IntegerType.get_signless(64),
2153+
value)
21462154
self.pushValue(casted)
21472155
if not IntegerType.isinstance(casted.type):
21482156
self.emitFatalError(
@@ -2306,15 +2314,15 @@ def bodyBuilder(iterVal):
23062314
ty = self.getComplexType(width=32)
23072315
eleTy = self.getFloatType(width=32)
23082316

2309-
value = self.promoteOperandType(ty, value)
2317+
value = self.changeOperandToType(ty, value)
23102318
if (ty == value.type):
23112319
self.pushValue(value)
23122320
return
23132321

23142322
real = complex.ReOp(value).result
23152323
imag = complex.ImOp(value).result
2316-
real = self.promoteOperandType(eleTy, real)
2317-
imag = self.promoteOperandType(eleTy, imag)
2324+
real = self.changeOperandToType(eleTy, real)
2325+
imag = self.changeOperandToType(eleTy, imag)
23182326

23192327
self.pushValue(complex.CreateOp(ty, real, imag).result)
23202328
return
@@ -2325,18 +2333,18 @@ def bodyBuilder(iterVal):
23252333
if node.func.attr == 'float32':
23262334
ty = self.getFloatType(width=32)
23272335

2328-
value = self.promoteOperandType(ty, value)
2336+
value = self.changeOperandToType(ty, value)
23292337
self.pushValue(value)
23302338
return
23312339

23322340
# Promote argument's types for `numpy.func` calls to match python's semantics
23332341
if node.func.attr in ['sin', 'cos', 'sqrt', 'ceil', 'exp']:
23342342
if ComplexType.isinstance(value.type):
2335-
value = self.promoteOperandType(self.getComplexType(),
2336-
value)
2343+
value = self.changeOperandToType(
2344+
self.getComplexType(), value)
23372345
if IntegerType.isinstance(value.type):
2338-
value = self.promoteOperandType(self.getFloatType(),
2339-
value)
2346+
value = self.changeOperandToType(
2347+
self.getFloatType(), value)
23402348

23412349
if node.func.attr == 'cos':
23422350
if ComplexType.isinstance(value.type):
@@ -2367,8 +2375,8 @@ def bodyBuilder(iterVal):
23672375
floatType = complexType.element_type
23682376
real = complex.ReOp(value).result
23692377
imag = complex.ImOp(value).result
2370-
left = self.promoteOperandType(complexType,
2371-
math.ExpOp(real).result)
2378+
left = self.changeOperandToType(complexType,
2379+
math.ExpOp(real).result)
23722380
re2 = math.CosOp(imag).result
23732381
im2 = math.SinOp(imag).result
23742382
right = complex.CreateOp(ComplexType.get(floatType),
@@ -4027,7 +4035,9 @@ def visit_Return(self, node):
40274035
result = self.ifPointerThenLoad(result)
40284036
if result.type != self.knownResultType:
40294037
# FIXME consider more auto-casting where possible
4030-
result = self.promoteOperandType(self.knownResultType, result)
4038+
result = self.changeOperandToType(self.knownResultType,
4039+
result,
4040+
allowDemotion=True)
40314041

40324042
if result.type != self.knownResultType:
40334043
self.emitFatalError(
@@ -4211,8 +4221,12 @@ def visit_BinOp(self, node):
42114221

42124222
# Type promotion for addition, subtraction, multiplication, or division
42134223
if isinstance(node.op, (ast.Add, ast.Sub, ast.Mult, ast.Div)):
4214-
right = self.promoteOperandType(left.type, right)
4215-
left = self.promoteOperandType(right.type, left)
4224+
right = self.changeOperandToType(left.type,
4225+
right,
4226+
allowDemotion=False)
4227+
left = self.changeOperandToType(right.type,
4228+
left,
4229+
allowDemotion=False)
42164230

42174231
# Based on the op type and the leaf types, create the MLIR operator
42184232
if isinstance(node.op, ast.Add):
@@ -4303,8 +4317,8 @@ def visit_BinOp(self, node):
43034317
if isinstance(node.op, ast.LShift):
43044318
if IntegerType.isinstance(left.type) and IntegerType.isinstance(
43054319
right.type):
4306-
left = self.promoteOperandType(self.getIntegerType(), left)
4307-
right = self.promoteOperandType(self.getIntegerType(), right)
4320+
left = self.changeOperandToType(self.getIntegerType(), left)
4321+
right = self.changeOperandToType(self.getIntegerType(), right)
43084322
self.pushValue(arith.ShLIOp(left, right).result)
43094323
return
43104324
else:
@@ -4315,8 +4329,8 @@ def visit_BinOp(self, node):
43154329
if isinstance(node.op, ast.RShift):
43164330
if IntegerType.isinstance(left.type) and IntegerType.isinstance(
43174331
right.type):
4318-
left = self.promoteOperandType(self.getIntegerType(), left)
4319-
right = self.promoteOperandType(self.getIntegerType(), right)
4332+
left = self.changeOperandToType(self.getIntegerType(), left)
4333+
right = self.changeOperandToType(self.getIntegerType(), right)
43204334
self.pushValue(arith.ShRSIOp(left, right).result)
43214335
return
43224336
else:
@@ -4327,8 +4341,8 @@ def visit_BinOp(self, node):
43274341
if isinstance(node.op, ast.BitAnd):
43284342
if IntegerType.isinstance(left.type) and IntegerType.isinstance(
43294343
right.type):
4330-
left = self.promoteOperandType(self.getIntegerType(), left)
4331-
right = self.promoteOperandType(self.getIntegerType(), right)
4344+
left = self.changeOperandToType(self.getIntegerType(), left)
4345+
right = self.changeOperandToType(self.getIntegerType(), right)
43324346
self.pushValue(arith.AndIOp(left, right).result)
43334347
return
43344348
else:
@@ -4339,8 +4353,8 @@ def visit_BinOp(self, node):
43394353
if isinstance(node.op, ast.BitOr):
43404354
if IntegerType.isinstance(left.type) and IntegerType.isinstance(
43414355
right.type):
4342-
left = self.promoteOperandType(self.getIntegerType(), left)
4343-
right = self.promoteOperandType(self.getIntegerType(), right)
4356+
left = self.changeOperandToType(self.getIntegerType(), left)
4357+
right = self.changeOperandToType(self.getIntegerType(), right)
43444358
self.pushValue(arith.OrIOp(left, right).result)
43454359
return
43464360
else:
@@ -4351,8 +4365,8 @@ def visit_BinOp(self, node):
43514365
if isinstance(node.op, ast.BitXor):
43524366
if IntegerType.isinstance(left.type) and IntegerType.isinstance(
43534367
right.type):
4354-
left = self.promoteOperandType(self.getIntegerType(), left)
4355-
right = self.promoteOperandType(self.getIntegerType(), right)
4368+
left = self.changeOperandToType(self.getIntegerType(), left)
4369+
right = self.changeOperandToType(self.getIntegerType(), right)
43564370
self.pushValue(arith.XOrIOp(left, right).result)
43574371
return
43584372
else:

python/tests/kernel/test_cast_kernel.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import cudaq
1010
import numpy as np
11+
import pytest
1112

1213

1314
# bool <-> int32
@@ -168,3 +169,15 @@ def kernelFloat64Float32(f: float) -> np.float32:
168169
return f
169170

170171
assert cudaq.run(kernelFloat64Float32, -2.0, shots_count=1) == [-2.0]
172+
173+
174+
def test_multiplication():
175+
176+
@cudaq.kernel
177+
def mult_check(angle: float) -> float:
178+
M_PI = 3.1415926536
179+
phase = 2 * M_PI * angle
180+
return phase
181+
182+
result = cudaq.run(mult_check, 0.1, shots_count=1)
183+
assert result[0] == pytest.approx(0.6283185307179586)

0 commit comments

Comments
 (0)