Skip to content

Commit 02e5fcb

Browse files
committed
last clean up from merge
Signed-off-by: Bettina Heim <[email protected]>
1 parent 5a82186 commit 02e5fcb

File tree

2 files changed

+258
-114
lines changed

2 files changed

+258
-114
lines changed

python/cudaq/kernel/ast_bridge.py

Lines changed: 107 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ def __init__(self, capturedDataStorage: CapturedDataStorage, **kwargs):
266266
self.walkingReturnNode = False
267267
self.controlNegations = []
268268
self.pushPointerValue = False
269+
self.isSubscriptRoot = False
269270
self.verbose = 'verbose' in kwargs and kwargs['verbose']
270271
self.currentNode = None
271272

@@ -866,6 +867,34 @@ def bodyBuilder(iterVar):
866867
vecTy = cc.StdvecType.get(targetEleType) if not isTargetBool else cc.StdvecType.get(self.getIntegerType(1))
867868
return cc.StdvecInitOp(vecTy, targetPtr, length=sourceSize).result
868869

870+
def __migrateLists(self, value, migrate):
871+
"""
872+
Replaces all lists in the given value by the list returned
873+
by the `migrate` function, including inner lists. Does an
874+
in-place replacement for list elements.
875+
"""
876+
if cc.StdvecType.isinstance(value.type):
877+
eleTy = cc.StdvecType.getElementType(value.type)
878+
if self.containsList(eleTy):
879+
size = cc.StdvecSizeOp(self.getIntegerType(), value).result
880+
ptrTy = cc.PointerType.get(cc.ArrayType.get(eleTy))
881+
iterable = cc.StdvecDataOp(ptrTy, value).result
882+
def bodyBuilder(iterVar):
883+
eleAddr = cc.ComputePtrOp(
884+
cc.PointerType.get(eleTy), iterable, [iterVar],
885+
DenseI32ArrayAttr.get([kDynamicPtrIndex], context=self.ctx))
886+
loadedEle = cc.LoadOp(eleAddr).result
887+
element = self.__migrateLists(loadedEle, migrate)
888+
cc.StoreOp(element, eleAddr)
889+
self.createInvariantForLoop(bodyBuilder, size)
890+
return migrate(value)
891+
if (cc.StructType.isinstance(value.type) and
892+
self.containsList(value.type)):
893+
return self.__copyStructAndConvertElements(
894+
value, conversion = lambda _, v: self.__migrateLists(v, migrate))
895+
assert not self.containsList(value.type)
896+
return value
897+
869898
def __insertDbgStmt(self, value, dbgStmt):
870899
"""
871900
Insert a debug print out statement if the programmer requested. Handles
@@ -1339,7 +1368,7 @@ def __validate_container_entry(self, mlirVal, pyVal):
13391368
if cc.StructType.isinstance(mlirVal.type):
13401369
structName = cc.StructType.getName(mlirVal.type)
13411370
# We need to give a proper error if we try to assign
1342-
# a mutable dataclass to an item in a list or a dataclass.
1371+
# a mutable dataclass to an item in another container.
13431372
# Allowing this would lead to incorrect behavior (i.e.
13441373
# inconsistent with Python) unless we change the
13451374
# representation of structs to be like `StdvecType`
@@ -2048,33 +2077,30 @@ def processFunctionCall(kernel):
20482077
if len(kernel.type.results) == 0:
20492078
func.CallOp(kernel, values)
20502079
return
2080+
# The logic for calls that return values must
2081+
# match the logic in `visit_Return`; anything
2082+
# copied to the heap during return must be copied
2083+
# back to the stack. Compiler optimizations should
2084+
# take care of eliminating unnecessary copies.
20512085
result = func.CallOp(kernel, values).result
2052-
# Copy to stack if necessary
2053-
# FIXME: needs to be updated to be recursive much like the matching return logic
2054-
if cc.StdvecType.isinstance(result.type):
2055-
elemTy = cc.StdvecType.getElementType(result.type)
2056-
if elemTy == self.getIntegerType(1):
2057-
elemTy = self.getIntegerType(8)
2058-
data = cc.StdvecDataOp(cc.PointerType.get(elemTy),
2059-
result).result
2060-
i64Ty = self.getIntegerType(64)
2061-
length = cc.StdvecSizeOp(i64Ty, result).result
2062-
elemSize = cc.SizeOfOp(i64Ty, TypeAttr.get(elemTy)).result
2063-
buffer = cc.AllocaOp(cc.PointerType.get(
2064-
cc.ArrayType.get(elemTy)),
2065-
TypeAttr.get(elemTy),
2066-
seqSize=length).result
2067-
i8PtrTy = cc.PointerType.get(self.getIntegerType(8))
2068-
cbuffer = cc.CastOp(i8PtrTy, buffer).result
2069-
cdata = cc.CastOp(i8PtrTy, data).result
2086+
def copy_list_to_stack(value):
20702087
symName = '__nvqpp_vectorCopyToStack'
20712088
load_intrinsic(self.module, symName)
2072-
sizeInBytes = arith.MulIOp(length, elemSize).result
2073-
func.CallOp([], symName, [cbuffer, cdata, sizeInBytes])
2074-
# Replace result with the stack buffer-backed vector
2075-
result = cc.StdvecInitOp(result.type, buffer,
2076-
length=length).result
2077-
return result
2089+
elemTy = cc.StdvecType.getElementType(value.type)
2090+
if elemTy == self.getIntegerType(1):
2091+
elemTy = self.getIntegerType(8)
2092+
ptrTy = cc.PointerType.get(self.getIntegerType(8))
2093+
resBuf = cc.StdvecDataOp(cc.PointerType.get(elemTy), value).result
2094+
eleSize = cc.SizeOfOp(self.getIntegerType(), TypeAttr.get(elemTy)).result
2095+
dynSize = cc.StdvecSizeOp(self.getIntegerType(), value).result
2096+
stackCopy = cc.AllocaOp(cc.PointerType.get(cc.ArrayType.get(elemTy)),
2097+
TypeAttr.get(elemTy),
2098+
seqSize=dynSize).result
2099+
func.CallOp([], symName, [cc.CastOp(ptrTy, stackCopy).result,
2100+
cc.CastOp(ptrTy, resBuf).result,
2101+
arith.MulIOp(dynSize, eleSize).result])
2102+
return cc.StdvecInitOp(value.type, stackCopy, length=dynSize).result
2103+
return self.__migrateLists(result, copy_list_to_stack)
20782104

20792105
def checkControlAndTargetTypes(controls, targets):
20802106
"""
@@ -3810,8 +3836,14 @@ def fix_negative_idx(idx, get_size):
38103836
var = self.popValue()
38113837
self.pushPointerValue = True
38123838
else:
3839+
# `isSubscriptRoot` is only used/needed to enable
3840+
# modification of items in lists and dataclasses
3841+
# contained in a tuple
3842+
subscriptRoot = self.isSubscriptRoot
3843+
self.isSubscriptRoot = True
38133844
self.visit(node.value)
38143845
var = self.popValue()
3846+
self.isSubscriptRoot = subscriptRoot
38153847

