DesertWolf's picture
Upload folder using huggingface_hub
447ebeb verified
"""
MCP Client Manager
This class is responsible for managing MCP clients with support for both SSE and HTTP streamable transports.
This is a Proxy
"""
import asyncio
import json
import uuid
from typing import Any, Dict, List, Optional, cast
from mcp import ClientSession
from mcp.client.sse import sse_client
from mcp.types import CallToolResult
from mcp.types import Tool as MCPTool
from litellm._logging import verbose_logger
from litellm.proxy._types import (
LiteLLM_MCPServerTable,
MCPAuthType,
MCPSpecVersion,
MCPSpecVersionType,
MCPTransport,
MCPTransportType,
)
try:
from mcp.client.streamable_http import streamablehttp_client
except ImportError:
streamablehttp_client = None # type: ignore
from litellm.types.mcp_server.mcp_server_manager import MCPInfo, MCPServer
class MCPServerManager:
def __init__(self):
self.registry: Dict[str, MCPServer] = {}
self.config_mcp_servers: Dict[str, MCPServer] = {}
"""
eg.
[
"server-1": {
"name": "zapier_mcp_server",
"url": "https://actions.zapier.com/mcp/sk-ak-2ew3bofIeQIkNoeKIdXrF1Hhhp/sse"
"transport": "sse",
"auth_type": "api_key",
"spec_version": "2025-03-26"
},
"uuid-2": {
"name": "google_drive_mcp_server",
"url": "https://actions.zapier.com/mcp/sk-ak-2ew3bofIeQIkNoeKIdXrF1Hhhp/sse"
}
]
"""
self.tool_name_to_mcp_server_name_mapping: Dict[str, str] = {}
"""
{
"gmail_send_email": "zapier_mcp_server",
}
"""
def get_registry(self) -> Dict[str, MCPServer]:
"""
Get the registered MCP Servers from the registry and union with the config MCP Servers
"""
return self.config_mcp_servers | self.registry
def load_servers_from_config(self, mcp_servers_config: Dict[str, Any]):
"""
Load the MCP Servers from the config
"""
verbose_logger.debug("Loading MCP Servers from config-----")
for server_name, server_config in mcp_servers_config.items():
_mcp_info: dict = server_config.get("mcp_info", None) or {}
mcp_info = MCPInfo(**_mcp_info)
mcp_info["server_name"] = server_name
mcp_info["description"] = server_config.get("description", None)
server_id = str(uuid.uuid4())
new_server = MCPServer(
server_id=server_id,
name=server_name,
url=server_config["url"],
# TODO: utility fn the default values
transport=server_config.get("transport", MCPTransport.sse),
spec_version=server_config.get("spec_version", MCPSpecVersion.mar_2025),
auth_type=server_config.get("auth_type", None),
mcp_info=mcp_info,
)
self.config_mcp_servers[server_id] = new_server
verbose_logger.debug(
f"Loaded MCP Servers: {json.dumps(self.config_mcp_servers, indent=4, default=str)}"
)
self.initialize_tool_name_to_mcp_server_name_mapping()
def remove_server(self, mcp_server: LiteLLM_MCPServerTable):
"""
Remove a server from the registry
"""
if mcp_server.alias in self.get_registry():
del self.registry[mcp_server.alias]
verbose_logger.debug(f"Removed MCP Server: {mcp_server.alias}")
elif mcp_server.server_id in self.get_registry():
del self.registry[mcp_server.server_id]
verbose_logger.debug(f"Removed MCP Server: {mcp_server.server_id}")
else:
verbose_logger.warning(
f"Server ID {mcp_server.server_id} not found in registry"
)
def add_update_server(self, mcp_server: LiteLLM_MCPServerTable):
if mcp_server.server_id not in self.get_registry():
new_server = MCPServer(
server_id=mcp_server.server_id,
name=mcp_server.alias or mcp_server.server_id,
url=mcp_server.url,
transport=cast(MCPTransportType, mcp_server.transport),
spec_version=cast(MCPSpecVersionType, mcp_server.spec_version),
auth_type=cast(MCPAuthType, mcp_server.auth_type),
mcp_info=MCPInfo(
server_name=mcp_server.alias or mcp_server.server_id,
description=mcp_server.description,
),
)
self.registry[mcp_server.server_id] = new_server
verbose_logger.debug(
f"Added MCP Server: {mcp_server.alias or mcp_server.server_id}"
)
async def list_tools(self) -> List[MCPTool]:
"""
List all tools available across all MCP Servers.
Returns:
List[MCPTool]: Combined list of tools from all servers
"""
list_tools_result: List[MCPTool] = []
verbose_logger.debug("SERVER MANAGER LISTING TOOLS")
for _, server in self.get_registry().items():
try:
tools = await self._get_tools_from_server(server)
list_tools_result.extend(tools)
except Exception as e:
verbose_logger.exception(
f"Error listing tools from server {server.name}: {str(e)}"
)
return list_tools_result
async def _get_tools_from_server(self, server: MCPServer) -> List[MCPTool]:
"""
Helper method to get tools from a single MCP server.
Args:
server (MCPServer): The server to query tools from
Returns:
List[MCPTool]: List of tools available on the server
"""
verbose_logger.debug(f"Connecting to url: {server.url}")
verbose_logger.info("_get_tools_from_server...")
# send transport to connect to the server
if server.transport is None or server.transport == MCPTransport.sse:
async with sse_client(url=server.url) as (read, write):
async with ClientSession(read, write) as session:
await session.initialize()
tools_result = await session.list_tools()
verbose_logger.debug(f"Tools from {server.name}: {tools_result}")
# Update tool to server mapping
for tool in tools_result.tools:
self.tool_name_to_mcp_server_name_mapping[tool.name] = (
server.name
)
return tools_result.tools
elif server.transport == MCPTransport.http:
if streamablehttp_client is None:
verbose_logger.error(
"streamablehttp_client not available - install mcp with HTTP support"
)
raise ValueError(
"streamablehttp_client not available - please run `pip install mcp -U`"
)
verbose_logger.debug(f"Using HTTP streamable transport for {server.url}")
async with streamablehttp_client(
url=server.url,
) as (read_stream, write_stream, get_session_id):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
if get_session_id is not None:
session_id = get_session_id()
if session_id:
verbose_logger.debug(f"HTTP session ID: {session_id}")
tools_result = await session.list_tools()
verbose_logger.debug(f"Tools from {server.name}: {tools_result}")
# Update tool to server mapping
for tool in tools_result.tools:
self.tool_name_to_mcp_server_name_mapping[tool.name] = (
server.name
)
return tools_result.tools
else:
verbose_logger.warning(f"Unsupported transport type: {server.transport}")
return []
def initialize_tool_name_to_mcp_server_name_mapping(self):
"""
On startup, initialize the tool name to MCP server name mapping
"""
try:
if asyncio.get_running_loop():
asyncio.create_task(
self._initialize_tool_name_to_mcp_server_name_mapping()
)
except RuntimeError as e: # no running event loop
verbose_logger.exception(
f"No running event loop - skipping tool name to MCP server name mapping initialization: {str(e)}"
)
async def _initialize_tool_name_to_mcp_server_name_mapping(self):
"""
Call list_tools for each server and update the tool name to MCP server name mapping
"""
for server in self.get_registry().values():
tools = await self._get_tools_from_server(server)
for tool in tools:
self.tool_name_to_mcp_server_name_mapping[tool.name] = server.name
async def call_tool(self, name: str, arguments: Dict[str, Any]):
"""
Call a tool with the given name and arguments
"""
mcp_server = self._get_mcp_server_from_tool_name(name)
if mcp_server is None:
raise ValueError(f"Tool {name} not found")
elif mcp_server.transport is None or mcp_server.transport == MCPTransport.sse:
async with sse_client(url=mcp_server.url) as (read, write):
async with ClientSession(read, write) as session:
await session.initialize()
return await session.call_tool(name, arguments)
elif mcp_server.transport == MCPTransport.http:
if streamablehttp_client is None:
verbose_logger.error(
"streamablehttp_client not available - install mcp with HTTP support"
)
raise ValueError(
"streamablehttp_client not available - please run `pip install mcp -U`"
)
verbose_logger.debug(
f"Using HTTP streamable transport for tool call: {name}"
)
async with streamablehttp_client(
url=mcp_server.url,
) as (read_stream, write_stream, get_session_id):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
if get_session_id is not None:
session_id = get_session_id()
if session_id:
verbose_logger.debug(
f"HTTP session ID for tool call: {session_id}"
)
return await session.call_tool(name, arguments)
else:
return CallToolResult(content=[], isError=True)
def _get_mcp_server_from_tool_name(self, tool_name: str) -> Optional[MCPServer]:
"""
Get the MCP Server from the tool name
"""
if tool_name in self.tool_name_to_mcp_server_name_mapping:
for server in self.get_registry().values():
if server.name == self.tool_name_to_mcp_server_name_mapping[tool_name]:
return server
return None
async def _add_mcp_servers_from_db_to_in_memory_registry(self):
from litellm.proxy._experimental.mcp_server.db import get_all_mcp_servers
from litellm.proxy.management_endpoints.mcp_management_endpoints import (
get_prisma_client_or_throw,
)
# perform authz check to filter the mcp servers user has access to
prisma_client = get_prisma_client_or_throw(
"Database not connected. Connect a database to your proxy"
)
db_mcp_servers = await get_all_mcp_servers(prisma_client)
# ensure the global_mcp_server_manager is up to date with the db
for server in db_mcp_servers:
self.add_update_server(server)
def get_mcp_server_by_id(self, server_id: str) -> Optional[MCPServer]:
"""
Get the MCP Server from the server id
"""
for server in self.get_registry().values():
if server.server_id == server_id:
return server
return None
global_mcp_server_manager: MCPServerManager = MCPServerManager()