|
|
|
|
|
|
|
|
|
import asyncio |
|
import json |
|
import traceback |
|
from typing import List, Optional |
|
from contextlib import AsyncExitStack |
|
import uuid |
|
|
|
from fastapi.staticfiles import StaticFiles |
|
from pydantic import BaseModel |
|
from fastapi import FastAPI, Request, HTTPException, Depends |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from fastapi.responses import FileResponse, JSONResponse |
|
|
|
|
|
from mcp import ClientSession, StdioServerParameters |
|
from mcp.client.stdio import stdio_client |
|
|
|
from groq import Groq, APIConnectionError |
|
from dotenv import load_dotenv |
|
import os |
|
import httpx |
|
|
|
sessions = {} |
|
unique_apikeys = [] |
|
|
|
class MCPClient: |
|
def __init__(self): |
|
self.session: Optional[ClientSession] = None |
|
self.exit_stack = AsyncExitStack() |
|
self.current_model = None |
|
self.groq = None |
|
self.api_key = None |
|
self.messages = [{ |
|
"role": "system", |
|
"content": "You are a helpful assistant that have access to different tools via MCP. Make complete answers." |
|
}] |
|
self.tool_use = True |
|
self.models = None |
|
self.tools = [] |
|
|
|
async def connect(self, api_key: str): |
|
try: |
|
self.groq = Groq(api_key=api_key, http_client=httpx.Client(verify=False, timeout=30)) |
|
self.api_key = api_key |
|
except APIConnectionError as e: |
|
traceback.print_exception(e) |
|
return False |
|
except Exception as e: |
|
traceback.print_exception(e) |
|
return False |
|
server_params = StdioServerParameters(command="uv", args=["run", "server.py"]) |
|
|
|
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params)) |
|
self.stdio, self.write = stdio_transport |
|
self.session = await self.exit_stack.enter_async_context(ClientSession(self.stdio, self.write)) |
|
|
|
await self.session.initialize() |
|
|
|
response = await self.session.list_tools() |
|
tools = response.tools |
|
print("\nConnected to server with tools:", [tool.name for tool in tools]) |
|
self.tools = [{"type": "function", "function": { |
|
"name": tool.name, |
|
"description": tool.description, |
|
"parameters": tool.inputSchema |
|
}} for tool in tools] |
|
|
|
def populate_model(self): |
|
self.models = sorted([m.id for m in self.groq.models.list().data]) |
|
|
|
async def process_query(self, query: str) -> str: |
|
"""Process a query using Groq and available tools""" |
|
self.messages.extend([ |
|
{ |
|
"role": "user", |
|
"content": query |
|
} |
|
]) |
|
|
|
response = self.groq.chat.completions.create( |
|
model=self.current_model, |
|
messages=self.messages, |
|
tools=self.tools, |
|
temperature=0 |
|
) if self.tool_use else self.groq.chat.completions.create( |
|
model=self.current_model, |
|
messages=self.messages, |
|
temperature=0.7 |
|
) |
|
|
|
|
|
final_text = [] |
|
|
|
for choice in response.choices: |
|
content = choice.message.content |
|
tool_calls = choice.message.tool_calls |
|
if content: |
|
final_text.append(content) |
|
if tool_calls: |
|
print(tool_calls) |
|
for tool in tool_calls: |
|
tool_name = tool.function.name |
|
tool_args = tool.function.arguments |
|
|
|
result = await self.session.call_tool(tool_name, json.loads(tool_args)) |
|
print(f"[Calling tool {tool_name} with args {tool_args}]") |
|
|
|
if content is not None: |
|
self.messages.append({ |
|
"role": "assistant", |
|
"content": content |
|
}) |
|
self.messages.append({ |
|
"role": "tool", |
|
"tool_call_id": tool.id, |
|
"content": str(result.content) |
|
}) |
|
|
|
print(result.content[0].text) |
|
|
|
response = self.groq.chat.completions.create( |
|
model=self.current_model, |
|
messages=self.messages, |
|
temperature=0.7 |
|
) |
|
|
|
final_text.append(response.choices[0].message.content) |
|
return "\n".join(final_text) |
|
|
|
app = FastAPI() |
|
app.add_middleware(CORSMiddleware, allow_credentials=True, allow_headers=["*"], allow_methods=["*"], allow_origins=["*"]) |
|
app.mount("/static", StaticFiles(directory="static"), name="static") |
|
mcp = MCPClient() |
|
|
|
class InitRequest(BaseModel): |
|
api_key: str |
|
|
|
class InitResponse(BaseModel): |
|
success: bool |
|
session_id: str |
|
models: Optional[list] = None |
|
error: Optional[str] = None |
|
|
|
class LogoutRequest(BaseModel): |
|
session_id: str |
|
|
|
def get_mcp_client(session_id: str) -> MCPClient|None: |
|
"""Get the MCPClient for a given session_id, or raise 404.""" |
|
client = sessions.get(session_id) |
|
if client is None: |
|
raise HTTPException(status_code=404, detail="Invalid session_id. Please re-initialize.") |
|
return client |
|
|
|
@app.get("/") |
|
def root(): |
|
return FileResponse("index.html") |
|
|
|
@app.post("/init", response_model=InitResponse) |
|
async def init_server(req: InitRequest): |
|
""" |
|
Initializes a new MCP client session. Returns a session_id. |
|
""" |
|
api_key = req.api_key |
|
session_id = str(uuid.uuid4()) |
|
mcp = MCPClient() |
|
|
|
try: |
|
ok = await mcp.connect(api_key) |
|
if ok is False: |
|
raise RuntimeError("Failed to connect to MCP or Groq with API key.") |
|
mcp.populate_model() |
|
|
|
sessions[session_id] = mcp |
|
if api_key not in unique_apikeys: |
|
unique_apikeys.append(api_key) |
|
else: |
|
raise Exception("Session with this API key already exists. We won't re-return you the session ID. Bye-bye Hacker !!") |
|
return InitResponse( |
|
session_id=session_id, |
|
models=mcp.models, |
|
error=None, |
|
success=True |
|
) |
|
except Exception as e: |
|
traceback.print_exception(e) |
|
return InitResponse( |
|
session_id="", |
|
models=None, |
|
error=str(e), |
|
success=False |
|
) |
|
|
|
class ChatRequest(BaseModel): |
|
session_id: str |
|
query: str |
|
tool_use: Optional[bool] = True |
|
model: Optional[str] = "llama-3.3-70b-versatile" |
|
|
|
class ChatResponse(BaseModel): |
|
output: str |
|
error: Optional[str] = None |
|
|
|
@app.post("/chat", response_model=ChatResponse) |
|
async def chat(req: ChatRequest): |
|
""" |
|
Handles chat requests for a given session. |
|
""" |
|
try: |
|
mcp = get_mcp_client(req.session_id) |
|
mcp.tool_use = req.tool_use |
|
if req.model in mcp.models: |
|
mcp.current_model = req.model |
|
else: |
|
raise ValueError(f"Model not recognized: Not in the model list: {mcp.models}") |
|
result = await mcp.process_query(req.query) |
|
return ChatResponse(output=result) |
|
except Exception as e: |
|
traceback.print_exception(e) |
|
return ChatResponse(output="", error=str(e)) |
|
|
|
@app.post("/logout") |
|
async def logout(logout_req: LogoutRequest): |
|
"""Clean up session resources.""" |
|
mcp = sessions.pop(logout_req.session_id, None) |
|
unique_apikeys.remove(mcp.api_key) |
|
if mcp and hasattr(mcp.exit_stack, "aclose"): |
|
await mcp.exit_stack.aclose() |
|
return {"success": True} |