38163848
pushPtr = self.pushPointerValue
38173849
self.pushPointerValue = False
@@ -3838,7 +3870,7 @@ def fix_negative_idx(idx, get_size):
38383870
if cc.PointerType.isinstance(var.type):
38393871
# We should only ever get a pointer if we
38403872
# explicitly asked for it.
3841-
assert self.pushPointerValue == True
3873+
assert self.pushPointerValue
38423874
varType = cc.PointerType.getElementType(var.type)
38433875
if cc.StdvecType.isinstance(varType):
38443876
# We can get a pointer to a vector (only) if we
@@ -3848,14 +3880,26 @@ def fix_negative_idx(idx, get_size):
38483880
# the vector, since the underlying data is
38493881
# not loaded.
38503882
var = cc.LoadOp(var).result
3883+
38513884
if cc.StructType.isinstance(varType):
3885+
structName = cc.StructType.getName(varType)
3886+
if not self.isSubscriptRoot and structName == 'tuple':
3887+
self.emitFatalError("tuple value cannot be modified", node)
3888+
if not isinstance(node.slice, ast.Constant):
3889+
if self.pushPointerValue:
3890+
if structName == 'tuple':
3891+
self.emitFatalError("tuple value cannot be modified via non-constant subscript", node)
3892+
self.emitFatalError(
3893+
f"{structName} value cannot be modified via non-constant subscript - use attribute access instead",
3894+
node)
3895+
3896+
idxVal = node.slice.value
3897+
structTys = cc.StructType.getTypes(varType)
3898+
eleAddr = cc.ComputePtrOp(cc.PointerType.get(structTys[idxVal]), var, [],
3899+
DenseI32ArrayAttr.get([idxVal])).result
38523900
if self.pushPointerValue:
3853-
structName = cc.StructType.getName(varType)
3854-
if structName == 'tuple':
3855-
self.emitFatalError("tuple value cannot be modified", node)
3856-
self.emitFatalError(
3857-
f"{structName} value cannot be modified - use `.copy(deep)` to create a new value that can be modified",
3858-
node)
3901+
self.pushValue(eleAddr)
3902+
return
38593903

