@@ -266,6 +266,7 @@ def __init__(self, capturedDataStorage: CapturedDataStorage, **kwargs):
266266 self .walkingReturnNode = False
267267 self .controlNegations = []
268268 self .pushPointerValue = False
269+ self .isSubscriptRoot = False
269270 self .verbose = 'verbose' in kwargs and kwargs ['verbose' ]
270271 self .currentNode = None
271272
@@ -866,6 +867,34 @@ def bodyBuilder(iterVar):
866867 vecTy = cc .StdvecType .get (targetEleType ) if not isTargetBool else cc .StdvecType .get (self .getIntegerType (1 ))
867868 return cc .StdvecInitOp (vecTy , targetPtr , length = sourceSize ).result
868869
870+ def __migrateLists (self , value , migrate ):
871+ """
872+ Replaces all lists in the given value by the list returned
873+ by the `migrate` function, including inner lists. Does an
874+ in-place replacement for list elements.
875+ """
876+ if cc .StdvecType .isinstance (value .type ):
877+ eleTy = cc .StdvecType .getElementType (value .type )
878+ if self .containsList (eleTy ):
879+ size = cc .StdvecSizeOp (self .getIntegerType (), value ).result
880+ ptrTy = cc .PointerType .get (cc .ArrayType .get (eleTy ))
881+ iterable = cc .StdvecDataOp (ptrTy , value ).result
882+ def bodyBuilder (iterVar ):
883+ eleAddr = cc .ComputePtrOp (
884+ cc .PointerType .get (eleTy ), iterable , [iterVar ],
885+ DenseI32ArrayAttr .get ([kDynamicPtrIndex ], context = self .ctx ))
886+ loadedEle = cc .LoadOp (eleAddr ).result
887+ element = self .__migrateLists (loadedEle , migrate )
888+ cc .StoreOp (element , eleAddr )
889+ self .createInvariantForLoop (bodyBuilder , size )
890+ return migrate (value )
891+ if (cc .StructType .isinstance (value .type ) and
892+ self .containsList (value .type )):
893+ return self .__copyStructAndConvertElements (
894+ value , conversion = lambda _ , v : self .__migrateLists (v , migrate ))
895+ assert not self .containsList (value .type )
896+ return value
897+
869898 def __insertDbgStmt (self , value , dbgStmt ):
870899 """
871900 Insert a debug print out statement if the programmer requested. Handles
@@ -1339,7 +1368,7 @@ def __validate_container_entry(self, mlirVal, pyVal):
13391368 if cc .StructType .isinstance (mlirVal .type ):
13401369 structName = cc .StructType .getName (mlirVal .type )
13411370 # We need to give a proper error if we try to assign
1342- # a mutable dataclass to an item in a list or a dataclass .
1371+ # a mutable dataclass to an item in another container .
13431372 # Allowing this would lead to incorrect behavior (i.e.
13441373 # inconsistent with Python) unless we change the
13451374 # representation of structs to be like `StdvecType`
@@ -2048,33 +2077,30 @@ def processFunctionCall(kernel):
20482077 if len (kernel .type .results ) == 0 :
20492078 func .CallOp (kernel , values )
20502079 return
2080+ # The logic for calls that return values must
2081+ # match the logic in `visit_Return`; anything
2082+ # copied to the heap during return must be copied
2083+ # back to the stack. Compiler optimizations should
2084+ # take care of eliminating unnecessary copies.
20512085 result = func .CallOp (kernel , values ).result
2052- # Copy to stack if necessary
2053- # FIXME: needs to be updated to be recursive much like the matching return logic
2054- if cc .StdvecType .isinstance (result .type ):
2055- elemTy = cc .StdvecType .getElementType (result .type )
2056- if elemTy == self .getIntegerType (1 ):
2057- elemTy = self .getIntegerType (8 )
2058- data = cc .StdvecDataOp (cc .PointerType .get (elemTy ),
2059- result ).result
2060- i64Ty = self .getIntegerType (64 )
2061- length = cc .StdvecSizeOp (i64Ty , result ).result
2062- elemSize = cc .SizeOfOp (i64Ty , TypeAttr .get (elemTy )).result
2063- buffer = cc .AllocaOp (cc .PointerType .get (
2064- cc .ArrayType .get (elemTy )),
2065- TypeAttr .get (elemTy ),
2066- seqSize = length ).result
2067- i8PtrTy = cc .PointerType .get (self .getIntegerType (8 ))
2068- cbuffer = cc .CastOp (i8PtrTy , buffer ).result
2069- cdata = cc .CastOp (i8PtrTy , data ).result
2086+ def copy_list_to_stack (value ):
20702087 symName = '__nvqpp_vectorCopyToStack'
20712088 load_intrinsic (self .module , symName )
2072- sizeInBytes = arith .MulIOp (length , elemSize ).result
2073- func .CallOp ([], symName , [cbuffer , cdata , sizeInBytes ])
2074- # Replace result with the stack buffer-backed vector
2075- result = cc .StdvecInitOp (result .type , buffer ,
2076- length = length ).result
2077- return result
2089+ elemTy = cc .StdvecType .getElementType (value .type )
2090+ if elemTy == self .getIntegerType (1 ):
2091+ elemTy = self .getIntegerType (8 )
2092+ ptrTy = cc .PointerType .get (self .getIntegerType (8 ))
2093+ resBuf = cc .StdvecDataOp (cc .PointerType .get (elemTy ), value ).result
2094+ eleSize = cc .SizeOfOp (self .getIntegerType (), TypeAttr .get (elemTy )).result
2095+ dynSize = cc .StdvecSizeOp (self .getIntegerType (), value ).result
2096+ stackCopy = cc .AllocaOp (cc .PointerType .get (cc .ArrayType .get (elemTy )),
2097+ TypeAttr .get (elemTy ),
2098+ seqSize = dynSize ).result
2099+ func .CallOp ([], symName , [cc .CastOp (ptrTy , stackCopy ).result ,
2100+ cc .CastOp (ptrTy , resBuf ).result ,
2101+ arith .MulIOp (dynSize , eleSize ).result ])
2102+ return cc .StdvecInitOp (value .type , stackCopy , length = dynSize ).result
2103+ return self .__migrateLists (result , copy_list_to_stack )
20782104
20792105 def checkControlAndTargetTypes (controls , targets ):
20802106 """
@@ -3810,8 +3836,14 @@ def fix_negative_idx(idx, get_size):
38103836 var = self .popValue ()
38113837 self .pushPointerValue = True
38123838 else :
3839+ # `isSubscriptRoot` is only used/needed to enable
3840+ # modification of items in lists and dataclasses
3841+ # contained in a tuple
3842+ subscriptRoot = self .isSubscriptRoot
3843+ self .isSubscriptRoot = True
38133844 self .visit (node .value )
38143845 var = self .popValue ()
3846+ self .isSubscriptRoot = subscriptRoot
38153847
38163848 pushPtr = self .pushPointerValue
38173849 self .pushPointerValue = False
@@ -3838,7 +3870,7 @@ def fix_negative_idx(idx, get_size):
38383870 if cc .PointerType .isinstance (var .type ):
38393871 # We should only ever get a pointer if we
38403872 # explicitly asked for it.
3841- assert self .pushPointerValue == True
3873+ assert self .pushPointerValue
38423874 varType = cc .PointerType .getElementType (var .type )
38433875 if cc .StdvecType .isinstance (varType ):
38443876 # We can get a pointer to a vector (only) if we
@@ -3848,14 +3880,26 @@ def fix_negative_idx(idx, get_size):
38483880 # the vector, since the underlying data is
38493881 # not loaded.
38503882 var = cc .LoadOp (var ).result
3883+
38513884 if cc .StructType .isinstance (varType ):
3885+ structName = cc .StructType .getName (varType )
3886+ if not self .isSubscriptRoot and structName == 'tuple' :
3887+ self .emitFatalError ("tuple value cannot be modified" , node )
3888+ if not isinstance (node .slice , ast .Constant ):
3889+ if self .pushPointerValue :
3890+ if structName == 'tuple' :
3891+ self .emitFatalError ("tuple value cannot be modified via non-constant subscript" , node )
3892+ self .emitFatalError (
3893+ f"{ structName } value cannot be modified via non-constant subscript - use attribute access instead" ,
3894+ node )
3895+
3896+ idxVal = node .slice .value
3897+ structTys = cc .StructType .getTypes (varType )
3898+ eleAddr = cc .ComputePtrOp (cc .PointerType .get (structTys [idxVal ]), var , [],
3899+ DenseI32ArrayAttr .get ([idxVal ])).result
38523900 if self .pushPointerValue :
3853- structName = cc .StructType .getName (varType )
3854- if structName == 'tuple' :
3855- self .emitFatalError ("tuple value cannot be modified" , node )
3856- self .emitFatalError (
3857- f"{ structName } value cannot be modified - use `.copy(deep)` to create a new value that can be modified" ,
3858- node )
3901+ self .pushValue (eleAddr )
3902+ return
38593903
38603904 if cc .StdvecType .isinstance (var .type ):
38613905 idx = fix_negative_idx (idx , lambda : get_size (var ))
@@ -4338,97 +4382,63 @@ def visit_Return(self, node):
43384382 result ,
43394383 allowDemotion = True )
43404384
4341- def getSize (ty ):
4342- fp_width = lambda t : 4 if F32Type .isinstance (t ) else 8
4343- if cc .StructType .isinstance (ty ):
4344- return cc .SizeOfOp (self .getIntegerType (), TypeAttr .get (ty )).result
4345- if ComplexType .isinstance (ty ):
4346- fType = ComplexType (ty ).element_type
4347- return self .getConstantInt (2 * fp_width (fType ))
4348- elif F32Type .isinstance (ty ) or F64Type .isinstance (ty ):
4349- return self .getConstantInt (fp_width (ty ))
4350- elif IntegerType .isinstance (ty ):
4351- width = IntegerType (ty ).width
4352- return self .getConstantInt ((width + 7 ) // 8 )
4353- return self .getConstantInt (8 )
4354-
4385+ # Generally, anything that was allocated locally on the stack
4386+ # needs to be copied to the heap to ensure it lives past the
4387+ # the function. This holds recursively; if we have a struct
4388+ # that contains a list, then the list data may need to be
4389+ # copied if it was allocated inside the function.
43554390 def copy_list_to_heap (value ):
43564391 symName = '__nvqpp_vectorCopyCtor'
43574392 load_intrinsic (self .module , symName )
4393+ elemTy = cc .StdvecType .getElementType (value .type )
4394+ if elemTy == self .getIntegerType (1 ):
4395+ elemTy = self .getIntegerType (8 )
43584396 ptrTy = cc .PointerType .get (self .getIntegerType (8 ))
43594397 arrTy = cc .ArrayType .get (self .getIntegerType (8 ))
43604398 ptrArrTy = cc .PointerType .get (arrTy )
43614399 resBuf = cc .StdvecDataOp (ptrArrTy , value ).result
4362- eleSize = getSize ( cc .StdvecType . getElementType ( value . type ))
4400+ eleSize = cc .SizeOfOp ( self . getIntegerType (), TypeAttr . get ( elemTy )). result
43634401 dynSize = cc .StdvecSizeOp (self .getIntegerType (), value ).result
43644402 resBuf = cc .CastOp (ptrTy , resBuf )
43654403 heapCopy = func .CallOp ([ptrTy ], symName ,
43664404 [resBuf , dynSize , eleSize ]).result
43674405 return cc .StdvecInitOp (value .type , heapCopy , length = dynSize ).result
43684406
4369- # Generally, anything that was allocated locally on the stack
4370- # needs to be copied to the heap to ensure it lives past the
4371- # the function. This holds recursively; if we have a struct
4372- # that contains a list, then the list data may need to be
4373- # copied if it was allocated inside the function.
4374- def create_return_value (res ):
4375- if cc .StdvecType .isinstance (res .type ):
4376- eleTy = cc .StdvecType .getElementType (res .type ) # iterty
4377- if self .containsList (eleTy ):
4378- # Need to make sure all inner lists live past
4379- # this function as well.
4380- size = cc .StdvecSizeOp (self .getIntegerType (), res ).result
4381- ptrTy = cc .PointerType .get (cc .ArrayType .get (eleTy ))
4382- iterable = cc .StdvecDataOp (ptrTy , res ).result
4383- def bodyBuilder (iterVar ):
4384- eleAddr = cc .ComputePtrOp (
4385- cc .PointerType .get (eleTy ), iterable , [iterVar ],
4386- DenseI32ArrayAttr .get ([kDynamicPtrIndex ], context = self .ctx ))
4387- loadedEle = cc .LoadOp (eleAddr ).result
4388- element = create_return_value (loadedEle )
4389- cc .StoreOp (element , eleAddr )
4390- self .createInvariantForLoop (bodyBuilder , size )
4391- return res if self .isFunctionArgument (res ) else copy_list_to_heap (res )
4392- if not cc .StructType .isinstance (res .type ):
4393- assert not cc .PointerType .isinstance (res .type )
4394- return res
4395- if rootVal and self .isFunctionArgument (rootVal ):
4396- # see comment below
4397- return res
4398- return self .__copyStructAndConvertElements (
4399- res , conversion = lambda _ , v : create_return_value (v ))
4400-
44014407 rootVal = self .__get_root_value (node .value )
44024408 if rootVal and self .isFunctionArgument (rootVal ):
4403- # If we allow assigning a value that contains a list to an
4404- # item of a function argument (which we do - caveat below),
4405- # then we necessarily need to make a copy when we return
4406- # function arguments, or function argument elements, that
4407- # contain lists, since we have to assume that their data may
4409+ # If we allow assigning a value that contains a list to an item
4410+ # of a function argument (which we do with the exceptions
4411+ # commented below), then we necessarily need to make a copy when
4412+ # we return function arguments, or function argument elements,
4413+ # that contain lists, since we have to assume that their data may
44084414 # be allocated on the stack. However, this leads to incorrect
44094415 # behavior if a returned list was indeed caller-side allocated
44104416 # (and should correspondingly have been returned by reference).
44114417 # Rather than preventing that lists in function arguments can be
44124418 # updated, we instead ensure that lists contained in function
44134419 # arguments stay recognizable as such, and prevent that function
44144420 # arguments that contain list are returned.
4415- # Caveat: We prevent that dataclass or tuple arguments or
4416- # argument items that contain lists are ever assigned to local
4417- # variables. We hence can savely return them and omit any
4418- # copies in this case.
44194421 # NOTE: Why is seems straightforward in principle to fail only
44204422 # for when we return *inner* lists of function arguments, this
4421- # is still not desirable; even if we return the reference to
4422- # the outer list correctly, any callerside assignment of the
4423- # return value would no longer be recognizable as being the
4424- # same reference given as argument, which is a problem if the
4425- # list was an argument to the caller. I.e. while this works for
4426- # one function indirection, it does not for two.
4427- if not cc .StructType .isinstance (result .type ) and \
4428- self .containsList (result .type ):
4423+ # is still not a good option for two reasons:
4424+ # 1) Even if we return the reference to the outer list correctly,
4425+ # any caller-side assignment of the return value would no longer
4426+ # be recognizable as being the same reference given as argument,
4427+ # which is a problem if the list was an argument to the caller.
4428+ # I.e. while this works for one function indirection, it does
4429+ # not work for two (see assignment tests).
4430+ # 2) To ensure that we don't have any memory leaks, we copy any
4431+ # lists returned from function calls to the stack. This copy (as
4432+ # of the time of writing this) results in a segfault when the
4433+ # list is not on the heap. As it is, we hence indeed have to copy
4434+ # every returned list to the heap, followed by a copy to the stack
4435+ # in the caller. Subsequent optimization passes should largely
4436+ # eliminate unnecessary copies.
4437+ if (self .containsList (result .type )):
44294438 self .emitFatalError ("return value must not contain a list that is a function argument or an item in a function argument - for device kernels, lists passed as arguments will be modified in place" , node )
4439+ else :
4440+ result = self .__migrateLists (result , copy_list_to_heap )
44304441
4431- result = create_return_value (result )
44324442 if self .symbolTable .numLevels () > 1 :
44334443 # We are in an inner scope, release all scopes before returning
44344444 cc .UnwindReturnOp ([result ])
0 commit comments