|
import asyncio |
|
import json |
|
from typing import List, Dict, Any |
|
from huggingface_hub import AsyncInferenceClient |
|
from .mcp_client import MCPClient |
|
|
|
class ToolRegistry: |
|
"""Manages connections to all required MCP servers and their tools.""" |
|
def __init__(self, server_urls: List[str]): |
|
self.servers: Dict[str, MCPClient] = {url: MCPClient(url) for url in server_urls} |
|
self.tools: Dict[str, Dict[str, Any]] = {} |
|
|
|
async def discover_tools(self): |
|
"""Discovers all available tools from all connected MCP servers.""" |
|
discovery_tasks = [client.list_tools() for client in self.servers.values()] |
|
results = await asyncio.gather(*discovery_tasks, return_exceptions=True) |
|
|
|
for i, client in enumerate(self.servers.values()): |
|
server_tools = results[i] |
|
if isinstance(server_tools, list): |
|
for tool in server_tools: |
|
self.tools[tool["name"]] = {"client": client, "description": tool["description"]} |
|
|
|
async def execute(self, tool_name: str, params: Dict[str, Any]) -> Dict[str, Any]: |
|
"""Finds the correct MCP server and executes the tool.""" |
|
if tool_name not in self.tools: |
|
raise ValueError(f"Tool '{tool_name}' not found in registry.") |
|
tool_info = self.tools[tool_name] |
|
return await tool_info["client"].execute_tool(tool_name, params) |
|
|
|
async def close_all(self): |
|
"""Closes all client connections.""" |
|
await asyncio.gather(*(client.close() for client in self.servers.values())) |
|
|
|
class HuggingFaceAgent: |
|
"""An AI agent that uses a Hugging Face model to generate plans.""" |
|
def __init__(self, hf_token: str, model_name: str = "mistralai/Mixtral-8x7B-Instruct-v0.1"): |
|
self.model = model_name |
|
self.client = AsyncInferenceClient(model=model_name, token=hf_token) |
|
|
|
def _construct_prompt(self, goal: str, available_tools: List[Dict[str, Any]], previous_steps: List = None, error: str = None) -> str: |
|
"""Constructs the detailed prompt for the LLM.""" |
|
tools_json_string = json.dumps(available_tools, indent=2) |
|
|
|
prompt = f"""You are Forge, an autonomous AI agent. Your task is to create a step-by-step plan to achieve a goal. |
|
You must respond with a valid JSON array of objects, where each object represents a step in the plan. |
|
Each step must have 'step', 'thought', 'tool', and 'params' keys. |
|
The final step must always use the 'report_success' tool. |
|
|
|
Available Tools: |
|
{tools_json_string} |
|
|
|
Goal: "{goal}" |
|
""" |
|
if previous_steps: |
|
prompt += f"\nYou have already completed these steps:\n{json.dumps(previous_steps, indent=2)}\n" |
|
if error: |
|
prompt += f"\nAn error occurred during the last step: {error}\nAnalyze the error and create a new, corrected plan to achieve the original goal. Start the new plan from the current state." |
|
|
|
prompt += "\nGenerate the JSON plan now:" |
|
return prompt |
|
|
|
async def _invoke_llm(self, prompt: str) -> List[Dict[str, Any]]: |
|
"""Invokes the LLM and parses the JSON response.""" |
|
try: |
|
response = await self.client.text_generation(prompt, max_new_tokens=1024) |
|
|
|
json_response_str = response.strip().split('```json')[-1].split('```')[0].strip() |
|
plan = json.loads(json_response_str) |
|
if isinstance(plan, list): |
|
return plan |
|
else: |
|
raise ValueError("LLM did not return a JSON list.") |
|
except (json.JSONDecodeError, ValueError, IndexError) as e: |
|
print(f"Error parsing LLM response: {e}\nRaw response:\n{response}") |
|
|
|
return [{"step": 1, "thought": "Failed to generate a plan due to a parsing error.", "tool": "report_failure", "params": {"message": f"LLM response parsing failed: {e}"}}] |
|
|
|
async def generate_plan(self, goal: str, available_tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]: |
|
""" |
|
Generates a step-by-step plan. |
|
""" |
|
prompt = self._construct_prompt(goal, available_tools) |
|
return await self._invoke_llm(prompt) |
|
|
|
async def regenerate_plan_on_error(self, goal: str, available_tools: List[Dict[str, Any]], completed_steps: List, error_message: str) -> List[Dict[str, Any]]: |
|
"""Generates a new plan after an error occurred.""" |
|
prompt = self._construct_prompt(goal, available_tools, previous_steps=completed_steps, error=error_message) |
|
return await self._invoke_llm(prompt) |
|
|
|
class ForgeApp: |
|
"""The main orchestrator for the Forge application.""" |
|
def __init__(self, goal: str, mcp_server_urls: List[str], hf_token: str): |
|
self.goal = goal |
|
self.planner = HuggingFaceAgent(hf_token=hf_token) |
|
self.tool_registry = ToolRegistry(server_urls=mcp_server_urls) |
|
|
|
async def run(self): |
|
""" |
|
Runs the agent and yields status updates as a generator. |
|
""" |
|
yield "π **Starting Forge... Initializing systems.**" |
|
await self.tool_registry.discover_tools() |
|
yield f"β
**Tool Discovery Complete.** Found {len(self.tool_registry.tools)} tools." |
|
|
|
|
|
available_tools_details = [{"name": name, "description": data["description"]} for name, data in self.tool_registry.tools.items()] |
|
|
|
yield f"π§ **Generating a plan for your goal:** '{self.goal}'" |
|
plan = await self.planner.generate_plan(self.goal, available_tools_details) |
|
yield "π **Plan Generated!** Starting execution..." |
|
|
|
completed_steps = [] |
|
while plan: |
|
task = plan.pop(0) |
|
yield f"\n**[Step {task.get('step', '?')}]** π€ **Thought:** {task.get('thought', 'N/A')}" |
|
|
|
tool_name = task.get("tool") |
|
if tool_name in ["report_success", "report_failure"]: |
|
emoji = "π" if tool_name == "report_success" else "π" |
|
yield f"{emoji} **Final Result:** {task.get('params', {}).get('message', 'N/A')}" |
|
plan = [] |
|
continue |
|
|
|
try: |
|
yield f"π οΈ **Action:** Executing tool `{tool_name}` with params: `{task.get('params', {})}`" |
|
result = await self.tool_registry.execute(tool_name, task.get("params", {})) |
|
|
|
if result.get("status") == "error": |
|
error_message = result.get('result', 'Unknown error') |
|
yield f"β **Error:** {error_message}" |
|
yield "π§ **Agent is re-evaluating the plan based on the error...**" |
|
completed_steps.append({"step": task, "outcome": "error", "details": error_message}) |
|
plan = await self.planner.regenerate_plan_on_error(self.goal, available_tools_details, completed_steps, error_message) |
|
yield "π **New Plan Generated!** Resuming execution..." |
|
else: |
|
observation = result.get('result', 'Tool executed successfully.') |
|
yield f"β
**Observation:** {observation}" |
|
completed_steps.append({"step": task, "outcome": "success", "details": observation}) |
|
|
|
except Exception as e: |
|
yield f"β **Critical Error executing step {task.get('step', '?')}:** {e}" |
|
yield "π **Execution Halted due to critical error.**" |
|
plan = [] |
|
|
|
await self.tool_registry.close_all() |
|
yield "\nπ **Forge execution finished.**" |