Source code for litestar_mcp.routes

# ruff: noqa: PLR0915, PLR0911, C901
"""MCP JSON-RPC 2.0 Streamable HTTP transport for Litestar applications."""

import asyncio
import contextlib
from collections.abc import AsyncGenerator
from dataclasses import dataclass
from typing import Any

from litestar import Controller, MediaType, Request, Response, delete, get, post
from litestar.exceptions import SerializationException
from litestar.handlers import BaseRouteHandler
from litestar.response import ServerSentEvent, ServerSentEventMessage
from litestar.serialization import decode_json, encode_json
from litestar.status_codes import (
    HTTP_200_OK,
    HTTP_204_NO_CONTENT,
    HTTP_400_BAD_REQUEST,
    HTTP_403_FORBIDDEN,
    HTTP_404_NOT_FOUND,
    HTTP_405_METHOD_NOT_ALLOWED,
    HTTP_503_SERVICE_UNAVAILABLE,
)

from litestar_mcp.config import MCPConfig
from litestar_mcp.executor import MCPToolErrorResult, execute_tool
from litestar_mcp.jsonrpc import (
    INTERNAL_ERROR,
    INVALID_PARAMS,
    INVALID_REQUEST,
    METHOD_NOT_FOUND,
    PARSE_ERROR,
    JSONRPCError,
    JSONRPCErrorException,
    JSONRPCRouter,
    error_response,
    parse_request,
)
from litestar_mcp.registry import Registry
from litestar_mcp.schema_builder import generate_schema_for_handler
from litestar_mcp.sessions import MCPSessionManager, SessionTerminated
from litestar_mcp.sse import StreamLimitExceeded
from litestar_mcp.tasks import InMemoryTaskStore, TaskLookupError, TaskRecord, TaskStateError
from litestar_mcp.utils import (
    get_handler_function,
    get_mcp_metadata,
    match_uri,
    render_description,
    should_include_handler,
)

MCP_PROTOCOL_VERSION = "2025-11-25"
MCP_SESSION_HEADER = "Mcp-Session-Id"

SESSION_ERROR = -32000
SESSION_NOT_INITIALIZED = -32002

_SESSION_EXEMPT_METHODS = frozenset({"initialize", "ping"})
_PRE_INIT_ALLOWED_METHODS = frozenset({"initialize", "ping", "notifications/initialized"})


@dataclass
class RequestContext:
    """Request context threaded through tool and task execution.

    Authentication lives in Litestar middleware; ``request.user`` and
    ``request.auth`` are the per-request source of truth for tool handlers.
    This struct only carries the scope identifiers used by MCP itself
    (client id, task-owner id, and the live request handle).
    """

    client_id: str
    owner_id: str
    request: "Request[Any, Any, Any] | None" = None


def _validate_origin(request: Request[Any, Any, Any], config: MCPConfig) -> Response[Any] | None:
    """Validate the Origin header if allowed_origins is configured."""
    if not config.allowed_origins:
        return None

    origin = request.headers.get("origin")
    if origin and origin not in config.allowed_origins:
        return Response(
            content={"error": "Origin not allowed"},
            status_code=HTTP_403_FORBIDDEN,
            media_type=MediaType.JSON,
        )
    return None


def _add_protocol_headers(response: Response[Any]) -> Response[Any]:
    """Add standard MCP protocol headers to a response."""
    response.headers["mcp-protocol-version"] = MCP_PROTOCOL_VERSION
    return response


def _request_subject(request: Request[Any, Any, Any]) -> str | None:
    """Best-effort ``sub``-like identifier from ``request.auth`` claims dict.

    Middleware populates ``scope["auth"]`` with whatever shape it sets — this
    helper reads the raw scope value (avoiding the ``request.auth`` property
    which raises when no auth middleware is installed) and treats it as a
    mapping, pulling ``"sub"`` if present. Non-mapping values are ignored.
    """
    auth = request.scope.get("auth")
    if isinstance(auth, dict):
        sub = auth.get("sub")
        if isinstance(sub, str) and sub:
            return sub
    return None


