Source code for litestar_mcp.plugin

"""Litestar MCP Plugin implementation."""

from collections.abc import Sequence
from typing import TYPE_CHECKING, Any

from litestar import Litestar, Router
from litestar.config.app import AppConfig
from litestar.di import Provide
from litestar.handlers import BaseRouteHandler
from litestar.plugins import CLIPlugin, InitPluginProtocol
from litestar.stores.memory import MemoryStore

from litestar_mcp.config import MCPConfig
from litestar_mcp.manifests import build_agent_card, build_mcp_server_manifest, build_oauth_protected_resource
from litestar_mcp.registry import Registry
from litestar_mcp.routes import MCPController
from litestar_mcp.sessions import MCPSessionManager
from litestar_mcp.sse import SSEManager
from litestar_mcp.tasks import InMemoryTaskStore, TaskRecord
from litestar_mcp.utils import get_handler_function, get_mcp_metadata

if TYPE_CHECKING:
    from click import Group


[docs] class LitestarMCP(InitPluginProtocol, CLIPlugin): """Litestar plugin for Model Context Protocol integration."""
[docs] def __init__(self, config: MCPConfig | None = None) -> None: """Initialize the MCP plugin.""" self._config = config or MCPConfig() self._registry = Registry() self._sse_manager = SSEManager( max_streams=self._config.sse_max_streams, max_idle_seconds=self._config.sse_max_idle_seconds, ) session_store = self._config.session_store or MemoryStore() self._session_manager = MCPSessionManager( session_store, max_idle_seconds=self._config.session_max_idle_seconds, ) self._config._session_manager = self._session_manager # noqa: SLF001 self._task_store: InMemoryTaskStore | None = None if self._config.task_config is not None: task_config = self._config.task_config self._task_store = InMemoryTaskStore( default_ttl=task_config.default_ttl, max_ttl=task_config.max_ttl, poll_interval=task_config.poll_interval, )
@property def config(self) -> MCPConfig: """Get the plugin configuration.""" return self._config @property def registry(self) -> Registry: """Get the central registry.""" return self._registry @property def discovered_tools(self) -> dict[str, BaseRouteHandler]: """Get discovered MCP tools.""" return self._registry.tools @property def discovered_resources(self) -> dict[str, BaseRouteHandler]: """Get discovered MCP resources.""" return self._registry.resources
[docs] def on_cli_init(self, cli: "Group") -> None: """Configure CLI commands for MCP operations.""" from litestar_mcp.cli import mcp_group cli.add_command(mcp_group)
def _discover_mcp_routes(self, route_handlers: Sequence[Any]) -> None: """Discover routes marked for MCP exposure via opt attribute or decorators.""" for handler in route_handlers: if isinstance(handler, BaseRouteHandler): metadata = get_mcp_metadata(handler) if not metadata: metadata = get_mcp_metadata(get_handler_function(handler)) if metadata: if metadata["type"] == "tool": self._registry.register_tool(metadata["name"], handler) elif metadata["type"] == "resource": self._registry.register_resource(metadata["name"], handler) template = metadata.get("resource_template") if template is not None: self._registry.register_resource_template(metadata["name"], handler, template) elif handler.opt: tool_key = self._config.opt_keys.tool resource_key = self._config.opt_keys.resource template_key = self._config.opt_keys.resource_template if tool_key in handler.opt: self._registry.register_tool(handler.opt[tool_key], handler) if resource_key in handler.opt: resource_name = handler.opt[resource_key] self._registry.register_resource(resource_name, handler) opt_template = handler.opt.get(template_key) if isinstance(opt_template, str): self._registry.register_resource_template(resource_name, handler, opt_template) if getattr(handler, "route_handlers", None): self._discover_mcp_routes(handler.route_handlers) # pyright: ignore[reportAttributeAccessIssue]
[docs] def on_app_init(self, app_config: AppConfig) -> AppConfig: """Initialize the MCP integration when the Litestar app starts.""" self._discover_mcp_routes(app_config.route_handlers) self._registry.set_sse_manager(self._sse_manager) if self._task_store is not None: async def publish_task_status(record: TaskRecord) -> None: await self._registry.publish_notification( "notifications/tasks/status", record.to_dict(), ) self._task_store.set_status_callback(publish_task_status) def provide_mcp_config() -> MCPConfig: return self._config def provide_registry() -> Registry: return self._registry def provide_task_store() -> InMemoryTaskStore | None: return self._task_store def provide_session_manager() -> MCPSessionManager: return self._session_manager router_kwargs: dict[str, Any] = { "path": self._config.base_path, "route_handlers": [MCPController], "tags": ["mcp"], "include_in_schema": self._config.include_in_schema, "dependencies": { "config": Provide(provide_mcp_config, sync_to_thread=False), "registry": Provide(provide_registry, sync_to_thread=False), "task_store": Provide(provide_task_store, sync_to_thread=False), "session_manager": Provide(provide_session_manager, sync_to_thread=False), "discovered_tools": Provide(lambda: self._registry.tools, sync_to_thread=False), "discovered_resources": Provide(lambda: self._registry.resources, sync_to_thread=False), }, } if self._config.guards is not None: router_kwargs["guards"] = self._config.guards mcp_router = Router(**router_kwargs) app_config.route_handlers.append(mcp_router) app_config.on_startup.append(self.on_startup) from litestar import Request from litestar import get as litestar_get @litestar_get("/.well-known/oauth-protected-resource", sync_to_thread=False, opt={"exclude_from_auth": True}) def oauth_protected_resource(request: Request[Any, Any, Any]) -> dict[str, Any]: return build_oauth_protected_resource(self._config.auth, request.app) @litestar_get("/.well-known/agent-card.json", sync_to_thread=False, opt={"exclude_from_auth": True}) def agent_card(request: Request[Any, Any, Any]) -> dict[str, Any]: return build_agent_card( base_url=str(request.base_url), config=self._config, app=request.app, discovered_tools=self._registry.tools, ) @litestar_get("/.well-known/mcp-server.json", sync_to_thread=False, opt={"exclude_from_auth": True}) def mcp_server_manifest(request: Request[Any, Any, Any]) -> dict[str, Any]: return build_mcp_server_manifest( base_url=str(request.base_url), config=self._config, app=request.app, discovered_tools=self._registry.tools, discovered_resources=self._registry.resources, ) app_config.route_handlers.extend([oauth_protected_resource, agent_card, mcp_server_manifest]) return app_config
[docs] def on_startup(self, app: Litestar) -> None: """Perform discovery after app is fully initialized and routes are built.""" all_handlers: list[BaseRouteHandler] = [] for route in app.routes: if hasattr(route, "route_handlers"): all_handlers.extend(route.route_handlers) # pyright: ignore[reportAttributeAccessIssue] self._discover_mcp_routes(all_handlers)