diff --git a/src/h2/events.py b/src/h2/events.py
index b81fd1a63..7a22f152c 100644
--- a/src/h2/events.py
+++ b/src/h2/events.py
@@ -16,7 +16,7 @@
from .settings import ChangedSetting, SettingCodes, Settings, _setting_code_from_int
if TYPE_CHECKING: # pragma: no cover
- from hpack import HeaderTuple
+ from hpack.struct import Header
from hyperframe.frame import Frame
from .errors import ErrorCodes
@@ -52,7 +52,7 @@ def __init__(self) -> None:
self.stream_id: int | None = None
#: The request headers.
- self.headers: list[HeaderTuple] | None = None
+ self.headers: list[Header] | None = None
#: If this request also ended the stream, the associated
#: :class:`StreamEnded
` event will be available
@@ -91,7 +91,7 @@ def __init__(self) -> None:
self.stream_id: int | None = None
#: The response headers.
- self.headers: list[HeaderTuple] | None = None
+ self.headers: list[Header] | None = None
#: If this response also ended the stream, the associated
#: :class:`StreamEnded ` event will be available
@@ -133,7 +133,7 @@ def __init__(self) -> None:
self.stream_id: int | None = None
#: The trailers themselves.
- self.headers: list[HeaderTuple] | None = None
+ self.headers: list[Header] | None = None
#: Trailers always end streams. This property has the associated
#: :class:`StreamEnded ` in it.
@@ -237,7 +237,7 @@ def __init__(self) -> None:
self.stream_id: int | None = None
#: The headers for this informational response.
- self.headers: list[HeaderTuple] | None = None
+ self.headers: list[Header] | None = None
#: If this response also had associated priority information, the
#: associated :class:`PriorityUpdated `
@@ -436,7 +436,7 @@ def __init__(self) -> None:
#: The error code given. Either one of :class:`ErrorCodes
#: ` or ``int``
- self.error_code: ErrorCodes | None = None
+ self.error_code: ErrorCodes | int | None = None
#: Whether the remote peer sent a RST_STREAM or we did.
self.remote_reset = True
@@ -460,7 +460,7 @@ def __init__(self) -> None:
self.parent_stream_id: int | None = None
#: The request headers, sent by the remote party in the push.
- self.headers: list[HeaderTuple] | None = None
+ self.headers: list[Header] | None = None
def __repr__(self) -> str:
return (
diff --git a/src/h2/stream.py b/src/h2/stream.py
index 7d4a12e35..3f6c97cd1 100644
--- a/src/h2/stream.py
+++ b/src/h2/stream.py
@@ -7,7 +7,7 @@
from __future__ import annotations
from enum import Enum, IntEnum
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING, Any, Union, cast
from hpack import HeaderTuple
from hyperframe.frame import AltSvcFrame, ContinuationFrame, DataFrame, Frame, HeadersFrame, PushPromiseFrame, RstStreamFrame, WindowUpdateFrame
@@ -46,7 +46,7 @@
from .windows import WindowManager
if TYPE_CHECKING: # pragma: no cover
- from collections.abc import Generator, Iterable
+ from collections.abc import Callable, Generator, Iterable
from hpack.hpack import Encoder
from hpack.struct import Header, HeaderWeaklyTyped
@@ -131,7 +131,7 @@ def __init__(self, stream_id: int) -> None:
# How the stream was closed. One of StreamClosedBy.
self.stream_closed_by: StreamClosedBy | None = None
- def process_input(self, input_: StreamInputs) -> Any:
+ def process_input(self, input_: StreamInputs) -> list[Event]:
"""
Process a specific input in the state machine.
"""
@@ -315,21 +315,23 @@ def recv_push_promise(self, previous_state: StreamState) -> list[Event]:
event.parent_stream_id = self.stream_id
return [event]
- def send_end_stream(self, previous_state: StreamState) -> None:
+ def send_end_stream(self, previous_state: StreamState) -> list[Event]:
"""
Called when an attempt is made to send END_STREAM in the
HALF_CLOSED_REMOTE state.
"""
self.stream_closed_by = StreamClosedBy.SEND_END_STREAM
+ return []
- def send_reset_stream(self, previous_state: StreamState) -> None:
+ def send_reset_stream(self, previous_state: StreamState) -> list[Event]:
"""
Called when an attempt is made to send RST_STREAM in a non-closed
stream state.
"""
self.stream_closed_by = StreamClosedBy.SEND_RST_STREAM
+ return []
- def reset_stream_on_error(self, previous_state: StreamState) -> None:
+ def reset_stream_on_error(self, previous_state: StreamState) -> list[Event]:
"""
Called when we need to forcefully emit another RST_STREAM frame on
behalf of the state machine.
@@ -350,7 +352,7 @@ def reset_stream_on_error(self, previous_state: StreamState) -> None:
error._events = [event]
raise error
- def recv_on_closed_stream(self, previous_state: StreamState) -> None:
+ def recv_on_closed_stream(self, previous_state: StreamState) -> list[Event]:
"""
Called when an unexpected frame is received on an already-closed
stream.
@@ -362,7 +364,7 @@ def recv_on_closed_stream(self, previous_state: StreamState) -> None:
"""
raise StreamClosedError(self.stream_id)
- def send_on_closed_stream(self, previous_state: StreamState) -> None:
+ def send_on_closed_stream(self, previous_state: StreamState) -> list[Event]:
"""
Called when an attempt is made to send data on an already-closed
stream.
@@ -374,7 +376,7 @@ def send_on_closed_stream(self, previous_state: StreamState) -> None:
"""
raise StreamClosedError(self.stream_id)
- def recv_push_on_closed_stream(self, previous_state: StreamState) -> None:
+ def recv_push_on_closed_stream(self, previous_state: StreamState) -> list[Event]:
"""
Called when a PUSH_PROMISE frame is received on a full stop
stream.
@@ -393,7 +395,7 @@ def recv_push_on_closed_stream(self, previous_state: StreamState) -> None:
msg = "Attempted to push on closed stream."
raise ProtocolError(msg)
- def send_push_on_closed_stream(self, previous_state: StreamState) -> None:
+ def send_push_on_closed_stream(self, previous_state: StreamState) -> list[Event]:
"""
Called when an attempt is made to push on an already-closed stream.
@@ -473,7 +475,7 @@ def recv_alt_svc(self, previous_state: StreamState) -> list[Event]:
# the event and let it get populated.
return [AlternativeServiceAvailable()]
- def send_alt_svc(self, previous_state: StreamState) -> None:
+ def send_alt_svc(self, previous_state: StreamState) -> list[Event]:
"""
Called when sending an ALTSVC frame on this stream.
@@ -489,6 +491,7 @@ def send_alt_svc(self, previous_state: StreamState) -> None:
if self.headers_sent:
msg = "Cannot send ALTSVC after sending response headers."
raise ProtocolError(msg)
+ return []
@@ -561,7 +564,10 @@ def send_alt_svc(self, previous_state: StreamState) -> None:
# (state, input) to tuples of (side_effect_function, end_state). This
# map contains all allowed transitions: anything not in this map is
# invalid and immediately causes a transition to ``closed``.
-_transitions = {
+_transitions: dict[
+ tuple[StreamState, StreamInputs],
+ tuple[Callable[[H2StreamStateMachine, StreamState], list[Event]] | None, StreamState],
+] = {
# State: idle
(StreamState.IDLE, StreamInputs.SEND_HEADERS):
(H2StreamStateMachine.request_sent, StreamState.OPEN),
@@ -1040,10 +1046,11 @@ def receive_push_promise_in_band(self,
events = self.state_machine.process_input(
StreamInputs.RECV_PUSH_PROMISE,
)
- events[0].pushed_stream_id = promised_stream_id
+ push_event = cast(PushedStreamReceived, events[0])
+ push_event.pushed_stream_id = promised_stream_id
hdr_validation_flags = self._build_hdr_validation_flags(events)
- events[0].headers = self._process_received_headers(
+ push_event.headers = self._process_received_headers(
headers, hdr_validation_flags, header_encoding,
)
return [], events
@@ -1077,22 +1084,30 @@ def receive_headers(self,
input_ = StreamInputs.RECV_HEADERS
events = self.state_machine.process_input(input_)
+ headers_event = cast(
+ Union[RequestReceived, ResponseReceived, TrailersReceived, InformationalResponseReceived],
+ events[0],
+ )
if end_stream:
es_events = self.state_machine.process_input(
StreamInputs.RECV_END_STREAM,
)
- events[0].stream_ended = es_events[0]
+ # We ensured it's not an information response at the beginning of the method.
+ cast(
+ Union[RequestReceived, ResponseReceived, TrailersReceived],
+ headers_event,
+ ).stream_ended = cast(StreamEnded, es_events[0])
events += es_events
self._initialize_content_length(headers)
- if isinstance(events[0], TrailersReceived) and not end_stream:
+ if isinstance(headers_event, TrailersReceived) and not end_stream:
msg = "Trailers must have END_STREAM set"
raise ProtocolError(msg)
hdr_validation_flags = self._build_hdr_validation_flags(events)
- events[0].headers = self._process_received_headers(
+ headers_event.headers = self._process_received_headers(
headers, hdr_validation_flags, header_encoding,
)
return [], events
@@ -1106,6 +1121,7 @@ def receive_data(self, data: bytes, end_stream: bool, flow_control_len: int) ->
"set to %d", self, end_stream, flow_control_len,
)
events = self.state_machine.process_input(StreamInputs.RECV_DATA)
+ data_event = cast(DataReceived, events[0])
self._inbound_window_manager.window_consumed(flow_control_len)
self._track_content_length(len(data), end_stream)
@@ -1113,11 +1129,11 @@ def receive_data(self, data: bytes, end_stream: bool, flow_control_len: int) ->
es_events = self.state_machine.process_input(
StreamInputs.RECV_END_STREAM,
)
- events[0].stream_ended = es_events[0]
+ data_event.stream_ended = cast(StreamEnded, es_events[0])
events.extend(es_events)
- events[0].data = data
- events[0].flow_controlled_length = flow_control_len
+ data_event.data = data
+ data_event.flow_controlled_length = flow_control_len
return [], events
def receive_window_update(self, increment: int) -> tuple[list[Frame], list[Event]]:
@@ -1137,7 +1153,7 @@ def receive_window_update(self, increment: int) -> tuple[list[Frame], list[Event
# this should be treated as a *stream* error, not a *connection* error.
# That means we need to catch the error and forcibly close the stream.
if events:
- events[0].delta = increment
+ cast(WindowUpdated, events[0]).delta = increment
try:
self.outbound_flow_control_window = guard_increment_window(
self.outbound_flow_control_window,
@@ -1220,7 +1236,7 @@ def stream_reset(self, frame: RstStreamFrame) -> tuple[list[Frame], list[Event]]
if events:
# We don't fire an event if this stream is already closed.
- events[0].error_code = _error_code_from_int(frame.error_code)
+ cast(StreamReset, events[0]).error_code = _error_code_from_int(frame.error_code)
return [], events
@@ -1322,7 +1338,7 @@ def _build_headers_frames(self,
def _process_received_headers(self,
headers: Iterable[Header],
header_validation_flags: HeaderValidationFlags,
- header_encoding: bool | str | None) -> Iterable[Header]:
+ header_encoding: bool | str | None) -> list[Header]:
"""
When headers have been received from the remote peer, run a processing
pipeline on them to transform them into the appropriate form for