38603904
if cc.StdvecType.isinstance(var.type):
38613905
idx = fix_negative_idx(idx, lambda: get_size(var))
@@ -4338,97 +4382,63 @@ def visit_Return(self, node):
43384382
result,
43394383
allowDemotion=True)
43404384

4341-
def getSize(ty):
4342-
fp_width = lambda t: 4 if F32Type.isinstance(t) else 8
4343-
if cc.StructType.isinstance(ty):
4344-
return cc.SizeOfOp(self.getIntegerType(), TypeAttr.get(ty)).result
4345-
if ComplexType.isinstance(ty):
4346-
fType = ComplexType(ty).element_type
4347-
return self.getConstantInt(2 * fp_width(fType))
4348-
elif F32Type.isinstance(ty) or F64Type.isinstance(ty):
4349-
return self.getConstantInt(fp_width(ty))
4350-
elif IntegerType.isinstance(ty):
4351-
width = IntegerType(ty).width
4352-
return self.getConstantInt((width + 7) // 8)
4353-
return self.getConstantInt(8)
4354-
4385+
# Generally, anything that was allocated locally on the stack
4386+
# needs to be copied to the heap to ensure it lives past the
4387+
# the function. This holds recursively; if we have a struct
4388+
# that contains a list, then the list data may need to be
4389+
# copied if it was allocated inside the function.
43554390
def copy_list_to_heap(value):
43564391
symName = '__nvqpp_vectorCopyCtor'
43574392
load_intrinsic(self.module, symName)
4393+
elemTy = cc.StdvecType.getElementType(value.type)
4394+
if elemTy == self.getIntegerType(1):
4395+
elemTy = self.getIntegerType(8)
43584396
ptrTy = cc.PointerType.get(self.getIntegerType(8))
43594397
arrTy = cc.ArrayType.get(self.getIntegerType(8))
43604398
ptrArrTy = cc.PointerType.get(arrTy)
43614399
resBuf = cc.StdvecDataOp(ptrArrTy, value).result
4362-
eleSize = getSize(cc.StdvecType.getElementType(value.type))
4400+
eleSize = cc.SizeOfOp(self.getIntegerType(), TypeAttr.get(elemTy)).result
43634401
dynSize = cc.StdvecSizeOp(self.getIntegerType(), value).result
43644402
resBuf = cc.CastOp(ptrTy, resBuf)
43654403
heapCopy = func.CallOp([ptrTy], symName,
43664404
[resBuf, dynSize, eleSize]).result
43674405
return cc.StdvecInitOp(value.type, heapCopy, length=dynSize).result
43684406

4369-
# Generally, anything that was allocated locally on the stack
4370-
# needs to be copied to the heap to ensure it lives past the
4371-
# the function. This holds recursively; if we have a struct
4372-
# that contains a list, then the list data may need to be
4373-
# copied if it was allocated inside the function.
4374-
def create_return_value(res):
4375-
if cc.StdvecType.isinstance(res.type):
4376-
eleTy = cc.StdvecType.getElementType(res.type) # iterty
4377-
if self.containsList(eleTy):
4378-
# Need to make sure all inner lists live past
4379-
# this function as well.
4380-
size = cc.StdvecSizeOp(self.getIntegerType(), res).result
4381-
ptrTy = cc.PointerType.get(cc.ArrayType.get(eleTy))
4382-
iterable = cc.StdvecDataOp(ptrTy, res).result
4383-
def bodyBuilder(iterVar):
4384-
eleAddr = cc.ComputePtrOp(
4385-
cc.PointerType.get(eleTy), iterable, [iterVar],
4386-
DenseI32ArrayAttr.get([kDynamicPtrIndex], context=self.ctx))
4387-
loadedEle = cc.LoadOp(eleAddr).result
4388-
element = create_return_value(loadedEle)
4389-
cc.StoreOp(element, eleAddr)
4390-
self.createInvariantForLoop(bodyBuilder, size)
4391-
return res if self.isFunctionArgument(res) else copy_list_to_heap(res)
4392-
if not cc.StructType.isinstance(res.type):
4393-
assert not cc.PointerType.isinstance(res.type)
4394-
return res
4395-
if rootVal and self.isFunctionArgument(rootVal):
4396-
# see comment below
4397-
return res
4398-
return self.__copyStructAndConvertElements(
4399-
res, conversion = lambda _, v: create_return_value(v))
4400-
44014407
rootVal = self.__get_root_value(node.value)
44024408
if rootVal and self.isFunctionArgument(rootVal):
4403-
# If we allow assigning a value that contains a list to an
4404-
# item of a function argument (which we do - caveat below),
4405-
# then we necessarily need to make a copy when we return
4406-
# function arguments, or function argument elements, that
4407-
# contain lists, since we have to assume that their data may
4409+
# If we allow assigning a value that contains a list to an item
4410+
# of a function argument (which we do with the exceptions
4411+
# commented below), then we necessarily need to make a copy when
4412+
# we return function arguments, or function argument elements,
4413+
# that contain lists, since we have to assume that their data may
44084414
# be allocated on the stack. However, this leads to incorrect
44094415
# behavior if a returned list was indeed caller-side allocated
44104416
# (and should correspondingly have been returned by reference).
44114417
# Rather than preventing that lists in function arguments can be
44124418
# updated, we instead ensure that lists contained in function
44134419
# arguments stay recognizable as such, and prevent that function
44144420
# arguments that contain list are returned.
4415-
# Caveat: We prevent that dataclass or tuple arguments or
4416-
# argument items that contain lists are ever assigned to local
4417-
# variables. We hence can savely return them and omit any
4418-
# copies in this case.
44194421
# NOTE: Why is seems straightforward in principle to fail only
44204422
# for when we return *inner* lists of function arguments, this
4421-
# is still not desirable; even if we return the reference to
4422-
# the outer list correctly, any callerside assignment of the
4423-
# return value would no longer be recognizable as being the
4424-
# same reference given as argument, which is a problem if the
4425-
# list was an argument to the caller. I.e. while this works for
4426-
# one function indirection, it does not for two.
4427-
if not cc.StructType.isinstance(result.type) and \
4428-
self.containsList(result.type):
4423+
# is still not a good option for two reasons:
4424+
# 1) Even if we return the reference to the outer list correctly,
4425+
# any caller-side assignment of the return value would no longer
4426+
# be recognizable as being the same reference given as argument,
4427+
# which is a problem if the list was an argument to the caller.
4428+
# I.e. while this works for one function indirection, it does
4429+
# not work for two (see assignment tests).
4430+
# 2) To ensure that we don't have any memory leaks, we copy any
4431+
# lists returned from function calls to the stack. This copy (as
4432+
# of the time of writing this) results in a segfault when the
4433+
# list is not on the heap. As it is, we hence indeed have to copy
4434+
# every returned list to the heap, followed by a copy to the stack
4435+
# in the caller. Subsequent optimization passes should largely
4436+
# eliminate unnecessary copies.
4437+
if (self.containsList(result.type)):
44294438
self.emitFatalError("return value must not contain a list that is a function argument or an item in a function argument - for device kernels, lists passed as arguments will be modified in place", node)
4439+
else:
4440+
result = self.__migrateLists(result, copy_list_to_heap)
44304441

4431-
result = create_return_value(result)
44324442
if self.symbolTable.numLevels() > 1:
44334443
# We are in an inner scope, release all scopes before returning
44344444
cc.UnwindReturnOp([result])

0 commit comments

Comments
 (0)