Skip to content

Commit 56950d6

Browse files
[swiftsrc2cpg] Proper closure call handling (#5685)
1 parent 01643e0 commit 56950d6

File tree

11 files changed

+408
-59
lines changed

11 files changed

+408
-59
lines changed

joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstCreatorHelper.scala

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ import org.apache.commons.lang3.StringUtils
99
object AstCreatorHelper {
1010

1111
private val TagsToKeepInFullName = List("<anonymous>", "<lambda>", "<global>", "<type>", "<extension>", "<wildcard>")
12+
private val ReturnTypeMatcher = """^\(.*\)(->|:)(.+)$""".r
13+
private val ClosureSignatureMatcher = """^(\(.*\))\s*(.*)\s*->(.+)$""".r
1214

1315
/** Removes generic type parameters from qualified names while preserving special tags.
1416
*
@@ -84,9 +86,15 @@ object AstCreatorHelper {
8486
case "Dictionary" => Defines.Dictionary
8587
case "Nil" => Defines.Nil
8688
// Special patterns with specific handling
87-
case t if t.startsWith("[") && t.endsWith("]") => Defines.Array
88-
case t if t.contains("=>") || t.contains("->") => Defines.Function
89-
case t if t.contains("( ") => t.substring(0, t.indexOf("( "))
89+
case t if t.startsWith("[") && t.endsWith("]") => Defines.Array
90+
case ClosureSignatureMatcher(params, mods, returnType) =>
91+
// "throws" is the only modifier that swiftc keeps
92+
// so we have to restore it here to keep signatures
93+
// consistent between runs with compiler support and without.
94+
val m = if (mods.contains("throws")) { "throws" }
95+
else ""
96+
s"${Defines.Function}<$params$m->$returnType>".replace(" ", "")
97+
case t if t.contains("( ") => t.substring(0, t.indexOf("( "))
9098
// Default case
9199
case typeStr => typeStr
92100
}
@@ -162,7 +170,7 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As
162170
case None if identNode.typeFullName != Defines.Any => identNode.typeFullName
163171
case _ => Defines.Any
164172
}
165-
val typedIdentNode = identNode.typeFullName(tpe)
173+
identNode.typeFullName = tpe
166174
scope.addVariableReference(identifierName, identNode, tpe, EvaluationStrategies.BY_REFERENCE)
167175
Ast(identNode)
168176
}
@@ -191,8 +199,6 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As
191199
}
192200
}
193201

194-
private val ReturnTypeMatcher = """\(.*\)(->|:)(.+)""".r
195-
196202
protected def methodInfoForFunctionDeclLike(node: FunctionDeclLike): MethodInfo = {
197203
val name = calcMethodName(node)
198204
fullnameProvider.declFullname(node) match {
@@ -226,9 +232,16 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As
226232
val returnType = cleanType(code(s.returnClause.`type`))
227233
(s"${paramSignature(s.parameterClause)}->$returnType", returnType)
228234
case c: ClosureExprSyntax =>
229-
val returnType = c.signature.flatMap(_.returnClause).fold(Defines.Any)(r => cleanType(code(r.`type`)))
230-
val paramClauseCode = c.signature.flatMap(_.parameterClause).fold("()")(paramSignature)
231-
(s"$paramClauseCode->$returnType", returnType)
235+
fullnameProvider.typeFullnameRaw(node) match {
236+
case Some(tpe) =>
237+
val signature = tpe
238+
val returnType = ReturnTypeMatcher.findFirstMatchIn(signature).map(_.group(2)).getOrElse(Defines.Any)
239+
(signature, returnType)
240+
case _ =>
241+
val returnType = c.signature.flatMap(_.returnClause).fold(Defines.Any)(r => cleanType(code(r.`type`)))
242+
val paramClauseCode = c.signature.flatMap(_.parameterClause).fold("()")(paramSignature)
243+
(s"$paramClauseCode->$returnType", returnType)
244+
}
232245
}
233246
registerType(returnType)
234247
MethodInfo(methodName, methodFullName, signature, returnType)

joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstForDeclSyntaxCreator.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -705,7 +705,7 @@ trait AstForDeclSyntaxCreator(implicit withSchemaValidation: ValidationMode) {
705705
List.empty[Ast]
706706
}
707707

708-
val methodReturnNode_ = methodReturnNode(node, returnType, Some(returnType))
708+
val methodReturnNode_ = methodReturnNode(node, returnType)
709709

710710
val blockAst_ = blockAst(block, methodBlockContent ++ bodyStmtAsts)
711711
val astForMethod =

joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstForExprSyntaxCreator.scala

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,6 @@ trait AstForExprSyntaxCreator(implicit withSchemaValidation: ValidationMode) {
210210
baseNode: NewNode,
211211
callName: String
212212
): Ast = {
213-
214213
val trailingClosureAsts = callExpr.trailingClosure.toList.map(astForNode)
215214
val additionalTrailingClosuresAsts = callExpr.additionalTrailingClosures.children.map(c => astForNode(c.closure))
216215

@@ -289,6 +288,8 @@ trait AstForExprSyntaxCreator(implicit withSchemaValidation: ValidationMode) {
289288
val thisTmpNode = identifierNode(callee, tmpVarName)
290289
(fieldAccessAst, thisTmpNode, memberCode)
291290
}
291+
case other if isRefToClosure(node, other) =>
292+
return astForClosureCall(node)
292293
case _ =>
293294
val receiverAst = astForNode(callee)
294295
val thisNode = identifierNode(callee, "this")
@@ -299,6 +300,48 @@ trait AstForExprSyntaxCreator(implicit withSchemaValidation: ValidationMode) {
299300
}
300301
}
301302

303+
private def astForClosureCall(expr: FunctionCallExprSyntax): Ast = {
304+
val tpe = fullnameProvider.typeFullname(expr).getOrElse(Defines.Any)
305+
registerType(tpe)
306+
val signature = fullnameProvider.typeFullnameRaw(expr.calledExpression).getOrElse(x2cpg.Defines.UnresolvedSignature)
307+
val callName = Defines.ClosureApplyMethodName
308+
val callMethodFullname = s"${Defines.Function}<$signature>.$callName:$signature"
309+
val baseAst = astForNode(expr.calledExpression)
310+
311+
val trailingClosureAsts = expr.trailingClosure.toList.map(astForNode)
312+
val additionalTrailingClosuresAsts = expr.additionalTrailingClosures.children.map(c => astForNode(c.closure))
313+
314+
val args = expr.arguments.children.map(astForNode) ++ trailingClosureAsts ++ additionalTrailingClosuresAsts
315+
316+
val callExprCode = code(expr)
317+
val callNode_ = callNode(
318+
expr,
319+
callExprCode,
320+
callName,
321+
callMethodFullname,
322+
DispatchTypes.DYNAMIC_DISPATCH,
323+
Option(signature),
324+
Option(tpe)
325+
)
326+
callAst(callNode_, args, Option(baseAst))
327+
}
328+
329+
private def isRefToClosure(func: FunctionCallExprSyntax, node: ExprSyntax): Boolean = {
330+
if (!config.swiftBuild) {
331+
// Early exit; without types from the compiler we will be unable to identify closure calls anyway.
332+
// This saves us the typeFullname lookup below.
333+
return false
334+
}
335+
node match {
336+
case refExpr: DeclReferenceExprSyntax
337+
if refExpr.baseName.isInstanceOf[identifier] && refExpr.argumentNames.isEmpty &&
338+
fullnameProvider.declFullname(func).isEmpty &&
339+
fullnameProvider.typeFullname(refExpr).exists(_.startsWith(s"${Defines.Function}<")) =>
340+
true
341+
case _ => false
342+
}
343+
}
344+
302345
private def astForGenericSpecializationExprSyntax(node: GenericSpecializationExprSyntax): Ast = {
303346
astForNode(node.expression)
304347
}

joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstNodeBuilder.scala

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -207,22 +207,30 @@ trait AstNodeBuilder(implicit withSchemaValidation: ValidationMode) { this: AstC
207207

208208
protected def createFunctionTypeAndTypeDecl(node: SwiftNode, methodNode: NewMethod): Unit = {
209209
registerType(methodNode.fullName)
210+
211+
val (inherits, bindingName) = if (node.isInstanceOf[ClosureExprSyntax]) {
212+
val inheritsFunctionFullName = s"${Defines.Function}<${methodNode.signature}>"
213+
registerType(inheritsFunctionFullName)
214+
(Seq(inheritsFunctionFullName), Defines.ClosureApplyMethodName)
215+
} else (Seq.empty, methodNode.name)
216+
210217
val (astParentType, astParentFullName) = astParentInfo()
211218
val methodTypeDeclNode = typeDeclNode(
212219
node,
213220
methodNode.name,
214221
methodNode.fullName,
215222
methodNode.filename,
216223
methodNode.fullName,
217-
astParentType,
218-
astParentFullName
224+
astParentType = astParentType,
225+
astParentFullName = astParentFullName,
226+
inherits = inherits
219227
)
220228

221229
methodNode.astParentFullName = astParentFullName
222230
methodNode.astParentType = astParentType
223231

224232
val functionBinding = NewBinding()
225-
.name(methodNode.name)
233+
.name(bindingName)
226234
.methodFullName(methodNode.fullName)
227235
.signature(methodNode.signature)
228236

joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/passes/SwiftTypeNodePass.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ object SwiftTypeNodePass {
1515
override def fullToShortName(typeName: String): String = {
1616
typeName match {
1717
case name if name.endsWith(NamespaceTraversal.globalNamespaceName) => NamespaceTraversal.globalNamespaceName
18-
case _ => typeName.split('.').lastOption.getOrElse(typeName)
18+
case _ => typeName.takeWhile(_ != '<').split('.').lastOption.getOrElse(typeName)
1919
}
2020
}
2121

joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/utils/FullnameProvider.scala

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,12 @@ private object FullnameProvider {
2020
}
2121

2222
// TODO: provide the actual mapping from SwiftNode.toString (nodeKind) to ResolvedTypeInfo.nodeKind
23-
private val NodeKindMapping = Map("DeclReferenceExprSyntax" -> "type_expr")
23+
private val NodeKindMapping = Map(
24+
"DeclReferenceExprSyntax" -> "type_expr",
25+
"VariableDeclSyntax" -> "var_decl",
26+
"PatternBindingSyntax" -> "var_decl",
27+
"IdentifierPatternSyntax" -> "var_decl"
28+
)
2429
}
2530

2631
/** Provides functionality to resolve and retrieve fullnames for Swift types and declarations. Uses a type mapping to
@@ -92,6 +97,12 @@ class FullnameProvider(typeMap: SwiftFileLocalTypeMapping) {
9297
* An optional String containing the type fullname if found
9398
*/
9499
protected def typeFullname(range: (Int, Int), nodeKind: String): Option[String] = {
100+
fullName(range, FullnameProvider.Kind.Type, nodeKind).map(AstCreatorHelper.cleanType)
101+
}
102+
103+
/** Same as FullnameProvider.typeFullname but does no type name sanitation.
104+
*/
105+
protected def typeFullnameRaw(range: (Int, Int), nodeKind: String): Option[String] = {
95106
fullName(range, FullnameProvider.Kind.Type, nodeKind).map(AstCreatorHelper.cleanName)
96107
}
97108

@@ -124,6 +135,16 @@ class FullnameProvider(typeMap: SwiftFileLocalTypeMapping) {
124135
}
125136
}
126137

138+
/** Same as FullnameProvider.typeFullname but does no type name sanitation.
139+
*/
140+
def typeFullnameRaw(node: SwiftNode): Option[String] = {
141+
if (typeMap.isEmpty) return None
142+
(node.startOffset, node.endOffset) match {
143+
case (Some(start), Some(end)) => typeFullnameRaw((start, end), node.toString)
144+
case _ => None
145+
}
146+
}
147+
127148
/** Retrieves the declaration fullname for a given Swift node. Extracts the start and end offsets from the node if
128149
* available. Returns None if typeMap is empty.
129150
*

joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/AsyncTests.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,10 @@ class AsyncTests extends AstSwiftSrc2CpgSuite {
2929
val cpg = code("func asyncGlobal3(fn: () throws -> Int) async rethrows { }")
3030
val List(asyncGlobal3) = cpg.method.internal.nameNot(NamespaceTraversal.globalNamespaceName).l
3131
asyncGlobal3.name shouldBe "asyncGlobal3"
32-
asyncGlobal3.fullName shouldBe s"Test0.swift:<global>.asyncGlobal3:(fn:${Defines.Function})->ANY"
32+
asyncGlobal3.fullName shouldBe s"Test0.swift:<global>.asyncGlobal3:(fn:Swift.Function<()throws->Int>)->ANY"
3333
val List(fn) = asyncGlobal3.parameter.l
3434
fn.name shouldBe "fn"
35-
fn.typeFullName shouldBe Defines.Function
35+
fn.typeFullName shouldBe "Swift.Function<()throws->Int>"
3636
}
3737

3838
"testAsync4" in {

0 commit comments

Comments
 (0)