Spaces:
Sleeping
Sleeping
# GET /models | |
# GET /tools | |
# POST /chat -> Groq -> response | |
import asyncio | |
import json | |
import traceback | |
from typing import List, Optional | |
from contextlib import AsyncExitStack | |
import uuid | |
from fastapi.staticfiles import StaticFiles | |
from pydantic import BaseModel | |
from fastapi import FastAPI, Request, HTTPException, Depends | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import FileResponse, JSONResponse | |
from mcp import ClientSession, StdioServerParameters | |
from mcp.client.stdio import stdio_client | |
from groq import Groq, APIConnectionError | |
from dotenv import load_dotenv | |
import os | |
import httpx | |
sessions = {} | |
unique_apikeys = [] | |
class MCPClient: | |
def __init__(self): | |
self.session: Optional[ClientSession] = None | |
self.exit_stack = AsyncExitStack() | |
self.current_model = None | |
self.groq = None | |
self.api_key = None | |
self.messages = [{ | |
"role": "system", | |
"content": "You are a helpful assistant that have access to different tools via MCP. Make complete answers." | |
}] | |
self.tool_use = True | |
self.models = None | |
self.tools = [] | |
async def connect(self, api_key: str): | |
try: | |
self.groq = Groq(api_key=api_key, http_client=httpx.Client(verify=False, timeout=30)) | |
self.api_key = api_key | |
except APIConnectionError as e: | |
traceback.print_exception(e) | |
return False | |
except Exception as e: | |
traceback.print_exception(e) | |
return False | |
server_params = StdioServerParameters(command="uv", args=["run", "server.py"]) | |
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params)) | |
self.stdio, self.write = stdio_transport | |
self.session = await self.exit_stack.enter_async_context(ClientSession(self.stdio, self.write)) | |
await self.session.initialize() | |
response = await self.session.list_tools() | |
tools = response.tools | |
print("\nConnected to server with tools:", [tool.name for tool in tools]) | |
self.tools = [{"type": "function", "function": { | |
"name": tool.name, | |
"description": tool.description, | |
"parameters": tool.inputSchema | |
}} for tool in tools] | |
def populate_model(self): | |
self.models = sorted([m.id for m in self.groq.models.list().data]) | |
async def process_query(self, query: str) -> str: | |
"""Process a query using Groq and available tools""" | |
self.messages.extend([ | |
{ | |
"role": "user", | |
"content": query | |
} | |
]) | |
response = self.groq.chat.completions.create( | |
model=self.current_model, | |
messages=self.messages, | |
tools=self.tools, | |
temperature=0 | |
) if self.tool_use else self.groq.chat.completions.create( | |
model=self.current_model, | |
messages=self.messages, | |
temperature=0.7 | |
) | |
# Process response and handle tool calls | |
final_text = [] | |
for choice in response.choices: | |
content = choice.message.content | |
tool_calls = choice.message.tool_calls | |
if content: | |
final_text.append(content) | |
if tool_calls: | |
print(tool_calls) | |
for tool in tool_calls: | |
tool_name = tool.function.name | |
tool_args = tool.function.arguments | |
result = await self.session.call_tool(tool_name, json.loads(tool_args)) | |
print(f"[Calling tool {tool_name} with args {tool_args}]") | |
if content is not None: | |
self.messages.append({ | |
"role": "assistant", | |
"content": content | |
}) | |
self.messages.append({ | |
"role": "tool", | |
"tool_call_id": tool.id, | |
"content": str(result.content) | |
}) | |
print(result.content[0].text) | |
response = self.groq.chat.completions.create( | |
model=self.current_model, | |
messages=self.messages, | |
temperature=0.7 | |
) | |
final_text.append(response.choices[0].message.content) | |
return "\n".join(final_text) | |
app = FastAPI() | |
app.add_middleware(CORSMiddleware, allow_credentials=True, allow_headers=["*"], allow_methods=["*"], allow_origins=["*"]) | |
app.mount("/static", StaticFiles(directory="static"), name="static") | |
mcp = MCPClient() | |
class InitRequest(BaseModel): | |
api_key: str | |
class InitResponse(BaseModel): | |
success: bool | |
session_id: str | |
models: Optional[list] = None | |
error: Optional[str] = None | |
class LogoutRequest(BaseModel): | |
session_id: str | |
def get_mcp_client(session_id: str) -> MCPClient|None: | |
"""Get the MCPClient for a given session_id, or raise 404.""" | |
client = sessions.get(session_id) | |
if client is None: | |
raise HTTPException(status_code=404, detail="Invalid session_id. Please re-initialize.") | |
return client | |
def root(): | |
return FileResponse("index.html") | |
async def init_server(req: InitRequest): | |
""" | |
Initializes a new MCP client session. Returns a session_id. | |
""" | |
api_key = req.api_key | |
session_id = str(uuid.uuid4()) | |
mcp = MCPClient() | |
try: | |
ok = await mcp.connect(api_key) | |
if ok is False: | |
raise RuntimeError("Failed to connect to MCP or Groq with API key.") | |
mcp.populate_model() | |
sessions[session_id] = mcp | |
if api_key not in unique_apikeys: | |
unique_apikeys.append(api_key) | |
else: | |
raise Exception("Session with this API key already exists. We won't re-return you the session ID. Bye-bye Hacker !!") | |
return InitResponse( | |
session_id=session_id, | |
models=mcp.models, | |
error=None, | |
success=True | |
) | |
except Exception as e: | |
traceback.print_exception(e) | |
return InitResponse( | |
session_id="", | |
models=None, | |
error=str(e), | |
success=False | |
) | |
class ChatRequest(BaseModel): | |
session_id: str | |
query: str | |
tool_use: Optional[bool] = True | |
model: Optional[str] = "llama-3.3-70b-versatile" | |
class ChatResponse(BaseModel): | |
output: str | |
error: Optional[str] = None | |
async def chat(req: ChatRequest): | |
""" | |
Handles chat requests for a given session. | |
""" | |
try: | |
mcp = get_mcp_client(req.session_id) | |
mcp.tool_use = req.tool_use | |
if req.model in mcp.models: | |
mcp.current_model = req.model | |
else: | |
raise ValueError(f"Model not recognized: Not in the model list: {mcp.models}") | |
result = await mcp.process_query(req.query) | |
return ChatResponse(output=result) | |
except Exception as e: | |
traceback.print_exception(e) | |
return ChatResponse(output="", error=str(e)) | |
async def logout(logout_req: LogoutRequest): | |
"""Clean up session resources.""" | |
mcp = sessions.pop(logout_req.session_id, None) | |
unique_apikeys.remove(mcp.api_key) | |
if mcp and hasattr(mcp.exit_stack, "aclose"): | |
await mcp.exit_stack.aclose() | |
return {"success": True} |