Skip to content

feat: index-range err code #19513

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 14 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
freshen_all_functions_type_vars,
freshen_function_type_vars,
)
from mypy.exprlength import get_static_expr_length
from mypy.infer import ArgumentInferContext, infer_function_type_arguments, infer_type_arguments
from mypy.literals import literal
from mypy.maptype import map_instance_to_supertype
Expand Down Expand Up @@ -4462,6 +4463,7 @@ def visit_index_with_type(
# Allow special forms to be indexed and used to create union types
return self.named_type("typing._SpecialForm")
else:
self.static_index_range_check(left_type, e, index)
result, method_type = self.check_method_call_by_name(
"__getitem__",
left_type,
Expand All @@ -4484,6 +4486,46 @@ def min_tuple_length(self, left: TupleType) -> int:
return left.length() - 1 + unpack.type.min_len
return left.length() - 1

def static_index_range_check(self, left_type: Type, e: IndexExpr, index: Expression) -> None:
if isinstance(left_type, Instance) and left_type.type.fullname in (
"builtins.list",
"builtins.tuple",
"builtins.str",
"builtins.bytes",
):
idx_val = None
# Try to extract integer literal index
if isinstance(index, IntExpr):
idx_val = index.value
elif isinstance(index, UnaryExpr):
if index.op == "-":
operand = index.expr
if isinstance(operand, IntExpr):
idx_val = -operand.value
elif index.op == "+":
operand = index.expr
if isinstance(operand, IntExpr):
idx_val = operand.value
# Could add more cases (e.g. LiteralType) if desired
if idx_val is not None:
length = get_static_expr_length(e.base)
if length is not None:
# For negative indices, Python counts from the end
check_idx = idx_val
if check_idx < 0:
check_idx += length
if not (0 <= check_idx < length):
name = ""
if isinstance(e.base, NameExpr):
name = e.base.name
self.chk.fail(
message_registry.SEQUENCE_INDEX_OUT_OF_RANGE.format(
name=name or "<expr>", length=length
),
e,
code=message_registry.SEQUENCE_INDEX_OUT_OF_RANGE.code,
)

def visit_tuple_index_helper(self, left: TupleType, n: int) -> Type | None:
unpack_index = find_unpack_in_list(left.items)
if unpack_index is None:
Expand Down
4 changes: 4 additions & 0 deletions mypy/errorcodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,5 +326,9 @@ def __hash__(self) -> int:
default_enabled=False,
)

INDEX_RANGE: Final[ErrorCode] = ErrorCode(
"index-range", "index out of statically known range", "Index Range", True
)

# This copy will not include any error codes defined later in the plugins.
mypy_error_codes = error_codes.copy()
174 changes: 174 additions & 0 deletions mypy/exprlength.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
"""Static expression length analysis utilities for mypy.

Provides helpers for statically determining the length of expressions,
when possible.
"""

from typing import List, Optional, Tuple

from mypy.nodes import (
ARG_POS,
AssignmentStmt,
Block,
BytesExpr,
CallExpr,
ClassDef,
DictExpr,
Expression,
ExpressionStmt,
ForStmt,
FuncDef,
GeneratorExpr,
GlobalDecl,
IfStmt,
ListComprehension,
ListExpr,
MemberExpr,
NameExpr,
NonlocalDecl,
OverloadedFuncDef,
SetExpr,
StarExpr,
StrExpr,
TryStmt,
TupleExpr,
WhileStmt,
WithStmt,
is_IntExpr_list,
)


def get_static_expr_length(expr: Expression, context: Optional[Block] = None) -> Optional[int]:
"""Try to statically determine the length of an expression.

Returns the length if it can be determined at type-check time,
otherwise returns None.

If context is provided, will attempt to resolve NameExpr/Var assignments.
"""
# NOTE: currently only used for indexing but could be extended to flag
# fun things like list.pop or to allow len([1, 2, 3]) to type check as Literal[3]

# List, tuple literals (with possible star expressions)
if isinstance(expr, (ListExpr, TupleExpr)):
stars = [get_static_expr_length(i, context) for i in expr.items if isinstance(i, StarExpr)]
if None not in stars:
# if there are no star expressions, or we know the
# length of them, we know the length of the expression
other = sum(not isinstance(i, StarExpr) for i in expr.items)
return other + sum(star for star in stars if star is not None)
elif isinstance(expr, SetExpr):
# TODO: set expressions are more complicated, you need to know the
# actual value of each item in order to confidently state its length
pass
elif isinstance(expr, DictExpr):
# TODO: same as with sets, dicts are more complicated since you need
# to know the specific value of each key, and ensure they don't collide
pass
# String or bytes literal
elif isinstance(expr, (StrExpr, BytesExpr)):
return len(expr.value)
elif isinstance(expr, ListComprehension):
# If the generator's length is known, the list's length is known
return get_static_expr_length(expr.generator, context)
elif isinstance(expr, GeneratorExpr):
# If there is only one sequence and no conditions, and we know
# the sequence length, we know the max number of items yielded
# from the genexp and can pass that info forward
if len(expr.sequences) == 1 and len(expr.condlists) == 0:
return get_static_expr_length(expr.sequences[0], context)
# range() with constant arguments
elif isinstance(expr, CallExpr):
callee = expr.callee
if isinstance(callee, NameExpr) and callee.fullname == "builtins.range":
args = expr.args
if is_IntExpr_list(args) and all(kind == ARG_POS for kind in expr.arg_kinds):
if len(args) == 1:
# range(stop)
stop = args[0].value
return max(0, stop)
elif len(args) == 2:
# range(start, stop)
start, stop = args[0].value, args[1].value
return max(0, stop - start)
elif len(args) == 3:
# range(start, stop, step)
start, stop, step = args[0].value, args[1].value, args[2].value
if step == 0:
return None
n = (stop - start + (step - (1 if step > 0 else -1))) // step
return max(0, n)
# We have a big spaghetti monster of special case logic to resolve name expressions
elif isinstance(expr, NameExpr):
# Try to resolve the value of a local variable if possible
if context is None:
# Cannot resolve without context
return None
assignments: List[Tuple[AssignmentStmt, int]] = []

# Iterate thru all statements in the block
for stmt in context.body:
if isinstance(
stmt,
(
IfStmt,
ForStmt,
WhileStmt,
TryStmt,
WithStmt,
FuncDef,
OverloadedFuncDef,
ClassDef,
),
):
# These statements complicate things and render the whole block useless
return None
elif isinstance(stmt, (GlobalDecl, NonlocalDecl)) and expr.name in stmt.names:
# We cannot assure the value of a global or nonlocal
return None
elif stmt.line >= expr.line:
# We can stop our analysis at the line where the name is used
break
# Check for any assignments
elif isinstance(stmt, AssignmentStmt):
# First, exit if any assignment has a rhs expression that
# could mutate the name
# TODO Write logic to recursively unwrap statements to see
# if any internal statements mess with our var

# Iterate thru lvalues in the assignment
for idx, lval in enumerate(stmt.lvalues):
# Check if any of them matches our variable
if isinstance(lval, NameExpr) and lval.name == expr.name:
assignments.append((stmt, idx))
elif isinstance(stmt, ExpressionStmt):
if isinstance(stmt.expr, CallExpr):
callee = stmt.expr.callee
for arg in stmt.expr.args:
if isinstance(arg, NameExpr) and arg.name == expr.name:
# our var was passed to a function as an input,
# it could be mutated now
return None
if (
isinstance(callee, MemberExpr)
and isinstance(callee.expr, NameExpr)
and callee.expr.name == expr.name
):
return None

# For now, we only attempt to resolve the length
# when the name was only ever assigned to once
if len(assignments) != 1:
return None
stmt, idx = assignments[0]
rvalue = stmt.rvalue
# If single lvalue, just use rvalue
if len(stmt.lvalues) == 1:
return get_static_expr_length(rvalue, context)
# If multiple lvalues, try to extract the corresponding value
elif isinstance(rvalue, (TupleExpr, ListExpr)):
if len(rvalue.items) == len(stmt.lvalues):
return get_static_expr_length(rvalue.items[idx], context)
# Otherwise, cannot determine
# Could add more cases (e.g. dicts, sets) in the future
return None
4 changes: 4 additions & 0 deletions mypy/message_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,3 +367,7 @@ def with_additional_msg(self, info: str) -> ErrorMessage:
TYPE_ALIAS_WITH_AWAIT_EXPRESSION: Final = ErrorMessage(
"Await expression cannot be used within a type alias", codes.SYNTAX
)

SEQUENCE_INDEX_OUT_OF_RANGE = ErrorMessage(
"Sequence index out of range: {name!r} only has {length} items", code=codes.INDEX_RANGE
)
6 changes: 5 additions & 1 deletion mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from collections import defaultdict
from collections.abc import Iterator, Sequence
from enum import Enum, unique
from typing import TYPE_CHECKING, Any, Callable, Final, Optional, TypeVar, Union, cast
from typing import TYPE_CHECKING, Any, Callable, Final, List, Optional, TypeVar, Union, cast
from typing_extensions import TypeAlias as _TypeAlias, TypeGuard

from mypy_extensions import trait
Expand Down Expand Up @@ -1761,6 +1761,10 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T:
return visitor.visit_int_expr(self)


def is_IntExpr_list(items: List[Expression]) -> TypeGuard[List[IntExpr]]:
return all(isinstance(item, IntExpr) for item in items)


# How mypy uses StrExpr and BytesExpr:
#
# b'x' -> BytesExpr
Expand Down
12 changes: 6 additions & 6 deletions mypyc/test-data/run-exceptions.test
Original file line number Diff line number Diff line change
Expand Up @@ -125,15 +125,15 @@ def g(b: bool) -> None:
try:
if b:
x = [0]
x[1]
x[1] # type: ignore [index-range]
else:
raise Exception('hi')
except:
print("caught!")

def r(x: int) -> None:
if x == 0:
[0][1]
[0][1] # type: ignore [index-range]
elif x == 1:
raise Exception('hi')
elif x == 2:
Expand Down Expand Up @@ -263,7 +263,7 @@ Traceback (most recent call last):
File "native.py", line 44, in i
r(0)
File "native.py", line 15, in r
[0][1]
[0][1] # type: ignore [index-range]
IndexError: list index out of range
== k ==
Traceback (most recent call last):
Expand All @@ -281,7 +281,7 @@ Traceback (most recent call last):
File "native.py", line 61, in k
r(0)
File "native.py", line 15, in r
[0][1]
[0][1] # type: ignore [index-range]
IndexError: list index out of range
== g ==
caught!
Expand Down Expand Up @@ -330,7 +330,7 @@ Traceback (most recent call last):
File "native.py", line 61, in k
r(0)
File "native.py", line 15, in r
[0][1]
[0][1] # type: ignore [index-range]
IndexError: list index out of range
== g ==
caught!
Expand Down Expand Up @@ -371,7 +371,7 @@ def b(b1: int, b2: int) -> str:
if b1 == 1:
raise Exception('hi')
elif b1 == 2:
[0][1]
[0][1] # type: ignore [index-range]
elif b1 == 3:
return 'try'
except IndexError:
Expand Down
2 changes: 1 addition & 1 deletion mypyc/test-data/run-misc.test
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def f(a: bool, b: bool) -> None:

def g() -> None:
try:
[0][1]
[0][1] # type: ignore [index-range]
y = 1
except Exception:
pass
Expand Down
Loading