diff --git a/pyproject.toml b/pyproject.toml index c1a15bf60..fbdffff92 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,7 +65,7 @@ testing = [ linting = [ "ruff>=0.8.0,<1", - "mypy>=1.13.0,<2", + "mypy>=1.16.0,<2", "typing_extensions>=4.12.2", ] @@ -149,6 +149,7 @@ testpaths = [ "tests" ] [tool.coverage.run] branch = true source = [ "h2" ] +omit = [ "*/h2/_typing.py" ] [tool.coverage.report] fail_under = 100 @@ -190,7 +191,7 @@ commands = [ dependency_groups = ["linting"] commands = [ ["ruff", "check", "src/"], - ["mypy", "src/"], + ["mypy", "--strict-bytes", "src/", "tests/typing/strict_bytes.py"], ] [tool.tox.env.docs] diff --git a/src/h2/_typing.py b/src/h2/_typing.py new file mode 100644 index 000000000..64592fd73 --- /dev/null +++ b/src/h2/_typing.py @@ -0,0 +1,21 @@ +""" +h2/_typing +~~~~~~~~~~ + +Shared typing helpers. +""" +from __future__ import annotations + +from typing import Protocol + + +class Buffer(Protocol): + """ + An object implementing the PEP 688 buffer protocol. + """ + + def __buffer__(self, flags: int, /) -> memoryview: + """ + Return a memoryview over this object's bytes. + """ + ... diff --git a/src/h2/connection.py b/src/h2/connection.py index 6c5122024..1de3e9fb5 100644 --- a/src/h2/connection.py +++ b/src/h2/connection.py @@ -70,6 +70,8 @@ from hpack.struct import Header, HeaderWeaklyTyped + from ._typing import Buffer + class ConnectionState(Enum): IDLE = 0 @@ -1496,12 +1498,12 @@ def _inbound_flow_control_change_from_settings(self, old_value: int | None, new_ for stream in self.streams.values(): stream._inbound_flow_control_change_from_settings(delta) - def receive_data(self, data: bytes) -> list[Event]: + def receive_data(self, data: Buffer) -> list[Event]: """ Pass some received HTTP/2 data to the connection for handling. :param data: The data received from the remote peer on the network. - :type data: ``bytes`` + :type data: An object implementing the buffer protocol. :returns: A list of events that the remote peer triggered by sending this data. """ diff --git a/src/h2/frame_buffer.py b/src/h2/frame_buffer.py index 2555cdce5..c2f8a5411 100644 --- a/src/h2/frame_buffer.py +++ b/src/h2/frame_buffer.py @@ -7,11 +7,16 @@ """ from __future__ import annotations +from typing import TYPE_CHECKING + from hyperframe.exceptions import InvalidDataError, InvalidFrameError from hyperframe.frame import ContinuationFrame, Frame, HeadersFrame, PushPromiseFrame from .exceptions import FrameDataMissingError, FrameTooLargeError, ProtocolError +if TYPE_CHECKING: # pragma: no cover + from ._typing import Buffer + # To avoid a DOS attack based on sending loads of continuation frames, we limit # the maximum number we're prepared to receive. In this case, we'll set the # limit to 64, which means the largest encoded header block we can receive by @@ -36,25 +41,27 @@ def __init__(self, server: bool = False) -> None: self._preamble_len = len(self._preamble) self._headers_buffer: list[HeadersFrame | ContinuationFrame | PushPromiseFrame] = [] - def add_data(self, data: bytes) -> None: + def add_data(self, data: Buffer) -> None: """ Add more data to the frame buffer. :param data: A bytestring containing the byte buffer. """ + data_view = memoryview(data) + if self._preamble_len: - data_len = len(data) + data_len = len(data_view) of_which_preamble = min(self._preamble_len, data_len) - if self._preamble[:of_which_preamble] != data[:of_which_preamble]: + if self._preamble[:of_which_preamble] != data_view[:of_which_preamble]: msg = "Invalid HTTP/2 preamble." raise ProtocolError(msg) - data = data[of_which_preamble:] + data_view = data_view[of_which_preamble:] self._preamble_len -= of_which_preamble self._preamble = self._preamble[of_which_preamble:] - self._data += data + self._data += data_view def _validate_frame_length(self, length: int) -> None: """ diff --git a/tests/typing/strict_bytes.py b/tests/typing/strict_bytes.py new file mode 100644 index 000000000..9bfc35f00 --- /dev/null +++ b/tests/typing/strict_bytes.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +from h2.connection import H2Connection +from h2.frame_buffer import FrameBuffer + + +def receive_data_accepts_buffer_types( + connection: H2Connection, + frame_buffer: FrameBuffer, +) -> None: + bytearray_data = bytearray(b"") + memoryview_data = memoryview(b"") + + connection.receive_data(bytearray_data) + connection.receive_data(memoryview_data) + frame_buffer.add_data(bytearray_data) + frame_buffer.add_data(memoryview_data)