def _resolve_client_id(request: Request[Any, Any, Any]) -> str:
    explicit_client_id = (
        request.headers.get("x-mcp-client-id")
        or request.headers.get("mcp-client-id")
        or request.query_params.get("clientId")
        or request.query_params.get("client_id")
    )
    if explicit_client_id:
        return explicit_client_id
    sub = _request_subject(request)
    if sub is not None:
        return f"user:{sub}"
    if request.client and request.client.host:
        return f"remote:{request.client.host}"
    return "anonymous"


def _build_request_context(request: Request[Any, Any, Any]) -> RequestContext:
    client_id = _resolve_client_id(request)
    sub = _request_subject(request)
    owner_id = f"user:{sub}" if sub is not None else f"client:{client_id}"
    return RequestContext(client_id=client_id, owner_id=owner_id, request=request)


def _serialize_tool_content(value: Any) -> str:
    if isinstance(value, str):
        return value
    return encode_json(value).decode("utf-8")


def _build_tool_result(value: Any, *, is_error: bool, task_id: str | None = None) -> dict[str, Any]:
    result: dict[str, Any] = {
        "content": [{"type": "text", "text": _serialize_tool_content(value)}],
        "isError": is_error,
    }
    if task_id is not None:
        result["_meta"] = {"io.modelcontextprotocol/related-task": {"taskId": task_id}}
    return result


_VALIDATION_CONTEXT_PARAMS = {
    "request",
    "socket",
    "state",
    "scope",
    "headers",
    "cookies",
    "query",
    "body",
    "data",
}


def _to_pointer(name: str, msgspec_path: str) -> str:
    """Turn ``name`` + ``$.age.limit`` into ``/arguments/age/limit`` JSON Pointer.

    ``msgspec.ValidationError`` messages include a trailing ``$.<path>`` marker
    indicating which nested field failed validation. We translate that into a
    JSON Pointer rooted at ``/arguments/<name>`` so downstream UIs can render
    field-level errors.
    """
    suffix = msgspec_path.removeprefix("$").lstrip(".")
    parts = ["arguments", name]
    if suffix:
        parts.extend(p for p in suffix.split(".") if p and p != name)
    return "/" + "/".join(parts)


def _split_msgspec_error(exc: "Exception") -> tuple[str, str]:
    """Split a ``msgspec.ValidationError`` string into (reason, path).

    msgspec formats messages as ``"<reason> - at `$.path`"``. When no path is
    present we return an empty path.
    """
    text = str(exc)
    marker = " - at `"
    if marker in text and text.endswith("`"):
        reason, _, tail = text.rpartition(marker)
        path = tail[:-1]
        return reason, path
    return text, ""


def _resolve_annotated_types(handler: "BaseRouteHandler") -> dict[str, Any]:
    """Return ``{param_name: annotated_type}`` from the original handler function.

    Litestar's ``signature_model`` strips user-supplied ``msgspec.Meta``
    constraints (``ge``/``le``/``pattern``/``min_length`` …) and replaces them
    with its own ``KwargDefinition`` metadata. To enforce those constraints via
    ``msgspec.convert`` we resolve type hints directly off the original
    function, preserving the full ``Annotated[...]`` chain.
    """
    import typing as _typing

    fn = get_handler_function(handler)
    try:
        return _typing.get_type_hints(fn, include_extras=True)
    except Exception:  # noqa: BLE001
        return {}


