diff --git a/src/mcp/server/connection.py b/src/mcp/server/connection.py index df3652ce0..5991715a4 100644 --- a/src/mcp/server/connection.py +++ b/src/mcp/server/connection.py @@ -1,9 +1,10 @@ """`Connection` — per-client connection state and the standalone outbound channel. Always present on `Context` (never ``None``), even in stateless deployments. -Holds peer info populated at ``initialize`` time, the per-connection lifespan -output, and an `Outbound` for the standalone stream (the SSE GET stream in -streamable HTTP, or the single duplex stream in stdio). +Holds peer info populated at ``initialize`` time, per-connection scratch +``state`` and an ``exit_stack`` for teardown, and an `Outbound` for the +standalone stream (the SSE GET stream in streamable HTTP, or the single duplex +stream in stdio). `notify` is best-effort: it never raises. If there's no standalone channel (stateless HTTP) or the stream has been dropped, the notification is @@ -14,6 +15,7 @@ import logging from collections.abc import Mapping +from contextlib import AsyncExitStack from typing import Any import anyio @@ -44,17 +46,27 @@ class Connection(TypedServerRequestMixin): ``None`` until ``initialize`` completes; ``initialized`` is set then. """ - def __init__(self, outbound: Outbound, *, has_standalone_channel: bool) -> None: + def __init__(self, outbound: Outbound, *, has_standalone_channel: bool, session_id: str | None = None) -> None: self._outbound = outbound self.has_standalone_channel = has_standalone_channel + self.session_id: str | None = session_id self.client_info: Implementation | None = None self.client_capabilities: ClientCapabilities | None = None self.protocol_version: str | None = None self.initialized: anyio.Event = anyio.Event() - # TODO: make this generic (Connection[StateT]) once connection_lifespan - # wiring lands in ServerRunner. - self.state: Any = None + + self.state: dict[str, Any] = {} + """Per-connection scratch state. Handlers and middleware may read and + write freely; persists across requests on this connection.""" + + self.exit_stack: AsyncExitStack = AsyncExitStack() + """Cleanup stack unwound by `ServerRunner` when the connection closes. + + Push context managers (``await exit_stack.enter_async_context(...)``) + or callbacks (``exit_stack.push_async_callback(...)``) from handlers or + middleware to register per-connection teardown. Unwound LIFO after + `dispatcher.run()` returns, shielded from cancellation.""" async def send_raw_request( self, diff --git a/src/mcp/server/context.py b/src/mcp/server/context.py index 1c855ae48..d1514a9ad 100644 --- a/src/mcp/server/context.py +++ b/src/mcp/server/context.py @@ -1,6 +1,6 @@ from __future__ import annotations -from collections.abc import Awaitable, Callable +from collections.abc import Awaitable, Callable, Mapping from dataclasses import dataclass from typing import Any, Generic, Protocol @@ -33,10 +33,9 @@ class ServerRequestContext(RequestContext[ServerSession], Generic[LifespanContex LifespanT = TypeVar("LifespanT", default=Any, covariant=True) -TransportT = TypeVar("TransportT", bound=TransportContext, default=TransportContext, covariant=True) -class Context(BaseContext[TransportT], PeerMixin, TypedServerRequestMixin, Generic[LifespanT, TransportT]): +class Context(BaseContext[TransportContext], PeerMixin, TypedServerRequestMixin, Generic[LifespanT]): """Server-side per-request context. Composes `BaseContext` (forwards to `DispatchContext`, satisfies `Outbound`), @@ -50,7 +49,7 @@ class Context(BaseContext[TransportT], PeerMixin, TypedServerRequestMixin, Gener def __init__( self, - dctx: DispatchContext[TransportT], + dctx: DispatchContext[TransportContext], *, lifespan: LifespanT, connection: Connection, @@ -70,6 +69,23 @@ def connection(self) -> Connection: """The per-client `Connection` for this request's connection.""" return self._connection + @property + def session_id(self) -> str | None: + """The transport's session id for this connection, when one exists. + + Convenience for ``ctx.connection.session_id``. ``None`` on stdio and + stateless HTTP. + """ + return self._connection.session_id + + @property + def headers(self) -> Mapping[str, str] | None: + """Request headers carried by this message, when the transport has them. + + Convenience for ``ctx.transport.headers``. ``None`` on stdio. + """ + return self.transport.headers + async def log(self, level: LoggingLevel, data: Any, logger: str | None = None, *, meta: Meta | None = None) -> None: """Send a request-scoped ``notifications/message`` log entry. @@ -94,7 +110,7 @@ async def log(self, level: LoggingLevel, data: Any, logger: str | None = None, * _MwLifespanT = TypeVar("_MwLifespanT", contravariant=True) -class ContextMiddleware(Protocol[_MwLifespanT]): +class ServerMiddleware(Protocol[_MwLifespanT]): """Context-tier middleware: ``(ctx, method, typed_params, call_next) -> result``. Runs *inside* `ServerRunner._on_request` after params validation and @@ -102,15 +118,15 @@ class ContextMiddleware(Protocol[_MwLifespanT]): not ``initialize``, ``METHOD_NOT_FOUND``, or validation failures. Listed outermost-first on `Server.middleware`. - `Server[L].middleware` holds `ContextMiddleware[L]`, so an app-specific - middleware sees `ctx.lifespan: L`. A reusable middleware (no app-specific - types) can be typed `ContextMiddleware[object]` — `Context` is covariant in - `LifespanT`, so it registers on any `Server[L]`. + `Server[L].middleware` holds `ServerMiddleware[L]`, so an app-specific + middleware sees `ctx.lifespan: L`. A reusable middleware can be typed + `ServerMiddleware[object]` — `Context` is covariant in `LifespanT`, so it + registers on any `Server[L]`. """ async def __call__( self, - ctx: Context[_MwLifespanT, TransportContext], + ctx: Context[_MwLifespanT], method: str, params: BaseModel, call_next: CallNext, diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 9dc44708f..fcb77500b 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -41,11 +41,13 @@ async def main(): import warnings from collections.abc import AsyncIterator, Awaitable, Callable from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager +from dataclasses import dataclass from importlib.metadata import version as importlib_version from typing import Any, Generic, cast import anyio from opentelemetry.trace import SpanKind, StatusCode +from pydantic import BaseModel from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.authentication import AuthenticationMiddleware @@ -58,7 +60,7 @@ async def main(): from mcp.server.auth.provider import OAuthAuthorizationServerProvider, TokenVerifier from mcp.server.auth.routes import build_resource_metadata_url, create_auth_routes, create_protected_resource_routes from mcp.server.auth.settings import AuthSettings -from mcp.server.context import ContextMiddleware, ServerRequestContext +from mcp.server.context import HandlerResult, ServerMiddleware, ServerRequestContext from mcp.server.experimental.request_context import Experimental from mcp.server.lowlevel.experimental import ExperimentalHandlers from mcp.server.models import InitializationOptions @@ -76,6 +78,30 @@ async def main(): LifespanResultT = TypeVar("LifespanResultT", default=Any) +_ParamsT = TypeVar("_ParamsT", bound=BaseModel, default=BaseModel) + +RequestHandler = Callable[[ServerRequestContext[LifespanResultT], _ParamsT], Awaitable[HandlerResult]] +"""A registered request handler: ``(ctx, params) -> result``.""" + +NotificationHandler = Callable[[ServerRequestContext[LifespanResultT], _ParamsT], Awaitable[None]] +"""A registered notification handler: ``(ctx, params) -> None``.""" + + +@dataclass(frozen=True, slots=True) +class HandlerEntry(Generic[LifespanResultT]): + """A registered handler and the params model to validate incoming params against. + + Stored in `Server._request_handlers` / `_notification_handlers` and consumed + by `ServerRunner` to validate, build `Context`, and invoke. The handler's + second-argument type is erased to ``Any`` in storage (each entry has a + different concrete params type and `Callable` parameters are contravariant); + the precise type is recoverable via `params_type`. The correlation is + enforced at registration time by `Server.add_request_handler`. + """ + + params_type: type[BaseModel] + handler: RequestHandler[LifespanResultT, Any] + class NotificationOptions: def __init__(self, prompts_changed: bool = False, resources_changed: bool = False, tools_changed: bool = False): @@ -85,7 +111,7 @@ def __init__(self, prompts_changed: bool = False, resources_changed: bool = Fals @asynccontextmanager -async def lifespan(_: Server[LifespanResultT]) -> AsyncIterator[dict[str, Any]]: +async def lifespan(_: Server[Any]) -> AsyncIterator[dict[str, Any]]: """Default lifespan context manager that does nothing. Returns: @@ -109,6 +135,8 @@ def __init__( instructions: str | None = None, website_url: str | None = None, icons: list[types.Icon] | None = None, + notification_options: NotificationOptions | None = None, + experimental_capabilities: dict[str, dict[str, Any]] | None = None, lifespan: Callable[ [Server[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT], @@ -193,57 +221,77 @@ def __init__( self.website_url = website_url self.icons = icons self.lifespan = lifespan - self._request_handlers: dict[str, Callable[[ServerRequestContext[LifespanResultT], Any], Awaitable[Any]]] = {} - self._notification_handlers: dict[ - str, Callable[[ServerRequestContext[LifespanResultT], Any], Awaitable[None]] - ] = {} + self._notification_options = notification_options or NotificationOptions() + self._experimental_capabilities = experimental_capabilities or {} + self._request_handlers: dict[str, HandlerEntry[LifespanResultT]] = {} + self._notification_handlers: dict[str, HandlerEntry[LifespanResultT]] = {} self._experimental_handlers: ExperimentalHandlers[LifespanResultT] | None = None self._session_manager: StreamableHTTPSessionManager | None = None # Context-tier middleware consumed by `ServerRunner`. Additive; the # existing `run()` path ignores it. - self.middleware: list[ContextMiddleware[LifespanResultT]] = [] + self.middleware: list[ServerMiddleware[LifespanResultT]] = [] logger.debug("Initializing server %r", name) - # Populate internal handler dicts from on_* kwargs - self._request_handlers.update( - { - method: handler - for method, handler in { - "ping": on_ping, - "prompts/list": on_list_prompts, - "prompts/get": on_get_prompt, - "resources/list": on_list_resources, - "resources/templates/list": on_list_resource_templates, - "resources/read": on_read_resource, - "resources/subscribe": on_subscribe_resource, - "resources/unsubscribe": on_unsubscribe_resource, - "tools/list": on_list_tools, - "tools/call": on_call_tool, - "logging/setLevel": on_set_logging_level, - "completion/complete": on_completion, - }.items() - if handler is not None - } - ) + _spec_requests: list[tuple[str, type[BaseModel], RequestHandler[LifespanResultT, Any] | None]] = [ + ("ping", types.RequestParams, on_ping), + ("prompts/list", types.PaginatedRequestParams, on_list_prompts), + ("prompts/get", types.GetPromptRequestParams, on_get_prompt), + ("resources/list", types.PaginatedRequestParams, on_list_resources), + ("resources/templates/list", types.PaginatedRequestParams, on_list_resource_templates), + ("resources/read", types.ReadResourceRequestParams, on_read_resource), + ("resources/subscribe", types.SubscribeRequestParams, on_subscribe_resource), + ("resources/unsubscribe", types.UnsubscribeRequestParams, on_unsubscribe_resource), + ("tools/list", types.PaginatedRequestParams, on_list_tools), + ("tools/call", types.CallToolRequestParams, on_call_tool), + ("logging/setLevel", types.SetLevelRequestParams, on_set_logging_level), + ("completion/complete", types.CompleteRequestParams, on_completion), + ] + self._request_handlers.update({m: HandlerEntry(pt, h) for m, pt, h in _spec_requests if h is not None}) + _spec_notifications: list[tuple[str, type[BaseModel], NotificationHandler[LifespanResultT, Any] | None]] = [ + ("notifications/roots/list_changed", types.NotificationParams, on_roots_list_changed), + ("notifications/progress", types.ProgressNotificationParams, on_progress), + ] self._notification_handlers.update( - { - method: handler - for method, handler in { - "notifications/roots/list_changed": on_roots_list_changed, - "notifications/progress": on_progress, - }.items() - if handler is not None - } + {m: HandlerEntry(pt, h) for m, pt, h in _spec_notifications if h is not None} ) + def add_request_handler( + self, + method: str, + params_type: type[_ParamsT], + handler: RequestHandler[LifespanResultT, _ParamsT], + ) -> None: + """Register a request handler for ``method``. + + ``params_type`` is the model incoming params are validated against + before the handler is invoked. It should subclass `RequestParams` so + ``_meta`` parses uniformly. Replaces any existing handler for the same + method (no collision guard against spec methods). + """ + self._request_handlers[method] = HandlerEntry(params_type, handler) + + def add_notification_handler( + self, + method: str, + params_type: type[_ParamsT], + handler: NotificationHandler[LifespanResultT, _ParamsT], + ) -> None: + """Register a notification handler for ``method``. + + ``params_type`` should subclass `NotificationParams` so ``_meta`` + parses uniformly. Replaces any existing handler. + """ + self._notification_handlers[method] = HandlerEntry(params_type, handler) + def _add_request_handler( self, method: str, - handler: Callable[[ServerRequestContext[LifespanResultT], Any], Awaitable[Any]], + handler: RequestHandler[LifespanResultT, Any], ) -> None: - """Add a request handler, silently replacing any existing handler for the same method.""" - self._request_handlers[method] = handler + # TODO: remove once experimental tasks plumbing and remaining callers + # migrate to `add_request_handler` with an explicit params_type. + self.add_request_handler(method, types.RequestParams, handler) def _has_handler(self, method: str) -> bool: """Check if a handler is registered for the given method.""" @@ -251,14 +299,18 @@ def _has_handler(self, method: str) -> bool: # --- ServerRegistry protocol (consumed by ServerRunner) ------------------ - def get_request_handler(self, method: str) -> Callable[..., Awaitable[Any]] | None: - """Return the handler for a request method, or ``None``.""" + def get_request_handler(self, method: str) -> HandlerEntry[LifespanResultT] | None: + """Return the registered entry for a request method, or ``None``.""" return self._request_handlers.get(method) - def get_notification_handler(self, method: str) -> Callable[..., Awaitable[Any]] | None: - """Return the handler for a notification method, or ``None``.""" + def get_notification_handler(self, method: str) -> HandlerEntry[LifespanResultT] | None: + """Return the registered entry for a notification method, or ``None``.""" return self._notification_handlers.get(method) + def capabilities(self) -> types.ServerCapabilities: + """Derive `ServerCapabilities` from registered handlers and constructor options.""" + return self.get_capabilities(self._notification_options, self._experimental_capabilities) + # TODO: Rethink capabilities API. Currently capabilities are derived from registered # handlers but require NotificationOptions to be passed externally for list_changed # flags, and experimental_capabilities as a separate dict. Consider deriving capabilities @@ -474,7 +526,8 @@ async def _handle_request( attributes={"mcp.method.name": req.method, "jsonrpc.request.id": message.request_id}, context=parent_context, ) as span: - if handler := self._request_handlers.get(req.method): + if entry := self._request_handlers.get(req.method): + handler = entry.handler logger.debug("Dispatching request of type %s", type(req).__name__) try: @@ -533,7 +586,8 @@ async def _handle_request( span.set_status(StatusCode.ERROR, response.message) try: - await message.respond(response) + # TODO: cast goes away when `_handle_request` is deleted. + await message.respond(cast(types.ServerResult | types.ErrorData, response)) except (anyio.BrokenResourceError, anyio.ClosedResourceError): # Transport closed between handler unblocking and respond. Happens # when _receive_loop's finally wakes a handler blocked on @@ -552,7 +606,8 @@ async def _handle_notification( session: ServerSession, lifespan_context: LifespanResultT, ) -> None: - if handler := self._notification_handlers.get(notify.method): + if entry := self._notification_handlers.get(notify.method): + handler = entry.handler logger.debug("Dispatching notification of type %s", type(notify).__name__) try: diff --git a/src/mcp/server/runner.py b/src/mcp/server/runner.py index bb3af0443..1ba732ec4 100644 --- a/src/mcp/server/runner.py +++ b/src/mcp/server/runner.py @@ -10,17 +10,16 @@ `Context`, runs the middleware chain, returns the result dict * drives ``dispatcher.run()`` and the per-connection lifespan -`ServerRunner` consumes any `ServerRegistry` — the lowlevel `Server` satisfies -it via additive methods so the existing ``Server.run()`` path is unaffected. +`ServerRunner` holds a `Server` directly — `Server` is the registry. """ from __future__ import annotations import logging -from collections.abc import Awaitable, Callable, Mapping, Sequence +from collections.abc import Mapping from dataclasses import dataclass, field from functools import partial, reduce -from typing import Any, Generic, Protocol, cast +from typing import Any, Generic, cast import anyio.abc from opentelemetry.trace import SpanKind, StatusCode @@ -28,8 +27,8 @@ from typing_extensions import TypeVar from mcp.server.connection import Connection -from mcp.server.context import CallNext, Context, ContextMiddleware -from mcp.server.lowlevel.server import NotificationOptions +from mcp.server.context import CallNext, Context, ServerMiddleware +from mcp.server.lowlevel.server import Server from mcp.shared._otel import extract_trace_context, otel_span from mcp.shared.dispatcher import DispatchContext, Dispatcher, DispatchMiddleware, OnRequest from mcp.shared.exceptions import MCPError @@ -38,87 +37,20 @@ INVALID_REQUEST, LATEST_PROTOCOL_VERSION, METHOD_NOT_FOUND, - CallToolRequestParams, - CompleteRequestParams, - GetPromptRequestParams, Implementation, InitializeRequestParams, InitializeResult, - NotificationParams, - PaginatedRequestParams, - ProgressNotificationParams, - ReadResourceRequestParams, - RequestParams, - ServerCapabilities, - SetLevelRequestParams, - SubscribeRequestParams, - UnsubscribeRequestParams, ) -__all__ = ["CallNext", "ContextMiddleware", "ServerRegistry", "ServerRunner", "otel_middleware"] +__all__ = ["CallNext", "ServerMiddleware", "ServerRunner", "otel_middleware"] logger = logging.getLogger(__name__) LifespanT = TypeVar("LifespanT", default=Any) -ServerTransportT = TypeVar("ServerTransportT", bound=TransportContext, default=TransportContext) - -Handler = Callable[..., Awaitable[Any]] -"""A request/notification handler: ``(ctx, params) -> result``. Typed loosely -so the existing `ServerRequestContext`-based handlers and the new -`Context`-based handlers both fit during the transition. -""" _INIT_EXEMPT: frozenset[str] = frozenset({"ping"}) -# TODO: remove this lookup once `Server` stores (params_type, handler) in its -# registry directly. This is scaffolding so ServerRunner can validate params -# without changing the existing `_request_handlers` dict shape. -_PARAMS_FOR_METHOD: dict[str, type[BaseModel]] = { - "ping": RequestParams, - "tools/list": PaginatedRequestParams, - "tools/call": CallToolRequestParams, - "prompts/list": PaginatedRequestParams, - "prompts/get": GetPromptRequestParams, - "resources/list": PaginatedRequestParams, - "resources/templates/list": PaginatedRequestParams, - "resources/read": ReadResourceRequestParams, - "resources/subscribe": SubscribeRequestParams, - "resources/unsubscribe": UnsubscribeRequestParams, - "logging/setLevel": SetLevelRequestParams, - "completion/complete": CompleteRequestParams, -} -"""Spec method → params model. Scaffolding while the lowlevel `Server`'s -`_request_handlers` stores handler-only; the registry refactor should make this -the registry's responsibility (or store params types alongside handlers).""" - -_PARAMS_FOR_NOTIFICATION: dict[str, type[BaseModel]] = { - "notifications/initialized": NotificationParams, - "notifications/roots/list_changed": NotificationParams, - "notifications/progress": ProgressNotificationParams, -} - - -class ServerRegistry(Protocol): - """The handler registry `ServerRunner` consumes. - - The lowlevel `Server` satisfies this via additive methods. - """ - - @property - def name(self) -> str: ... - @property - def version(self) -> str | None: ... - - @property - def middleware(self) -> Sequence[ContextMiddleware[Any]]: ... - - def get_request_handler(self, method: str) -> Handler | None: ... - def get_notification_handler(self, method: str) -> Handler | None: ... - def get_capabilities( - self, notification_options: Any, experimental_capabilities: dict[str, dict[str, Any]] - ) -> ServerCapabilities: ... - def otel_middleware(next_on_request: OnRequest) -> OnRequest: """Dispatch-tier middleware that wraps each request in an OpenTelemetry span. @@ -177,13 +109,14 @@ def _dump_result(result: Any) -> dict[str, Any]: @dataclass -class ServerRunner(Generic[LifespanT, ServerTransportT]): +class ServerRunner(Generic[LifespanT]): """Per-connection orchestrator. One instance per client connection.""" - server: ServerRegistry - dispatcher: Dispatcher[ServerTransportT] + server: Server[LifespanT] + dispatcher: Dispatcher[TransportContext] lifespan_state: LifespanT has_standalone_channel: bool + session_id: str | None = None stateless: bool = False dispatch_middleware: list[DispatchMiddleware] = field(default_factory=list[DispatchMiddleware]) @@ -192,7 +125,9 @@ class ServerRunner(Generic[LifespanT, ServerTransportT]): def __post_init__(self) -> None: self._initialized = self.stateless - self.connection = Connection(self.dispatcher, has_standalone_channel=self.has_standalone_channel) + self.connection = Connection( + self.dispatcher, has_standalone_channel=self.has_standalone_channel, session_id=self.session_id + ) async def run(self, *, task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STATUS_IGNORED) -> None: """Drive the dispatcher until the underlying channel closes. @@ -200,9 +135,15 @@ async def run(self, *, task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STAT Composes `dispatch_middleware` over `_on_request` and hands the result to `dispatcher.run()`. ``task_status.started()`` is forwarded so callers can ``await tg.start(runner.run)`` and resume once the dispatcher is - ready to accept requests. + ready to accept requests. Once the dispatcher exits, + `connection.exit_stack` is unwound (shielded) so any per-connection + cleanup registered by handlers or middleware runs to completion. """ - await self.dispatcher.run(self._compose_on_request(), self._on_notify, task_status=task_status) + try: + await self.dispatcher.run(self._compose_on_request(), self._on_notify, task_status=task_status) + finally: + with anyio.CancelScope(shield=True): + await self.connection.exit_stack.aclose() def _compose_on_request(self) -> OnRequest: """Wrap `_on_request` in `dispatch_middleware`, outermost-first. @@ -227,17 +168,15 @@ async def _on_request( code=INVALID_REQUEST, message=f"Received {method!r} before initialization was complete", ) - handler = self.server.get_request_handler(method) - if handler is None: + entry = self.server.get_request_handler(method) + if entry is None: raise MCPError(code=METHOD_NOT_FOUND, message=f"Method not found: {method}") - # TODO: scaffolding — params_type comes from a static lookup until the - # registry stores it alongside the handler. - params_type = _PARAMS_FOR_METHOD.get(method, RequestParams) # ValidationError propagates; the dispatcher's exception boundary maps # it to INVALID_PARAMS. - typed_params = params_type.model_validate(params or {}) + typed_params = entry.params_type.model_validate(params or {}) ctx = self._make_context(dctx, typed_params) - call: CallNext = partial(handler, ctx, typed_params) + # TODO: cast goes away when `ServerRequestContext = Context` lands. + call: CallNext = partial(cast(Any, entry.handler), ctx, typed_params) for mw in reversed(self.server.middleware): call = partial(mw, ctx, method, typed_params, call) return _dump_result(await call()) @@ -255,24 +194,18 @@ async def _on_notify( if not self._initialized: logger.debug("dropped %s: received before initialization", method) return - handler = self.server.get_notification_handler(method) - if handler is None: + entry = self.server.get_notification_handler(method) + if entry is None: logger.debug("no handler for notification %s", method) return - params_type = _PARAMS_FOR_NOTIFICATION.get(method, NotificationParams) - typed_params = params_type.model_validate(params or {}) + typed_params = entry.params_type.model_validate(params or {}) ctx = self._make_context(dctx, typed_params) - await handler(ctx, typed_params) - - def _make_context( - self, dctx: DispatchContext[TransportContext], typed_params: BaseModel - ) -> Context[LifespanT, ServerTransportT]: - # `OnRequest` delivers `DispatchContext[TransportContext]`; this - # ServerRunner instance was constructed for a specific - # `ServerTransportT`, so the narrow is safe by construction. - narrowed = cast(DispatchContext[ServerTransportT], dctx) + # TODO: cast goes away when `ServerRequestContext = Context` lands. + await cast(Any, entry.handler)(ctx, typed_params) + + def _make_context(self, dctx: DispatchContext[TransportContext], typed_params: BaseModel) -> Context[LifespanT]: meta = getattr(typed_params, "meta", None) - return Context(narrowed, lifespan=self.lifespan_state, connection=self.connection, meta=meta) + return Context(dctx, lifespan=self.lifespan_state, connection=self.connection, meta=meta) def _handle_initialize(self, params: Mapping[str, Any] | None) -> dict[str, Any]: init = InitializeRequestParams.model_validate(params or {}) @@ -289,7 +222,7 @@ def _handle_initialize(self, params: Mapping[str, Any] | None) -> dict[str, Any] self.connection.initialized.set() result = InitializeResult( protocol_version=self.connection.protocol_version, - capabilities=self.server.get_capabilities(NotificationOptions(), {}), + capabilities=self.server.capabilities(), server_info=Implementation(name=self.server.name, version=self.server.version or "0.0.0"), ) return _dump_result(result) diff --git a/src/mcp/shared/direct_dispatcher.py b/src/mcp/shared/direct_dispatcher.py index 27443ec87..1842cf8ab 100644 --- a/src/mcp/shared/direct_dispatcher.py +++ b/src/mcp/shared/direct_dispatcher.py @@ -162,18 +162,20 @@ async def _dispatch_notify(self, method: str, params: Mapping[str, Any] | None) def create_direct_dispatcher_pair( *, can_send_request: bool = True, + headers: Mapping[str, str] | None = None, ) -> tuple[DirectDispatcher, DirectDispatcher]: """Create two `DirectDispatcher` instances wired to each other. Args: can_send_request: Sets `TransportContext.can_send_request` on both sides. Pass ``False`` to simulate a transport with no back-channel. + headers: Sets `TransportContext.headers` on both sides. Returns: A ``(left, right)`` pair. Conventionally ``left`` is the client side and ``right`` is the server side, but the wiring is symmetric. """ - ctx = TransportContext(kind=DIRECT_TRANSPORT_KIND, can_send_request=can_send_request) + ctx = TransportContext(kind=DIRECT_TRANSPORT_KIND, can_send_request=can_send_request, headers=headers) left = DirectDispatcher(ctx) right = DirectDispatcher(ctx) left.connect_to(right) diff --git a/src/mcp/shared/jsonrpc_dispatcher.py b/src/mcp/shared/jsonrpc_dispatcher.py index f1e7b3675..b450bb66d 100644 --- a/src/mcp/shared/jsonrpc_dispatcher.py +++ b/src/mcp/shared/jsonrpc_dispatcher.py @@ -76,6 +76,21 @@ `TransportContext(kind="jsonrpc", can_send_request=True)` when not supplied.""" +def _coerce_id(request_id: RequestId) -> RequestId: + """Coerce a string request ID to int when it's a valid int literal. + + `_allocate_id` only ever produces ``int`` keys for ``_pending``, but a peer + may echo the ID back as a JSON string. The TypeScript SDK and `BaseSession` + both perform this coercion at lookup time so the response still correlates. + """ + if isinstance(request_id, str): + try: + return int(request_id) + except ValueError: + pass + return request_id + + @dataclass(slots=True) class _Pending: """An outbound request awaiting its response.""" @@ -409,7 +424,7 @@ def _dispatch_notification( if msg.method == "notifications/progress": match msg.params: case {"progressToken": str() | int() as token, "progress": int() | float() as progress} if ( - pending := self._pending.get(token) + pending := self._pending.get(_coerce_id(token)) ) is not None and pending.on_progress is not None: total = msg.params.get("total") message = msg.params.get("message") @@ -428,7 +443,7 @@ def _dispatch_notification( self._spawn(on_notify, dctx, msg.method, msg.params, sender_ctx=sender_ctx) def _resolve_pending(self, request_id: RequestId | None, outcome: dict[str, Any] | ErrorData) -> None: - pending = self._pending.get(request_id) if request_id is not None else None + pending = self._pending.get(_coerce_id(request_id)) if request_id is not None else None if pending is None: logger.debug("dropping response for unknown/late request id %r", request_id) return diff --git a/src/mcp/shared/transport_context.py b/src/mcp/shared/transport_context.py index 832cead51..934611670 100644 --- a/src/mcp/shared/transport_context.py +++ b/src/mcp/shared/transport_context.py @@ -6,6 +6,7 @@ dispatcher (`ServerRunner`, `Context`, user handlers) read its concrete fields. """ +from collections.abc import Mapping from dataclasses import dataclass __all__ = ["TransportContext"] @@ -28,3 +29,10 @@ class TransportContext: stdio, SSE, and stateful streamable HTTP. When ``False``, `DispatchContext.send_raw_request` raises `NoBackChannelError`. """ + + headers: Mapping[str, str] | None = None + """Request headers carried by this message, when the transport has them. + + Populated by HTTP-based transports; ``None`` on stdio. Handlers should + None-check before use. + """ diff --git a/tests/server/test_runner.py b/tests/server/test_runner.py index 843b0ae8b..33df234db 100644 --- a/tests/server/test_runner.py +++ b/tests/server/test_runner.py @@ -6,9 +6,9 @@ under test. """ -from collections.abc import AsyncIterator -from contextlib import asynccontextmanager -from typing import Any +from collections.abc import AsyncIterator, Mapping +from contextlib import AbstractAsyncContextManager, asynccontextmanager +from typing import Any, cast import anyio import anyio.lowlevel @@ -17,20 +17,25 @@ from mcp.server.connection import Connection from mcp.server.context import Context -from mcp.server.lowlevel.server import Server +from mcp.server.lowlevel.server import NotificationOptions, Server from mcp.server.runner import ServerRunner, otel_middleware from mcp.shared.direct_dispatcher import DirectDispatcher, create_direct_dispatcher_pair from mcp.shared.dispatcher import DispatchMiddleware from mcp.shared.exceptions import MCPError -from mcp.shared.transport_context import TransportContext from mcp.types import ( INTERNAL_ERROR, INVALID_REQUEST, LATEST_PROTOCOL_VERSION, METHOD_NOT_FOUND, + CallToolRequestParams, ClientCapabilities, Implementation, InitializeRequestParams, + ListToolsResult, + NotificationParams, + PaginatedRequestParams, + RequestParams, + SetLevelRequestParams, Tool, ) @@ -46,7 +51,7 @@ def _initialize_params() -> dict[str, Any]: ).model_dump(by_alias=True, exclude_none=True) -_seen_ctx: list[Context[Any, TransportContext]] = [] +_seen_ctx: list[Context[Any]] = [] SrvT = Server[dict[str, Any]] @@ -55,12 +60,11 @@ def server() -> SrvT: """A lowlevel Server with one tools/list handler registered.""" _seen_ctx.clear() - async def list_tools(ctx: Any, params: Any) -> Any: - # ctx is typed `Any` because Server's on_list_tools kwarg expects the - # legacy ServerRequestContext shape; ServerRunner passes the new - # `Context`. The transition is intentional — Handler is loosely typed. + async def list_tools(ctx: Any, params: PaginatedRequestParams | None) -> ListToolsResult: + # ctx is `Any` while `on_*` kwargs are typed against `ServerRequestContext` + # but `ServerRunner` passes the new `Context`; tightens once the alias lands. _seen_ctx.append(ctx) - return {"tools": [Tool(name="t", input_schema={"type": "object"}).model_dump(by_alias=True)]} + return ListToolsResult(tools=[Tool(name="t", input_schema={"type": "object"})]) return Server(name="test-server", version="0.0.1", on_list_tools=list_tools) @@ -72,8 +76,10 @@ async def connected_runner( initialized: bool = True, stateless: bool = False, has_standalone_channel: bool = True, + session_id: str | None = None, + headers: Mapping[str, str] | None = None, dispatch_middleware: list[DispatchMiddleware] | None = None, -) -> AsyncIterator[tuple[DirectDispatcher, ServerRunner[None, TransportContext]]]: +) -> AsyncIterator[tuple[DirectDispatcher, ServerRunner[dict[str, Any]]]]: """Yield ``(client, runner)`` running over an in-memory dispatcher pair. Starts the client (echo handlers) and `runner.run()` in a task group, wraps @@ -81,12 +87,13 @@ async def connected_runner( ``initialized`` is true the helper performs the real ``initialize`` request before yielding, so tests start past the init-gate via the public path. """ - client, server_d = create_direct_dispatcher_pair() + client, server_d = create_direct_dispatcher_pair(headers=headers) runner = ServerRunner( server=server, dispatcher=server_d, - lifespan_state=None, + lifespan_state={}, has_standalone_channel=has_standalone_channel, + session_id=session_id, stateless=stateless, dispatch_middleware=dispatch_middleware or [], ) @@ -147,7 +154,7 @@ async def test_runner_routes_to_handler_and_builds_context(server: SrvT): assert result["tools"][0]["name"] == "t" ctx = _seen_ctx[0] assert isinstance(ctx, Context) - assert ctx.lifespan is None + assert ctx.lifespan == {} assert isinstance(ctx.connection, Connection) assert ctx.transport.kind == "direct" @@ -175,7 +182,7 @@ async def test_runner_on_notify_routes_to_registered_handler(server: SrvT): async def on_roots_changed(ctx: Any, params: Any) -> None: seen.append((ctx, params)) - server._notification_handlers["notifications/roots/list_changed"] = on_roots_changed + server.add_notification_handler("notifications/roots/list_changed", NotificationParams, on_roots_changed) async with connected_runner(server) as (client, _): await client.notify("notifications/roots/list_changed", None) # DirectDispatcher delivers synchronously; one yield is enough. @@ -249,7 +256,7 @@ async def test_runner_handler_returning_none_yields_empty_result(server: SrvT): async def set_level(ctx: Any, params: Any) -> None: return None - server._request_handlers["logging/setLevel"] = set_level + server.add_request_handler("logging/setLevel", SetLevelRequestParams, set_level) async with connected_runner(server) as (client, _): result = await client.send_raw_request("logging/setLevel", {"level": "info"}) assert result == {} @@ -260,7 +267,9 @@ async def test_runner_handler_returning_unsupported_type_surfaces_as_internal_er async def bad_return(ctx: Any, params: Any) -> int: return 42 - server._request_handlers["tools/list"] = bad_return + # cast: deliberately registering a handler with a bad return type to + # exercise the runtime check; pyright would (correctly) reject it otherwise. + server.add_request_handler("tools/list", PaginatedRequestParams, cast(Any, bad_return)) async with connected_runner(server) as (client, _): with pytest.raises(MCPError) as exc: await client.send_raw_request("tools/list", None) @@ -275,12 +284,48 @@ async def test_runner_stateless_skips_init_gate(server: SrvT): assert result["tools"][0]["name"] == "t" +@pytest.mark.anyio +async def test_server_add_request_handler_routes_custom_method_with_validated_params(server: SrvT): + class GreetParams(RequestParams): + name: str + + received: list[GreetParams] = [] + + async def greet(ctx: Any, params: GreetParams) -> dict[str, Any]: + received.append(params) + return {"greeting": f"hello {params.name}"} + + server.add_request_handler("custom/greet", GreetParams, greet) + async with connected_runner(server) as (client, _): + result = await client.send_raw_request("custom/greet", {"name": "world"}) + assert result == {"greeting": "hello world"} + assert isinstance(received[0], GreetParams) + assert received[0].name == "world" + + +@pytest.mark.anyio +async def test_server_capabilities_reflects_ctor_options_in_initialize_result(): + async def list_tools(ctx: Any, params: PaginatedRequestParams | None) -> ListToolsResult: + raise NotImplementedError + + server: SrvT = Server( + name="caps-test", + on_list_tools=list_tools, + notification_options=NotificationOptions(tools_changed=True), + experimental_capabilities={"ext": {"k": "v"}}, + ) + async with connected_runner(server, initialized=False) as (client, _): + result = await client.send_raw_request("initialize", _initialize_params()) + assert result["capabilities"]["tools"]["listChanged"] is True + assert result["capabilities"]["experimental"] == {"ext": {"k": "v"}} + + @pytest.mark.anyio async def test_otel_middleware_emits_server_span_with_method_and_target(server: SrvT, spans: SpanCapture): async def call_tool(ctx: Any, params: Any) -> dict[str, Any]: return {"content": [], "isError": False} - server._request_handlers["tools/call"] = call_tool + server.add_request_handler("tools/call", CallToolRequestParams, call_tool) async with connected_runner(server, dispatch_middleware=[otel_middleware]) as (client, _): spans.clear() result = await client.send_raw_request("tools/call", {"name": "mytool", "arguments": {}}) @@ -326,7 +371,7 @@ async def test_otel_middleware_records_error_status_on_handler_exception(server: async def failing(ctx: Any, params: Any) -> Any: raise ValueError("handler blew up") - server._request_handlers["tools/list"] = failing + server.add_request_handler("tools/list", PaginatedRequestParams, failing) async with connected_runner(server, dispatch_middleware=[otel_middleware]) as (client, _): spans.clear() with pytest.raises(MCPError) as exc: @@ -338,3 +383,100 @@ async def failing(ctx: Any, params: Any) -> Any: [event] = [e for e in span.events if e.name == "exception"] assert event.attributes is not None assert event.attributes["exception.type"] == "ValueError" + + +@pytest.mark.anyio +async def test_connection_state_persists_across_requests_on_same_connection(server: SrvT) -> None: + async def count(ctx: Any, params: PaginatedRequestParams | None) -> ListToolsResult: + ctx.connection.state["n"] = ctx.connection.state.get("n", 0) + 1 + return ListToolsResult(tools=[]) + + server.add_request_handler("tools/list", PaginatedRequestParams, count) + async with connected_runner(server) as (client, runner): + await client.send_raw_request("tools/list", None) + await client.send_raw_request("tools/list", None) + assert runner.connection.state == {"n": 2} + + +@pytest.mark.anyio +async def test_connection_exit_stack_runs_pushed_callback_after_close(server: SrvT) -> None: + cleaned: list[str] = [] + + async def push(ctx: Any, params: PaginatedRequestParams | None) -> ListToolsResult: + async def _cleanup() -> None: + cleaned.append("done") + + ctx.connection.exit_stack.push_async_callback(_cleanup) + return ListToolsResult(tools=[]) + + server.add_request_handler("tools/list", PaginatedRequestParams, push) + async with connected_runner(server) as (client, _runner): + await client.send_raw_request("tools/list", None) + assert cleaned == [] + assert cleaned == ["done"] + + +@pytest.mark.anyio +async def test_connection_exit_stack_unwinds_entered_context_manager_after_close(server: SrvT) -> None: + events: list[str] = [] + + class _Tracker(AbstractAsyncContextManager[str]): + async def __aenter__(self) -> str: + events.append("enter") + return "resource" + + async def __aexit__(self, *exc: object) -> None: + events.append("exit") + + async def acquire(ctx: Any, params: PaginatedRequestParams | None) -> ListToolsResult: + res = await ctx.connection.exit_stack.enter_async_context(_Tracker()) + ctx.connection.state["res"] = res + return ListToolsResult(tools=[]) + + server.add_request_handler("tools/list", PaginatedRequestParams, acquire) + async with connected_runner(server) as (client, runner): + await client.send_raw_request("tools/list", None) + assert events == ["enter"] + assert runner.connection.state["res"] == "resource" + assert events == ["enter", "exit"] + + +@pytest.mark.anyio +async def test_connection_exit_stack_runs_callbacks_lifo_after_handler_error(server: SrvT) -> None: + cleaned: list[int] = [] + + async def push_then_fail(ctx: Any, params: PaginatedRequestParams | None) -> ListToolsResult: + for i in (1, 2, 3): + ctx.connection.exit_stack.push_async_callback(_append, i) + raise RuntimeError("boom") + + async def _append(i: int) -> None: + cleaned.append(i) + + server.add_request_handler("tools/list", PaginatedRequestParams, push_then_fail) + async with connected_runner(server) as (client, _runner): + with pytest.raises(MCPError) as ei: + await client.send_raw_request("tools/list", None) + assert ei.value.error.code == INTERNAL_ERROR + assert cleaned == [] + assert cleaned == [3, 2, 1] + + +@pytest.mark.anyio +async def test_context_session_id_and_headers_expose_connection_and_transport(server: SrvT) -> None: + async with connected_runner(server, session_id="sess-abc", headers={"authorization": "Bearer t"}) as (client, _r): + await client.send_raw_request("tools/list", None) + [ctx] = _seen_ctx + assert ctx.session_id == "sess-abc" + assert ctx.session_id == ctx.connection.session_id + assert ctx.headers == {"authorization": "Bearer t"} + assert ctx.headers is ctx.transport.headers + + +@pytest.mark.anyio +async def test_context_session_id_and_headers_default_none(server: SrvT) -> None: + async with connected_runner(server) as (client, _r): + await client.send_raw_request("tools/list", None) + [ctx] = _seen_ctx + assert ctx.session_id is None + assert ctx.headers is None diff --git a/tests/server/test_server_context.py b/tests/server/test_server_context.py index e01de34d3..43c2069a8 100644 --- a/tests/server/test_server_context.py +++ b/tests/server/test_server_context.py @@ -31,11 +31,11 @@ class _Lifespan: @pytest.mark.anyio async def test_context_exposes_lifespan_and_connection_and_forwards_base_context(): - captured: list[Context[_Lifespan, TransportContext]] = [] + captured: list[Context[_Lifespan]] = [] conn = Connection.__new__(Connection) # placeholder until running_pair gives us the dispatcher async def server_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: - ctx: Context[_Lifespan, TransportContext] = Context(dctx, lifespan=_Lifespan("app"), connection=conn) + ctx: Context[_Lifespan] = Context(dctx, lifespan=_Lifespan("app"), connection=conn) captured.append(ctx) return {} @@ -62,7 +62,7 @@ async def client_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | results: list[CreateMessageResult] = [] async def server_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: - ctx: Context[_Lifespan, TransportContext] = Context( + ctx: Context[_Lifespan] = Context( dctx, lifespan=_Lifespan("app"), connection=Connection(dctx, has_standalone_channel=True) ) results.append( @@ -92,7 +92,7 @@ async def client_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | results: list[ListRootsResult] = [] async def server_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: - ctx: Context[_Lifespan, TransportContext] = Context( + ctx: Context[_Lifespan] = Context( dctx, lifespan=_Lifespan("app"), connection=Connection(dctx, has_standalone_channel=True) ) results.append(await ctx.send_request(ListRootsRequest())) @@ -113,7 +113,7 @@ async def test_context_log_sends_request_scoped_message_notification(): _, c_notify = echo_handlers(crec) async def server_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: - ctx: Context[_Lifespan, TransportContext] = Context( + ctx: Context[_Lifespan] = Context( dctx, lifespan=_Lifespan("app"), connection=Connection(dctx, has_standalone_channel=True) ) await ctx.log("debug", "hello") @@ -137,7 +137,7 @@ async def test_context_log_includes_logger_and_meta_when_supplied(): _, c_notify = echo_handlers(crec) async def server_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: - ctx: Context[_Lifespan, TransportContext] = Context( + ctx: Context[_Lifespan] = Context( dctx, lifespan=_Lifespan("app"), connection=Connection(dctx, has_standalone_channel=True) ) await ctx.log("info", "x", logger="my.log", meta={"traceId": "t"}) diff --git a/tests/shared/test_jsonrpc_dispatcher.py b/tests/shared/test_jsonrpc_dispatcher.py index 7f9f11718..5755b55d1 100644 --- a/tests/shared/test_jsonrpc_dispatcher.py +++ b/tests/shared/test_jsonrpc_dispatcher.py @@ -18,6 +18,7 @@ from mcp.shared.exceptions import MCPError from mcp.shared.jsonrpc_dispatcher import ( # pyright: ignore[reportPrivateUsage] JSONRPCDispatcher, + _coerce_id, _outbound_metadata, _Pending, ) @@ -29,6 +30,7 @@ INVALID_PARAMS, ErrorData, JSONRPCError, + JSONRPCNotification, JSONRPCRequest, JSONRPCResponse, Tool, @@ -511,6 +513,87 @@ def test_outbound_metadata_with_resumption_token_returns_client_metadata(): assert _outbound_metadata(None, {}) is None +@pytest.mark.anyio +async def test_response_with_string_id_correlates_to_int_keyed_pending_request(): + """A peer that echoes the request ID as a JSON string still resolves the waiter.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + on_request, on_notify = echo_handlers(Recorder()) + try: + async with anyio.create_task_group() as tg: + await tg.start(client.run, on_request, on_notify) + with anyio.fail_after(5): + + async def respond_stringly() -> None: + out = await c2s_recv.receive() + assert isinstance(out, SessionMessage) + assert isinstance(out.message, JSONRPCRequest) + rid = out.message.id + assert isinstance(rid, int) + await s2c_send.send( + SessionMessage(message=JSONRPCResponse(jsonrpc="2.0", id=str(rid), result={"ok": True})) + ) + + tg.start_soon(respond_stringly) + result = await client.send_raw_request("ping", None) + assert result == {"ok": True} + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + +@pytest.mark.anyio +async def test_progress_with_string_token_reaches_callback_for_int_keyed_request(): + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + on_request, on_notify = echo_handlers(Recorder()) + seen: list[float] = [] + try: + async with anyio.create_task_group() as tg: + await tg.start(client.run, on_request, on_notify) + with anyio.fail_after(5): + + async def respond_with_string_token_progress() -> None: + out = await c2s_recv.receive() + assert isinstance(out, SessionMessage) + assert isinstance(out.message, JSONRPCRequest) + rid = out.message.id + assert isinstance(rid, int) + await s2c_send.send( + SessionMessage( + message=JSONRPCNotification( + jsonrpc="2.0", + method="notifications/progress", + params={"progressToken": str(rid), "progress": 0.5}, + ) + ) + ) + await s2c_send.send( + SessionMessage(message=JSONRPCResponse(jsonrpc="2.0", id=rid, result={"ok": True})) + ) + + async def on_progress(progress: float, total: float | None, message: str | None) -> None: + seen.append(progress) + + tg.start_soon(respond_with_string_token_progress) + result = await client.send_raw_request("ping", None, {"on_progress": on_progress}) + assert result == {"ok": True} + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + assert seen == [0.5] + + +def test_coerce_id_passes_through_non_numeric_string_and_int(): + assert _coerce_id("7") == 7 + assert _coerce_id("not-an-int") == "not-an-int" + assert _coerce_id(42) == 42 + + @pytest.mark.anyio async def test_jsonrpc_error_response_with_null_id_is_dropped(): """Parse-error responses (id=null) have no waiter; they're logged and dropped."""