""" MCP Client Manager This class is responsible for managing MCP clients with support for both SSE and HTTP streamable transports. This is a Proxy """ import asyncio import json import uuid from typing import Any, Dict, List, Optional, cast from mcp import ClientSession from mcp.client.sse import sse_client from mcp.types import CallToolResult from mcp.types import Tool as MCPTool from litellm._logging import verbose_logger from litellm.proxy._types import ( LiteLLM_MCPServerTable, MCPAuthType, MCPSpecVersion, MCPSpecVersionType, MCPTransport, MCPTransportType, ) try: from mcp.client.streamable_http import streamablehttp_client except ImportError: streamablehttp_client = None # type: ignore from litellm.types.mcp_server.mcp_server_manager import MCPInfo, MCPServer class MCPServerManager: def __init__(self): self.registry: Dict[str, MCPServer] = {} self.config_mcp_servers: Dict[str, MCPServer] = {} """ eg. [ "server-1": { "name": "zapier_mcp_server", "url": "https://actions.zapier.com/mcp/sk-ak-2ew3bofIeQIkNoeKIdXrF1Hhhp/sse" "transport": "sse", "auth_type": "api_key", "spec_version": "2025-03-26" }, "uuid-2": { "name": "google_drive_mcp_server", "url": "https://actions.zapier.com/mcp/sk-ak-2ew3bofIeQIkNoeKIdXrF1Hhhp/sse" } ] """ self.tool_name_to_mcp_server_name_mapping: Dict[str, str] = {} """ { "gmail_send_email": "zapier_mcp_server", } """ def get_registry(self) -> Dict[str, MCPServer]: """ Get the registered MCP Servers from the registry and union with the config MCP Servers """ return self.config_mcp_servers | self.registry def load_servers_from_config(self, mcp_servers_config: Dict[str, Any]): """ Load the MCP Servers from the config """ verbose_logger.debug("Loading MCP Servers from config-----") for server_name, server_config in mcp_servers_config.items(): _mcp_info: dict = server_config.get("mcp_info", None) or {} mcp_info = MCPInfo(**_mcp_info) mcp_info["server_name"] = server_name mcp_info["description"] = server_config.get("description", None) server_id = str(uuid.uuid4()) new_server = MCPServer( server_id=server_id, name=server_name, url=server_config["url"], # TODO: utility fn the default values transport=server_config.get("transport", MCPTransport.sse), spec_version=server_config.get("spec_version", MCPSpecVersion.mar_2025), auth_type=server_config.get("auth_type", None), mcp_info=mcp_info, ) self.config_mcp_servers[server_id] = new_server verbose_logger.debug( f"Loaded MCP Servers: {json.dumps(self.config_mcp_servers, indent=4, default=str)}" ) self.initialize_tool_name_to_mcp_server_name_mapping() def remove_server(self, mcp_server: LiteLLM_MCPServerTable): """ Remove a server from the registry """ if mcp_server.alias in self.get_registry(): del self.registry[mcp_server.alias] verbose_logger.debug(f"Removed MCP Server: {mcp_server.alias}") elif mcp_server.server_id in self.get_registry(): del self.registry[mcp_server.server_id] verbose_logger.debug(f"Removed MCP Server: {mcp_server.server_id}") else: verbose_logger.warning( f"Server ID {mcp_server.server_id} not found in registry" ) def add_update_server(self, mcp_server: LiteLLM_MCPServerTable): if mcp_server.server_id not in self.get_registry(): new_server = MCPServer( server_id=mcp_server.server_id, name=mcp_server.alias or mcp_server.server_id, url=mcp_server.url, transport=cast(MCPTransportType, mcp_server.transport), spec_version=cast(MCPSpecVersionType, mcp_server.spec_version), auth_type=cast(MCPAuthType, mcp_server.auth_type), mcp_info=MCPInfo( server_name=mcp_server.alias or mcp_server.server_id, description=mcp_server.description, ), ) self.registry[mcp_server.server_id] = new_server verbose_logger.debug( f"Added MCP Server: {mcp_server.alias or mcp_server.server_id}" ) async def list_tools(self) -> List[MCPTool]: """ List all tools available across all MCP Servers. Returns: List[MCPTool]: Combined list of tools from all servers """ list_tools_result: List[MCPTool] = [] verbose_logger.debug("SERVER MANAGER LISTING TOOLS") for _, server in self.get_registry().items(): try: tools = await self._get_tools_from_server(server) list_tools_result.extend(tools) except Exception as e: verbose_logger.exception( f"Error listing tools from server {server.name}: {str(e)}" ) return list_tools_result async def _get_tools_from_server(self, server: MCPServer) -> List[MCPTool]: """ Helper method to get tools from a single MCP server. Args: server (MCPServer): The server to query tools from Returns: List[MCPTool]: List of tools available on the server """ verbose_logger.debug(f"Connecting to url: {server.url}") verbose_logger.info("_get_tools_from_server...") # send transport to connect to the server if server.transport is None or server.transport == MCPTransport.sse: async with sse_client(url=server.url) as (read, write): async with ClientSession(read, write) as session: await session.initialize() tools_result = await session.list_tools() verbose_logger.debug(f"Tools from {server.name}: {tools_result}") # Update tool to server mapping for tool in tools_result.tools: self.tool_name_to_mcp_server_name_mapping[tool.name] = ( server.name ) return tools_result.tools elif server.transport == MCPTransport.http: if streamablehttp_client is None: verbose_logger.error( "streamablehttp_client not available - install mcp with HTTP support" ) raise ValueError( "streamablehttp_client not available - please run `pip install mcp -U`" ) verbose_logger.debug(f"Using HTTP streamable transport for {server.url}") async with streamablehttp_client( url=server.url, ) as (read_stream, write_stream, get_session_id): async with ClientSession(read_stream, write_stream) as session: await session.initialize() if get_session_id is not None: session_id = get_session_id() if session_id: verbose_logger.debug(f"HTTP session ID: {session_id}") tools_result = await session.list_tools() verbose_logger.debug(f"Tools from {server.name}: {tools_result}") # Update tool to server mapping for tool in tools_result.tools: self.tool_name_to_mcp_server_name_mapping[tool.name] = ( server.name ) return tools_result.tools else: verbose_logger.warning(f"Unsupported transport type: {server.transport}") return [] def initialize_tool_name_to_mcp_server_name_mapping(self): """ On startup, initialize the tool name to MCP server name mapping """ try: if asyncio.get_running_loop(): asyncio.create_task( self._initialize_tool_name_to_mcp_server_name_mapping() ) except RuntimeError as e: # no running event loop verbose_logger.exception( f"No running event loop - skipping tool name to MCP server name mapping initialization: {str(e)}" ) async def _initialize_tool_name_to_mcp_server_name_mapping(self): """ Call list_tools for each server and update the tool name to MCP server name mapping """ for server in self.get_registry().values(): tools = await self._get_tools_from_server(server) for tool in tools: self.tool_name_to_mcp_server_name_mapping[tool.name] = server.name async def call_tool(self, name: str, arguments: Dict[str, Any]): """ Call a tool with the given name and arguments """ mcp_server = self._get_mcp_server_from_tool_name(name) if mcp_server is None: raise ValueError(f"Tool {name} not found") elif mcp_server.transport is None or mcp_server.transport == MCPTransport.sse: async with sse_client(url=mcp_server.url) as (read, write): async with ClientSession(read, write) as session: await session.initialize() return await session.call_tool(name, arguments) elif mcp_server.transport == MCPTransport.http: if streamablehttp_client is None: verbose_logger.error( "streamablehttp_client not available - install mcp with HTTP support" ) raise ValueError( "streamablehttp_client not available - please run `pip install mcp -U`" ) verbose_logger.debug( f"Using HTTP streamable transport for tool call: {name}" ) async with streamablehttp_client( url=mcp_server.url, ) as (read_stream, write_stream, get_session_id): async with ClientSession(read_stream, write_stream) as session: await session.initialize() if get_session_id is not None: session_id = get_session_id() if session_id: verbose_logger.debug( f"HTTP session ID for tool call: {session_id}" ) return await session.call_tool(name, arguments) else: return CallToolResult(content=[], isError=True) def _get_mcp_server_from_tool_name(self, tool_name: str) -> Optional[MCPServer]: """ Get the MCP Server from the tool name """ if tool_name in self.tool_name_to_mcp_server_name_mapping: for server in self.get_registry().values(): if server.name == self.tool_name_to_mcp_server_name_mapping[tool_name]: return server return None async def _add_mcp_servers_from_db_to_in_memory_registry(self): from litellm.proxy._experimental.mcp_server.db import get_all_mcp_servers from litellm.proxy.management_endpoints.mcp_management_endpoints import ( get_prisma_client_or_throw, ) # perform authz check to filter the mcp servers user has access to prisma_client = get_prisma_client_or_throw( "Database not connected. Connect a database to your proxy" ) db_mcp_servers = await get_all_mcp_servers(prisma_client) # ensure the global_mcp_server_manager is up to date with the db for server in db_mcp_servers: self.add_update_server(server) def get_mcp_server_by_id(self, server_id: str) -> Optional[MCPServer]: """ Get the MCP Server from the server id """ for server in self.get_registry().values(): if server.server_id == server_id: return server return None global_mcp_server_manager: MCPServerManager = MCPServerManager()