diff --git a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/passes/TypeInferencePass.scala b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/passes/TypeInferencePass.scala index 8f3e98743388..3feea533f8bc 100644 --- a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/passes/TypeInferencePass.scala +++ b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/passes/TypeInferencePass.scala @@ -32,15 +32,22 @@ class TypeInferencePass(cpg: Cpg) extends ForkJoinParallelCpgPass[Call](cpg) { } private def isMatchingMethod(method: Method, call: Call, callNameParts: NameParts): Boolean = { + if (call.name == "createScript") { + println(s"\n Checking method: ${method.fullName}") + } // An erroneous `this` argument is added for unresolved calls to static methods. val argSizeMod = if (method.modifier.modifierType.iterator.contains(ModifierTypes.STATIC)) 1 else 0 lazy val methodNameParts = getNameParts(method.name, method.fullName) val parameterSizesMatch = (method.parameter.size == (call.argument.size - argSizeMod)) - + if (call.name == "createScript") { + println(s" parameterSizesMatch: $parameterSizesMatch (${method.parameter.size} == ${call.argument.size} - $argSizeMod)") + } lazy val argTypesMatch = doArgumentTypesMatch(method: Method, call: Call, skipCallThis = argSizeMod == 1) - + if (call.name == "createScript" && parameterSizesMatch) { + println(s" argTypesMatch: $argTypesMatch") + } lazy val typeDeclMatches = (callNameParts.typeDecl == methodNameParts.typeDecl) parameterSizesMatch && argTypesMatch && typeDeclMatches @@ -52,18 +59,76 @@ class TypeInferencePass(cpg: Cpg) extends ForkJoinParallelCpgPass[Call](cpg) { */ private def doArgumentTypesMatch(method: Method, call: Call, skipCallThis: Boolean): Boolean = { val callArgs = if (skipCallThis) call.argument.toList.tail else call.argument.toList - - val hasDifferingArg = method.parameter.zip(callArgs).exists { case (parameter, argument) => +// if (call.name == "createScript") { +// println(s"\n === doArgumentTypesMatch ===") +// println(s" skipCallThis: $skipCallThis") +// println(s" call.argument.size: ${call.argument.size}") +// println(s" callArgs.size after skip: ${callArgs.size}") +// println(s" method.parameter.size: ${method.parameter.size}") +// } + val hasDifferingArg = method.parameter.zip(callArgs).zipWithIndex.exists { case ((parameter, argument), idx) => val maybeArgumentType = argument.propertyOption(Properties.TypeFullName).getOrElse(TypeConstants.Any) + val argMatches = maybeArgumentType == TypeConstants.Any || maybeArgumentType == parameter.typeFullName || (maybeArgumentType == TypeConstants.NULL && !isPrimitiveType(parameter.typeFullName)) || isSubtypeOf(maybeArgumentType, parameter.typeFullName) +// if (call.name == "createScript") { +// println(s" [$idx] Param: ${parameter.name} (${parameter.typeFullName}) vs Arg: ${argument.code} (${maybeArgumentType})") +// println(s" Match: $argMatches") +// } val argMatches = maybeArgumentType == TypeConstants.Any || maybeArgumentType == parameter.typeFullName || (maybeArgumentType == TypeConstants.Null && !PrimitiveTypes .contains(parameter.typeFullName)) !argMatches } - + if (call.name == "createScript") { + println(s" hasDifferingArg: $hasDifferingArg") + println(s" === End doArgumentTypesMatch ===\n") + } !hasDifferingArg } + private def isPrimitiveType(typeName: String): Boolean = { + Set("byte", "short", "int", "long", "float", "double", "boolean", "char").contains(typeName) + } + +// private def isSubtypeOf(argumentType: String, parameterType: String): Boolean = { +// if (argumentType == parameterType) { +// return true +// } +// +// cpg.typeDecl.fullNameExact(argumentType).headOption match { +// case Some(typeDecl) => +// typeDecl.inheritsFromTypeFullName.contains(parameterType) +// case None => +// false +// } +// } + private def isSubtypeOf( + argumentType: String, + parameterType: String, + visited: Set[String] = Set.empty + ): Boolean = { + if (argumentType == parameterType) { + return true + } + + if (visited.contains(argumentType)) { + return false + } + + cpg.typeDecl.fullNameExact(argumentType).headOption match { + case Some(typeDecl) => + val parents = typeDecl.inheritsFromTypeFullName.l + + if (parents.contains(parameterType)) { + return true + } + + val newVisited = visited + argumentType + parents.exists(parent => isSubtypeOf(parent, parameterType, newVisited)) + + case None => + false + } + } private def getNameParts(name: String, fullName: String): NameParts = { val Array(qualifiedName, signature) = fullName.split(":", 2)