def _validate_tool_arguments(handler: "BaseRouteHandler", tool_args: dict[str, Any]) -> list[dict[str, str]]:
    """Validate ``tool_args`` against the handler's Litestar signature.

    Matches the executor's partitioning (Ch2): if the handler declares a
    ``data`` parameter, unrecognized tool_args are validated as fields of
    that struct type; path params are matched against the route's declared
    path variables; remaining scalars are matched against the handler's
    non-DI signature fields.

    Returns a list of ``{"path": <json-pointer>, "message": <reason>}`` dicts,
    sorted by path for deterministic output.
    """
    import msgspec

    signature_model = getattr(handler, "signature_model", None)
    if signature_model is None:
        return []

    try:
        fields = msgspec.structs.fields(signature_model)
    except TypeError:
        return []

    di_params: set[str] = set()
    with contextlib.suppress(AttributeError, TypeError):
        di_params = set(handler.resolve_dependencies().keys())

    declared_by_name = {field.name: field for field in fields}
    annotated_types = _resolve_annotated_types(handler)
    errors: list[dict[str, str]] = []

    data_field = declared_by_name.get("data")
    data_type = annotated_types.get("data") if data_field is not None else None
    recognized_scalar_names = {
        name for name in declared_by_name if name not in di_params and name not in _VALIDATION_CONTEXT_PARAMS
    }

    # When the handler has a ``data`` param, tool_args keys that aren't
    # recognized scalar fields are treated as members of the data struct.
    # Validate them by building a mapping and converting it to the struct.
    if data_type is not None:
        data_payload = {k: v for k, v in tool_args.items() if k not in recognized_scalar_names}
        if data_payload:
            try:
                msgspec.convert(data_payload, data_type, strict=False)
            except msgspec.ValidationError as exc:
                reason, path = _split_msgspec_error(exc)
                errors.append({"path": _to_pointer("data", path), "message": reason})
            except TypeError:
                pass

    for field in fields:
        if field.name in di_params or field.name in _VALIDATION_CONTEXT_PARAMS:
            continue
        if field.name == "data":
            # Presence of ``data`` is implied by any matching struct field
            # in tool_args; we don't require callers to pass ``data`` as a
            # literal key.
            continue
        if field.name in tool_args:
            continue
        if field.default is msgspec.NODEFAULT and field.default_factory is msgspec.NODEFAULT:
            errors.append({"path": _to_pointer(field.name, ""), "message": "Missing required argument"})

    for name, value in tool_args.items():
        if name not in recognized_scalar_names:
            if data_type is not None:
                # Unknown-to-scalars: assumed to belong to the ``data`` struct
                # and already validated above.
                continue
            # No ``data`` parameter → unknown keys are genuinely unexpected.
            errors.append({"path": "/arguments", "message": f"Unexpected argument: {name}"})
            continue
        declared = declared_by_name[name]
        convert_type = annotated_types.get(name, declared.type)
        try:
            msgspec.convert(value, convert_type, strict=False)
        except msgspec.ValidationError as exc:
            reason, path = _split_msgspec_error(exc)
            errors.append({"path": _to_pointer(name, path), "message": reason})
        except TypeError:
            continue

    return sorted(errors, key=lambda entry: (entry["path"], entry["message"]))


