Skip to content

Rust: Add predicate for certain type information #20155

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 150 additions & 14 deletions rust/ql/lib/codeql/rust/internal/TypeInference.qll
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,13 @@ private module M2 = Make2<Input2>;

private import M2

module Consistency = M2::Consistency;
module Consistency {
import M2::Consistency

query predicate nonUniqueCertainType(AstNode n, TypePath path) {
strictcount(CertainTypeInference::inferCertainType(n, path)) > 1
}
}

/** Gets the type annotation that applies to `n`, if any. */
private TypeMention getTypeAnnotation(AstNode n) {
Expand Down Expand Up @@ -249,6 +255,134 @@ private Type inferAnnotatedType(AstNode n, TypePath path) {
result = getTypeAnnotation(n).resolveTypeAt(path)
}

/** Module for inferring certain type information. */
private module CertainTypeInference {
/** Holds if the type mention does not contain any inferred types `_`. */
predicate typeMentionIsComplete(TypeMention tm) {
not exists(InferTypeRepr t | t.getParentNode*() = tm)
}

/**
* Holds if `ce` is a call where we can infer the type with certainty and if
* `f` is the target of the call and `p` the path invoked by the call.
*
* Necessary conditions for this are:
* - We are certain of the call target (i.e., the call target can not depend on type information).
* - The declared type of the function does not contain any generics that we
* need to infer.
* - The call does not contain any arguments, as arguments in calls are coercion sites.
*
* The current requirements are made to allow for call to `new` functions such
* as `Vec<Foo>::new()` but not much more.
*/
predicate certainCallExprTarget(CallExpr ce, Function f, Path p) {
p = CallExprImpl::getFunctionPath(ce) and
f = resolvePath(p) and
// The function is not in a trait
not any(TraitItemNode t).getAnAssocItem() = f and
// The function is not in a trait implementation
not any(ImplItemNode impl | impl.(Impl).hasTrait()).getAnAssocItem() = f and
// The function does not have parameters.
not f.getParamList().hasSelfParam() and
f.getParamList().getNumberOfParams() = 0 and
// The function is not async.
not f.isAsync() and
// For now, exclude functions in macro expansions.
not ce.isInMacroExpansion() and
// The function has no type parameters.
not f.hasGenericParamList() and
// The function does not have `impl` types among its parameters (these are type parameters).
not any(ImplTraitTypeRepr itt | not itt.isInReturnPos()).getFunction() = f and
(
not exists(ImplItemNode impl | impl.getAnAssocItem() = f)
or
// If the function is in an impl then the impl block has no type
// parameters or all the type parameters are given explicitly.
exists(ImplItemNode impl | impl.getAnAssocItem() = f |
not impl.(Impl).hasGenericParamList() or
impl.(Impl).getGenericParamList().getNumberOfGenericParams() =
p.getQualifier().getSegment().getGenericArgList().getNumberOfGenericArgs()
)
)
}

private ImplItemNode getFunctionImpl(FunctionItemNode f) { result.getAnAssocItem() = f }

Type inferCertainCallExprType(CallExpr ce, TypePath path) {
exists(Function f, Type ty, TypePath prefix, Path p |
certainCallExprTarget(ce, f, p) and
ty = f.getRetType().getTypeRepr().(TypeMention).resolveTypeAt(prefix)
|
if ty.(TypeParamTypeParameter).getTypeParam() = getFunctionImpl(f).getTypeParam(_)
then
exists(TypePath pathToTp, TypePath suffix |
// For type parameters of the `impl` block we must resolve their
// instantiation from the path. For instance, for `impl<A> for Foo<A>`
// and the path `Foo<i64>::bar` we must resolve `A` to `i64`.
ty = getFunctionImpl(f).(Impl).getSelfTy().(TypeMention).resolveTypeAt(pathToTp) and
result = p.getQualifier().(TypeMention).resolveTypeAt(pathToTp.appendInverse(suffix)) and
path = prefix.append(suffix)
)
else (
result = ty and path = prefix
)
)
}

predicate certainTypeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2) {
prefix1.isEmpty() and
prefix2.isEmpty() and
(
exists(Variable v | n1 = v.getAnAccess() |
n2 = v.getPat().getName() or n2 = v.getParameter().(SelfParam)
)
or
// A `let` statement with a type annotation is a coercion site and hence
// is not a certain type equality.
exists(LetStmt let | not let.hasTypeRepr() |
let.getPat() = n1 and
let.getInitializer() = n2
)
)
or
n1 =
any(IdentPat ip |
n2 = ip.getName() and
prefix1.isEmpty() and
if ip.isRef() then prefix2 = TypePath::singleton(TRefTypeParameter()) else prefix2.isEmpty()
)
}

pragma[nomagic]
private Type inferCertainTypeEquality(AstNode n, TypePath path) {
exists(TypePath prefix1, AstNode n2, TypePath prefix2, TypePath suffix |
result = inferCertainType(n2, prefix2.appendInverse(suffix)) and
path = prefix1.append(suffix)
|
certainTypeEquality(n, prefix1, n2, prefix2)
or
certainTypeEquality(n2, prefix2, n, prefix1)
)
}

/**
* Holds if `n` has complete and certain type information and if `n` has the
* resulting type at `path`.
*/
pragma[nomagic]
Type inferCertainType(AstNode n, TypePath path) {
exists(TypeMention tm |
tm = getTypeAnnotation(n) and
typeMentionIsComplete(tm) and
result = tm.resolveTypeAt(path)
)
or
result = inferCertainCallExprType(n, path)
or
result = inferCertainTypeEquality(n, path)
}
}

