Spaces:
Configuration error
Configuration error
import uuid | |
from typing import Iterable, List, Optional, Set | |
from litellm.proxy._types import ( | |
LiteLLM_MCPServerTable, | |
LiteLLM_ObjectPermissionTable, | |
LiteLLM_TeamTable, | |
NewMCPServerRequest, | |
SpecialMCPServerName, | |
UpdateMCPServerRequest, | |
UserAPIKeyAuth, | |
) | |
from litellm.proxy.utils import PrismaClient | |
async def get_all_mcp_servers( | |
prisma_client: PrismaClient, | |
) -> List[LiteLLM_MCPServerTable]: | |
""" | |
Returns all of the mcp servers from the db | |
""" | |
mcp_servers = await prisma_client.db.litellm_mcpservertable.find_many() | |
return mcp_servers | |
async def get_mcp_server( | |
prisma_client: PrismaClient, server_id: str | |
) -> Optional[LiteLLM_MCPServerTable]: | |
""" | |
Returns the matching mcp server from the db iff exists | |
""" | |
mcp_server: Optional[ | |
LiteLLM_MCPServerTable | |
] = await prisma_client.db.litellm_mcpservertable.find_unique( | |
where={ | |
"server_id": server_id, | |
} | |
) | |
return mcp_server | |
async def get_mcp_servers( | |
prisma_client: PrismaClient, server_ids: Iterable[str] | |
) -> List[LiteLLM_MCPServerTable]: | |
""" | |
Returns the matching mcp servers from the db with the server_ids | |
""" | |
mcp_servers: List[ | |
LiteLLM_MCPServerTable | |
] = await prisma_client.db.litellm_mcpservertable.find_many( | |
where={ | |
"server_id": {"in": server_ids}, | |
} | |
) | |
return mcp_servers | |
async def get_mcp_servers_by_verificationtoken( | |
prisma_client: PrismaClient, token: str | |
) -> List[str]: | |
""" | |
Returns the mcp servers from the db for the verification token | |
""" | |
verification_token_record: LiteLLM_TeamTable = ( | |
await prisma_client.db.litellm_verificationtoken.find_unique( | |
where={ | |
"token": token, | |
}, | |
include={ | |
"object_permission": True, | |
}, | |
) | |
) | |
mcp_servers: Optional[List[str]] = [] | |
if ( | |
verification_token_record is not None | |
and verification_token_record.object_permission is not None | |
): | |
mcp_servers = verification_token_record.object_permission.mcp_servers | |
return mcp_servers or [] | |
async def get_mcp_servers_by_team( | |
prisma_client: PrismaClient, team_id: str | |
) -> List[str]: | |
""" | |
Returns the mcp servers from the db for the team id | |
""" | |
team_record: LiteLLM_TeamTable = ( | |
await prisma_client.db.litellm_teamtable.find_unique( | |
where={ | |
"team_id": team_id, | |
}, | |
include={ | |
"object_permission": True, | |
}, | |
) | |
) | |
mcp_servers: Optional[List[str]] = [] | |
if team_record is not None and team_record.object_permission is not None: | |
mcp_servers = team_record.object_permission.mcp_servers | |
return mcp_servers or [] | |
async def get_all_mcp_servers_for_user( | |
prisma_client: PrismaClient, | |
user: UserAPIKeyAuth, | |
) -> List[LiteLLM_MCPServerTable]: | |
""" | |
Get all the mcp servers filtered by the given user has access to. | |
Following Least-Privilege Principle - the requestor should only be able to see the mcp servers that they have access to. | |
""" | |
mcp_server_ids: Set[str] = set() | |
mcp_servers = [] | |
# Get the mcp servers for the key | |
if user.api_key: | |
token_mcp_servers = await get_mcp_servers_by_verificationtoken( | |
prisma_client, user.api_key | |
) | |
mcp_server_ids.update(token_mcp_servers) | |
# check for special team membership | |
if ( | |
SpecialMCPServerName.all_team_servers in mcp_server_ids | |
and user.team_id is not None | |
): | |
team_mcp_servers = await get_mcp_servers_by_team( | |
prisma_client, user.team_id | |
) | |
mcp_server_ids.update(team_mcp_servers) | |
if len(mcp_server_ids) > 0: | |
mcp_servers = await get_mcp_servers(prisma_client, mcp_server_ids) | |
return mcp_servers | |
async def get_objectpermissions_for_mcp_server( | |
prisma_client: PrismaClient, mcp_server_id: str | |
) -> List[LiteLLM_ObjectPermissionTable]: | |
""" | |
Get all the object permissions records and the associated team and verficiationtoken records that have access to the mcp server | |
""" | |
object_permission_records = ( | |
await prisma_client.db.litellm_objectpermissiontable.find_many( | |
where={ | |
"mcp_servers": {"has": mcp_server_id}, | |
}, | |
include={ | |
"teams": True, | |
"verification_tokens": True, | |
}, | |
) | |
) | |
return object_permission_records | |
async def get_virtualkeys_for_mcp_server( | |
prisma_client: PrismaClient, server_id: str | |
) -> List: | |
""" | |
Get all the virtual keys that have access to the mcp server | |
""" | |
virtual_keys = await prisma_client.db.litellm_verificationtoken.find_many( | |
where={ | |
"mcp_servers": {"has": server_id}, | |
}, | |
) | |
if virtual_keys is None: | |
return [] | |
return virtual_keys | |
async def delete_mcp_server_from_team(prisma_client: PrismaClient, server_id: str): | |
""" | |
Remove the mcp server from the team | |
""" | |
pass | |
async def delete_mcp_server_from_virtualkey(): | |
""" | |
Remove the mcp server from the virtual key | |
""" | |
pass | |
async def delete_mcp_server( | |
prisma_client: PrismaClient, server_id: str | |
) -> Optional[LiteLLM_MCPServerTable]: | |
""" | |
Delete the mcp server from the db by server_id | |
Returns the deleted mcp server record if it exists, otherwise None | |
""" | |
deleted_server = await prisma_client.db.litellm_mcpservertable.delete( | |
where={ | |
"server_id": server_id, | |
}, | |
) | |
return deleted_server | |
async def create_mcp_server( | |
prisma_client: PrismaClient, data: NewMCPServerRequest, touched_by: str | |
) -> LiteLLM_MCPServerTable: | |
""" | |
Create a new mcp server record in the db | |
""" | |
if data.server_id is None: | |
data.server_id = str(uuid.uuid4()) | |
mcp_server_record = await prisma_client.db.litellm_mcpservertable.create( | |
data={ | |
**data.model_dump(), | |
"created_by": touched_by, | |
"updated_by": touched_by, | |
} | |
) | |
return mcp_server_record | |
async def update_mcp_server( | |
prisma_client: PrismaClient, data: UpdateMCPServerRequest, touched_by: str | |
) -> LiteLLM_MCPServerTable: | |
""" | |
Update a new mcp server record in the db | |
""" | |
mcp_server_record = await prisma_client.db.litellm_mcpservertable.update( | |
where={ | |
"server_id": data.server_id, | |
}, | |
data={ | |
**data.model_dump(), | |
"created_by": touched_by, | |
"updated_by": touched_by, | |
}, | |
) | |
return mcp_server_record | |