def build_jsonrpc_router(
    config: MCPConfig,
    discovered_tools: dict[str, BaseRouteHandler],
    discovered_resources: dict[str, BaseRouteHandler],
    *,
    app_ref: Any,
    request_context: RequestContext,
    task_store: InMemoryTaskStore | None = None,
    registry: Registry | None = None,
) -> JSONRPCRouter:
    """Build and return a JSONRPCRouter wired to MCP method handlers."""
    router = JSONRPCRouter()
    task_config = config.task_config

    async def execute_tool_call(
        handler: BaseRouteHandler,
        tool_args: dict[str, Any],
        *,
        task_id: str | None = None,
    ) -> dict[str, Any]:
        validation_errors = _validate_tool_arguments(handler, tool_args)
        if validation_errors:
            return _build_tool_result(
                {"error": "Invalid tool arguments", "errors": validation_errors},
                is_error=True,
                task_id=task_id,
            )

        try:
            result = await execute_tool(handler, app_ref, tool_args, request=request_context.request)
        except MCPToolErrorResult as err:
            return _build_tool_result(err.content, is_error=True, task_id=task_id)
        except Exception as exc:  # noqa: BLE001
            return _build_tool_result({"error": str(exc)}, is_error=True, task_id=task_id)

        return _build_tool_result(result, is_error=False, task_id=task_id)

    async def run_task(
        record: TaskRecord,
        handler: "BaseRouteHandler",
        tool_args: dict[str, Any],
    ) -> None:
        try:
            result = await execute_tool_call(handler, tool_args, task_id=record.task_id)
            await task_store.complete(record.task_id, result)  # type: ignore[union-attr]
        except JSONRPCErrorException as exc:
            await task_store.fail(record.task_id, exc.error)  # type: ignore[union-attr]
        except asyncio.CancelledError:
            raise
        except Exception as exc:  # noqa: BLE001
            await task_store.fail(  # type: ignore[union-attr]
                record.task_id,
                JSONRPCError(code=INTERNAL_ERROR, message=str(exc)),
                status_message=str(exc),
            )

    async def handle_initialize(params: dict[str, Any]) -> dict[str, Any]:  # noqa: ARG001
        server_name = config.name or "Litestar MCP Server"
        server_version = "1.0.0"

        if app_ref is not None:
            openapi_config = app_ref.openapi_config
            if openapi_config:
                server_name = config.name or openapi_config.title
                server_version = openapi_config.version

        capabilities: dict[str, Any] = {
            "tools": {"listChanged": True},
            "resources": {"subscribe": True, "listChanged": True},
        }
        if task_config is not None:
            task_capabilities: dict[str, Any] = {"requests": {"tools": {"call": {}}}}
            if task_config.list_enabled:
                task_capabilities["list"] = {}
            if task_config.cancel_enabled:
                task_capabilities["cancel"] = {}
            capabilities["tasks"] = task_capabilities

        return {
            "protocolVersion": MCP_PROTOCOL_VERSION,
            "capabilities": capabilities,
            "serverInfo": {"name": server_name, "version": server_version},
        }

    router.register("initialize", handle_initialize)

    async def handle_initialized(params: dict[str, Any]) -> dict[str, Any]:  # noqa: ARG001
        return {}

    router.register("notifications/initialized", handle_initialized)

    async def handle_ping(params: dict[str, Any]) -> dict[str, Any]:  # noqa: ARG001
        return {}

    router.register("ping", handle_ping)

    async def handle_tools_list(params: dict[str, Any]) -> dict[str, Any]:  # noqa: ARG001
        tools = []
        for name, handler in discovered_tools.items():
            handler_tags = set(getattr(handler, "tags", None) or [])
            if not should_include_handler(name, handler_tags, config):
                continue

            fn = get_handler_function(handler)
            metadata = get_mcp_metadata(handler) or get_mcp_metadata(fn) or {}
            tool_entry: dict[str, Any] = {
                "name": name,
                "description": render_description(
                    handler, fn, kind="tool", fallback_name=name, opt_keys=config.opt_keys
                ),
                "inputSchema": generate_schema_for_handler(handler),
            }
            if "output_schema" in metadata:
                tool_entry["outputSchema"] = metadata["output_schema"]
            if "annotations" in metadata:
                tool_entry["annotations"] = metadata["annotations"]
            if "scopes" in metadata:
                annotations = tool_entry.get("annotations") or {}
                # Explicit annotations.scopes wins when both are supplied.
                annotations.setdefault("scopes", list(metadata["scopes"]))
                tool_entry["annotations"] = annotations
            if task_config is not None and metadata.get("task_support") is not None:
                tool_entry["execution"] = {"taskSupport": metadata["task_support"]}
            tools.append(tool_entry)
        return {"tools": tools}

    router.register("tools/list", handle_tools_list)

    async def handle_tools_call(params: dict[str, Any]) -> dict[str, Any]:
        tool_name = params.get("name")
        if not tool_name:
            raise JSONRPCErrorException(JSONRPCError(code=INVALID_PARAMS, message="Missing required param: 'name'"))

        handler = discovered_tools.get(tool_name)
        if handler is None:
            raise JSONRPCErrorException(JSONRPCError(code=INVALID_PARAMS, message=f"Tool not found: {tool_name}"))

        fn = get_handler_function(handler)
        metadata = get_mcp_metadata(handler) or get_mcp_metadata(fn) or {}
        tool_args = params.get("arguments", {})
        if not isinstance(tool_args, dict):
            return _build_tool_result({"error": "Tool arguments must be an object"}, is_error=True)

        task_request = params.get("task")
        task_support = metadata.get("task_support")

        if task_request is None:
            if task_support == "required" and task_config is not None:
                raise JSONRPCErrorException(
                    JSONRPCError(code=INVALID_REQUEST, message="Task augmentation required for tools/call requests")
                )
            return await execute_tool_call(handler, tool_args)

        if task_config is None:
            raise JSONRPCErrorException(
                JSONRPCError(code=METHOD_NOT_FOUND, message=f"Task augmentation is not supported for tool: {tool_name}")
            )
        if task_support not in {"optional", "required"}:
            raise JSONRPCErrorException(
                JSONRPCError(code=METHOD_NOT_FOUND, message=f"Task augmentation is not supported for tool: {tool_name}")
            )
        if not isinstance(task_request, dict):
            raise JSONRPCErrorException(
                JSONRPCError(code=INVALID_PARAMS, message="The 'task' parameter must be an object")
            )

        record = await task_store.create(request_context.owner_id, task_request.get("ttl"))  # type: ignore[union-attr]
        background_task = asyncio.create_task(run_task(record, handler, tool_args))
        await task_store.attach_background_task(record.task_id, background_task)  # type: ignore[union-attr]
        return {"task": record.to_dict()}

    router.register("tools/call", handle_tools_call)

    async def handle_resources_list(params: dict[str, Any]) -> dict[str, Any]:  # noqa: ARG001
        resources = [
            {
                "uri": "litestar://openapi",
                "name": "openapi",
                "description": "OpenAPI schema for this Litestar application",
                "mimeType": "application/json",
            }
        ]
        for name, handler in discovered_resources.items():
            handler_tags = set(getattr(handler, "tags", None) or [])
            if not should_include_handler(name, handler_tags, config):
                continue

            fn = get_handler_function(handler)
            resources.append(
                {
                    "uri": f"litestar://{name}",
                    "name": name,
                    "description": render_description(
                        handler, fn, kind="resource", fallback_name=name, opt_keys=config.opt_keys
                    ),
                    "mimeType": "application/json",
                }
            )
        return {"resources": resources}

    router.register("resources/list", handle_resources_list)

    async def handle_resources_templates_list(params: dict[str, Any]) -> dict[str, Any]:  # noqa: ARG001
        if registry is None:
            return {"resourceTemplates": []}
        templates = []
        for entry in registry.templates.values():
            handler_tags = set(getattr(entry.handler, "tags", None) or [])
            if not should_include_handler(entry.name, handler_tags, config):
                continue
            fn = get_handler_function(entry.handler)
            templates.append(
                {
                    "uriTemplate": entry.template,
                    "name": entry.name,
                    "description": render_description(
                        entry.handler, fn, kind="resource", fallback_name=entry.name, opt_keys=config.opt_keys
                    ),
                    "mimeType": "application/json",
                }
            )
        return {"resourceTemplates": templates}

    router.register("resources/templates/list", handle_resources_templates_list)

    async def handle_resources_read(params: dict[str, Any]) -> dict[str, Any]:
        uri = params.get("uri", "")
        if not isinstance(uri, str) or not uri:
            raise JSONRPCErrorException(JSONRPCError(code=INVALID_PARAMS, message=f"Invalid resource URI: {uri}"))

        if uri.startswith("litestar://"):
            resource_name = uri[len("litestar://") :]
            if resource_name == "openapi" and app_ref is not None:
                return {
                    "contents": [
                        {
                            "uri": "litestar://openapi",
                            "mimeType": "application/json",
                            "text": encode_json(app_ref.openapi_schema).decode("utf-8"),
                        }
                    ]
                }

            handler = discovered_resources.get(resource_name)
            if handler is None:
                raise JSONRPCErrorException(
                    JSONRPCError(code=INVALID_PARAMS, message=f"Resource not found: {resource_name}")
                )

            try:
                result = await execute_tool(
                    handler,
                    app_ref,
                    {},
                    request=request_context.request,
                )
            except MCPToolErrorResult as err:
                raise JSONRPCErrorException(
                    JSONRPCError(code=INTERNAL_ERROR, message=f"Resource read failed: {err.content!s}")
                ) from err
            except Exception as exc:
                raise JSONRPCErrorException(
                    JSONRPCError(code=INTERNAL_ERROR, message=f"Resource read failed: {exc!s}")
                ) from exc

            return {
                "contents": [
                    {
                        "uri": uri,
                        "mimeType": "application/json",
                        "text": encode_json(result).decode("utf-8"),
                    }
                ]
            }

        # Non-``litestar://`` URIs: match against registered templates.
        # First template that matches wins (documented: registration order).
        template_entries = registry.templates.values() if registry is not None else ()
        for entry in template_entries:
            extracted = match_uri(entry.template, uri)
            if extracted is None:
                continue
            try:
                result = await execute_tool(
                    entry.handler,
                    app_ref,
                    dict(extracted),
                    request=request_context.request,
                )
            except MCPToolErrorResult as err:
                raise JSONRPCErrorException(
                    JSONRPCError(code=INTERNAL_ERROR, message=f"Resource read failed: {err.content!s}")
                ) from err
            except Exception as exc:
                raise JSONRPCErrorException(
                    JSONRPCError(code=INTERNAL_ERROR, message=f"Resource read failed: {exc!s}")
                ) from exc

            return {
                "contents": [
                    {
                        "uri": uri,
                        "mimeType": "application/json",
                        "text": encode_json(result).decode("utf-8"),
                    }
                ]
            }

        raise JSONRPCErrorException(JSONRPCError(code=INVALID_PARAMS, message=f"Resource not found: {uri}"))

    router.register("resources/read", handle_resources_read)

    async def handle_completion_complete(params: dict[str, Any]) -> dict[str, Any]:  # noqa: ARG001
        # v0.5.0 default: every ref returns an empty completion. A future
        # ``@mcp_resource_completion`` decorator will dispatch through this
        # method; for now, unknown refs must not error per MCP spec.
        return {"completion": {"values": [], "total": 0, "hasMore": False}}

    router.register("completion/complete", handle_completion_complete)

    if task_store is not None:

        async def handle_tasks_get(params: dict[str, Any]) -> dict[str, Any]:
            task_id = params.get("taskId")
            if not task_id:
                raise JSONRPCErrorException(
                    JSONRPCError(code=INVALID_PARAMS, message="Missing required param: 'taskId'")
                )
            try:
                record = await task_store.get(task_id, request_context.owner_id)
            except TaskLookupError as exc:
                raise JSONRPCErrorException(JSONRPCError(code=INVALID_PARAMS, message=str(exc))) from exc
            return record.to_dict()

        async def handle_tasks_result(params: dict[str, Any]) -> dict[str, Any]:
            task_id = params.get("taskId")
            if not task_id:
                raise JSONRPCErrorException(
                    JSONRPCError(code=INVALID_PARAMS, message="Missing required param: 'taskId'")
                )
            try:
                record = await task_store.wait_for_terminal(task_id, request_context.owner_id)
            except TaskLookupError as exc:
                raise JSONRPCErrorException(JSONRPCError(code=INVALID_PARAMS, message=str(exc))) from exc

            if record.result is not None:
                meta = record.result.setdefault("_meta", {})
                meta["io.modelcontextprotocol/related-task"] = {"taskId": task_id}
                return record.result
            if record.error is not None:
                raise JSONRPCErrorException(record.error)
            raise JSONRPCErrorException(
                JSONRPCError(code=INTERNAL_ERROR, message="Task did not produce a final result")
            )

        async def handle_tasks_list(params: dict[str, Any]) -> dict[str, Any]:
            limit = params.get("limit", 50)
            if not isinstance(limit, int) or limit <= 0:
                raise JSONRPCErrorException(
                    JSONRPCError(code=INVALID_PARAMS, message="The 'limit' parameter must be a positive integer")
                )
            try:
                tasks, next_cursor = await task_store.list(
                    request_context.owner_id,
                    cursor=params.get("cursor"),
                    limit=limit,
                )
            except ValueError as exc:
                raise JSONRPCErrorException(JSONRPCError(code=INVALID_PARAMS, message=str(exc))) from exc
            result: dict[str, Any] = {"tasks": [task.to_dict() for task in tasks]}
            if next_cursor is not None:
                result["nextCursor"] = next_cursor
            return result

        async def handle_tasks_cancel(params: dict[str, Any]) -> dict[str, Any]:
            task_id = params.get("taskId")
            if not task_id:
                raise JSONRPCErrorException(
                    JSONRPCError(code=INVALID_PARAMS, message="Missing required param: 'taskId'")
                )
            try:
                record = await task_store.cancel(task_id, request_context.owner_id)
            except TaskLookupError as exc:
                raise JSONRPCErrorException(JSONRPCError(code=INVALID_PARAMS, message=str(exc))) from exc
            except TaskStateError as exc:
                raise JSONRPCErrorException(JSONRPCError(code=INVALID_PARAMS, message=str(exc))) from exc
            return record.to_dict()

        router.register("tasks/get", handle_tasks_get)
        router.register("tasks/result", handle_tasks_result)
        router.register("tasks/list", handle_tasks_list)
        router.register("tasks/cancel", handle_tasks_cancel)

    return router