private Type inferLogicalOperationType(AstNode n, TypePath path) {
exists(Builtins::BuiltinType t, BinaryLogicalOperation be |
n = [be, be.getLhs(), be.getRhs()] and
Expand Down Expand Up @@ -288,15 +422,11 @@ private Struct getRangeType(RangeExpr re) {
* through the type equality.
*/
private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2) {
CertainTypeInference::certainTypeEquality(n1, prefix1, n2, prefix2)
or
prefix1.isEmpty() and
prefix2.isEmpty() and
(
exists(Variable v | n1 = v.getAnAccess() |
n2 = v.getPat().getName()
or
n2 = v.getParameter().(SelfParam)
)
or
exists(LetStmt let |
let.getPat() = n1 and
let.getInitializer() = n2
Expand Down Expand Up @@ -339,13 +469,6 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
n1 = n2.(MacroPat).getMacroCall().getMacroCallExpansion()
)
or
n1 =
any(IdentPat ip |
n2 = ip.getName() and
prefix1.isEmpty() and
if ip.isRef() then prefix2 = TypePath::singleton(TRefTypeParameter()) else prefix2.isEmpty()
)
or
(
n1 = n2.(RefExpr).getExpr() or
n1 = n2.(RefPat).getPat()
Expand Down Expand Up @@ -408,6 +531,9 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat

pragma[nomagic]
private Type inferTypeEquality(AstNode n, TypePath path) {
// Don't propagate type information into a node for which we already have
// certain type information.
not exists(CertainTypeInference::inferCertainType(n, _)) and
exists(TypePath prefix1, AstNode n2, TypePath prefix2, TypePath suffix |
result = inferType(n2, prefix2.appendInverse(suffix)) and
path = prefix1.append(suffix)
Expand Down Expand Up @@ -818,6 +944,8 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
}

final class Access extends Call {
Access() { not CertainTypeInference::certainCallExprTarget(this, _, _) }

pragma[nomagic]
Type getTypeArgument(TypeArgumentPosition apos, TypePath path) {
exists(TypeMention arg | result = arg.resolveTypeAt(path) |
Expand Down Expand Up @@ -2150,6 +2278,8 @@ private module Cached {
cached
Type inferType(AstNode n, TypePath path) {
Stages::TypeInferenceStage::ref() and
result = CertainTypeInference::inferCertainType(n, path)
or
result = inferAnnotatedType(n, path)
or
result = inferLogicalOperationType(n, path)
Expand Down Expand Up @@ -2305,4 +2435,10 @@ private module Debug {
c = countTypePaths(n, path, t) and
c = max(countTypePaths(_, _, _))
}

Type debugInferCertainNonUniqueType(AstNode n, TypePath path) {
n = getRelevantLocatable() and
Consistency::nonUniqueCertainType(n, path) and
result = CertainTypeInference::inferCertainType(n, path)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
* Provides classes for recognizing type inference inconsistencies.
*/

private import rust
private import Type
private import TypeMention
private import TypeInference
private import TypeInference::Consistency as Consistency
import TypeInference::Consistency

Expand All @@ -27,4 +29,7 @@ int getTypeInferenceInconsistencyCounts(string type) {
or
type = "Ill-formed type mention" and
result = count(TypeMention tm | illFormedTypeMention(tm) | tm)
or
type = "Non-unique certain type information" and
result = count(AstNode n, TypePath path | nonUniqueCertainType(n, path) | n)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
nonUniqueCertainType
| web_frameworks.rs:139:30:139:39 | ...::get(...) | |
| web_frameworks.rs:140:34:140:43 | ...::get(...) | |
| web_frameworks.rs:141:30:141:39 | ...::get(...) | |
36 changes: 36 additions & 0 deletions rust/ql/test/library-tests/type-inference/dereference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,46 @@ fn implicit_dereference() {
let _y = x.is_positive(); // $ MISSING: target=is_positive type=_y:bool
}

mod implicit_deref_coercion_cycle {
use std::collections::HashMap;

#[derive(Hash, PartialEq, Eq, PartialOrd, Ord, Debug, Clone, Copy)]
pub struct Key {}

// This example can trigger a cycle in type inference due to an implicit
// dereference if we are not careful and accurate enough.
//
// To explain how a cycle might happen, we let `[V]` denote the type of the
// type parameter `V` of `key_to_key` (i.e., the type of the values in the
// map) and `[key]` denote the type of `key`.
//
// 1. From the first two lines we infer `[V] = &Key` and `[key] = &Key`
// 2. At the 3. line we infer the type of `ref_key` to be `&[V]`.
// 3. At the 4. line we impose the equality `[key] = &[V]`, not accounting
// for the implicit deref caused by a coercion.
// 4. At the last line we infer `[key] = [V]`.
//
// Putting the above together we have `[V] = [key] = &[V]` which is a cycle.
// This means that `[key]` is both `&Key`, `&&Key`, `&&&Key`, and so on ad
// infinitum.

#[rustfmt::skip]
pub fn test() {
let mut key_to_key = HashMap::<&Key, &Key>::new(); // $ target=new
let mut key = &Key {}; // Initialize key2 to a reference
if let Some(ref_key) = key_to_key.get(key) { // $ target=get
// Below `ref_key` is implicitly dereferenced from `&&Key` to `&Key`
key = ref_key;
}
key_to_key.insert(key, key); // $ target=insert
}
}

pub fn test() {
explicit_monomorphic_dereference(); // $ target=explicit_monomorphic_dereference
explicit_polymorphic_dereference(); // $ target=explicit_polymorphic_dereference
explicit_ref_dereference(); // $ target=explicit_ref_dereference
explicit_box_dereference(); // $ target=explicit_box_dereference
implicit_dereference(); // $ target=implicit_dereference
implicit_deref_coercion_cycle::test(); // $ target=test
}
2 changes: 1 addition & 1 deletion rust/ql/test/library-tests/type-inference/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2352,7 +2352,7 @@ mod loops {
#[rustfmt::skip]
let _ = while a < 10 // $ target=lt type=a:i64
{
a += 1; // $ type=a:i64 target=add_assign
a += 1; // $ type=a:i64 MISSING: target=add_assign
};
}
}
Expand Down
Loading