[docs] class MCPController(Controller): """MCP JSON-RPC 2.0 Streamable HTTP controller.""" @get("/", name="mcp_sse", media_type=MediaType.TEXT) async def handle_sse( self, request: Request[Any, Any, Any], config: MCPConfig, registry: Registry, session_manager: MCPSessionManager, ) -> Response[Any]: """Handle GET-based Streamable HTTP SSE streams on the MCP endpoint.""" origin_err = _validate_origin(request, config) if origin_err is not None: return origin_err accept_header = request.headers.get("accept", "") if "text/event-stream" not in accept_header: return _add_protocol_headers( Response( content={"error": "GET /mcp requires Accept: text/event-stream"}, status_code=HTTP_405_METHOD_NOT_ALLOWED, media_type=MediaType.JSON, ) ) _build_request_context(request) session_id = request.headers.get(MCP_SESSION_HEADER) or request.headers.get(MCP_SESSION_HEADER.lower()) if not session_id: return _add_protocol_headers( Response( content={"error": f"Missing required header: {MCP_SESSION_HEADER}"}, status_code=HTTP_400_BAD_REQUEST, media_type=MediaType.JSON, ) ) try: await session_manager.get(session_id) except SessionTerminated: return _add_protocol_headers( Response( content=error_response( None, JSONRPCError(code=SESSION_ERROR, message="Session terminated or unknown") ), status_code=HTTP_404_NOT_FOUND, media_type=MediaType.JSON, ) ) try: stream_id, stream = await registry.sse_manager.open_stream( session_id=session_id, last_event_id=request.headers.get("last-event-id"), ) except StreamLimitExceeded: return _add_protocol_headers( Response( content=error_response(None, JSONRPCError(code=SESSION_ERROR, message="SSE stream limit exceeded")), status_code=HTTP_503_SERVICE_UNAVAILABLE, media_type=MediaType.JSON, ) ) async def event_stream() -> AsyncGenerator[ServerSentEventMessage, None]: try: async for message in stream: yield ServerSentEventMessage(data=message.data, event=message.event, id=message.id) finally: registry.sse_manager.disconnect(stream_id) response = ServerSentEvent(event_stream()) response.headers[MCP_SESSION_HEADER] = session_id return _add_protocol_headers(response) @delete("/", name="mcp_session_delete", status_code=HTTP_200_OK) async def handle_delete( self, request: Request[Any, Any, Any], config: MCPConfig, registry: Registry, session_manager: MCPSessionManager, ) -> Response[Any]: """Terminate an MCP session and close its attached SSE streams.""" origin_err = _validate_origin(request, config) if origin_err is not None: return origin_err session_id = request.headers.get(MCP_SESSION_HEADER) or request.headers.get(MCP_SESSION_HEADER.lower()) if not session_id: return _add_protocol_headers( Response( content={"error": f"Missing required header: {MCP_SESSION_HEADER}"}, status_code=HTTP_400_BAD_REQUEST, media_type=MediaType.JSON, ) ) registry.sse_manager.close_session_streams(session_id) await session_manager.delete(session_id) return _add_protocol_headers(Response(content=None, status_code=HTTP_204_NO_CONTENT)) @post("/", name="mcp_jsonrpc", media_type=MediaType.JSON, status_code=HTTP_200_OK) async def handle_jsonrpc( self, request: Request[Any, Any, Any], config: MCPConfig, discovered_tools: dict[str, Any], discovered_resources: dict[str, Any], registry: Registry, session_manager: MCPSessionManager, task_store: InMemoryTaskStore | None = None, ) -> Response[Any]: """Handle a JSON-RPC 2.0 request over Streamable HTTP.""" origin_err = _validate_origin(request, config) if origin_err is not None: return origin_err try: raw = decode_json(await request.body()) except (SerializationException, ValueError): return _add_protocol_headers( Response( content=error_response(None, JSONRPCError(code=PARSE_ERROR, message="Parse error")), status_code=HTTP_200_OK, media_type=MediaType.JSON, ) ) try: rpc_request = parse_request(raw) except JSONRPCErrorException as exc: return _add_protocol_headers( Response( content=error_response(raw.get("id") if isinstance(raw, dict) else None, exc.error), status_code=HTTP_200_OK, media_type=MediaType.JSON, ) ) incoming_session_id = request.headers.get(MCP_SESSION_HEADER) or request.headers.get(MCP_SESSION_HEADER.lower()) session = None response_session_id: str | None = None if rpc_request.method == "initialize": params = rpc_request.params if isinstance(rpc_request.params, dict) else {} session = await session_manager.create( protocol_version=params.get("protocolVersion", MCP_PROTOCOL_VERSION), client_info=params.get("clientInfo") if isinstance(params.get("clientInfo"), dict) else None, capabilities=params.get("capabilities") if isinstance(params.get("capabilities"), dict) else None, ) response_session_id = session.id elif rpc_request.method in _SESSION_EXEMPT_METHODS: # ping may be issued without a session header if incoming_session_id: try: session = await session_manager.get(incoming_session_id) response_session_id = session.id except SessionTerminated: return _add_protocol_headers( Response( content=error_response( rpc_request.id, JSONRPCError(code=SESSION_ERROR, message="Session terminated or unknown"), ), status_code=HTTP_404_NOT_FOUND, media_type=MediaType.JSON, ) ) else: if not incoming_session_id: return _add_protocol_headers( Response( content=error_response( rpc_request.id, JSONRPCError(code=SESSION_ERROR, message=f"Missing required header: {MCP_SESSION_HEADER}"), ), status_code=HTTP_400_BAD_REQUEST, media_type=MediaType.JSON, ) ) try: session = await session_manager.get(incoming_session_id) except SessionTerminated: return _add_protocol_headers( Response( content=error_response( rpc_request.id, JSONRPCError(code=SESSION_ERROR, message="Session terminated or unknown"), ), status_code=HTTP_404_NOT_FOUND, media_type=MediaType.JSON, ) ) response_session_id = session.id if not session.initialized and rpc_request.method not in _PRE_INIT_ALLOWED_METHODS: return _add_protocol_headers( Response( content=error_response( rpc_request.id, JSONRPCError(code=SESSION_NOT_INITIALIZED, message="Session not initialized"), ), status_code=HTTP_200_OK, media_type=MediaType.JSON, ) ) if rpc_request.method == "notifications/initialized" and incoming_session_id: with contextlib.suppress(SessionTerminated): await session_manager.mark_initialized(incoming_session_id) request_context = _build_request_context(request) router = build_jsonrpc_router( config, discovered_tools, discovered_resources, app_ref=request.app, request_context=request_context, task_store=task_store, registry=registry, ) result = await router.dispatch(rpc_request) response: Response[Any] if result is None: response = Response(content=None, status_code=HTTP_204_NO_CONTENT) else: response = Response(content=result, status_code=HTTP_200_OK, media_type=MediaType.JSON) if response_session_id is not None: response.headers[MCP_SESSION_HEADER] = response_session_id return _add_protocol_headers(response)