mcp / src /forge_agent.py
mgbam's picture
Update src/forge_agent.py
b105f0b verified
import asyncio
import json
from typing import List, Dict, Any
from huggingface_hub import AsyncInferenceClient
from .mcp_client import MCPClient
class ToolRegistry:
"""Manages connections to all required MCP servers and their tools."""
def __init__(self, server_urls: List[str]):
self.servers: Dict[str, MCPClient] = {url: MCPClient(url) for url in server_urls}
self.tools: Dict[str, Dict[str, Any]] = {}
async def discover_tools(self):
"""Discovers all available tools from all connected MCP servers."""
discovery_tasks = [client.list_tools() for client in self.servers.values()]
results = await asyncio.gather(*discovery_tasks, return_exceptions=True)
for i, client in enumerate(self.servers.values()):
server_tools = results[i]
if isinstance(server_tools, list):
for tool in server_tools:
self.tools[tool["name"]] = {"client": client, "description": tool["description"]}
async def execute(self, tool_name: str, params: Dict[str, Any]) -> Dict[str, Any]:
"""Finds the correct MCP server and executes the tool."""
if tool_name not in self.tools:
raise ValueError(f"Tool '{tool_name}' not found in registry.")
tool_info = self.tools[tool_name]
return await tool_info["client"].execute_tool(tool_name, params)
async def close_all(self):
"""Closes all client connections."""
await asyncio.gather(*(client.close() for client in self.servers.values()))
class HuggingFaceAgent:
"""An AI agent that uses a Hugging Face model to generate plans."""
def __init__(self, hf_token: str, model_name: str = "mistralai/Mixtral-8x7B-Instruct-v0.1"):
self.model = model_name
self.client = AsyncInferenceClient(model=model_name, token=hf_token)
def _construct_prompt(self, goal: str, available_tools: List[Dict[str, Any]], previous_steps: List = None, error: str = None) -> str:
"""Constructs the detailed prompt for the LLM."""
tools_json_string = json.dumps(available_tools, indent=2)
prompt = f"""You are Forge, an autonomous AI agent. Your task is to create a step-by-step plan to achieve a goal.
You must respond with a valid JSON array of objects, where each object represents a step in the plan.
Each step must have 'step', 'thought', 'tool', and 'params' keys.
The final step must always use the 'report_success' tool.
Available Tools:
{tools_json_string}
Goal: "{goal}"
"""
if previous_steps:
prompt += f"\nYou have already completed these steps:\n{json.dumps(previous_steps, indent=2)}\n"
if error:
prompt += f"\nAn error occurred during the last step: {error}\nAnalyze the error and create a new, corrected plan to achieve the original goal. Start the new plan from the current state."
prompt += "\nGenerate the JSON plan now:"
return prompt
async def _invoke_llm(self, prompt: str) -> List[Dict[str, Any]]:
"""Invokes the LLM and parses the JSON response."""
try:
response = await self.client.text_generation(prompt, max_new_tokens=1024)
# The response might contain the JSON within backticks or other text.
json_response_str = response.strip().split('```json')[-1].split('```')[0].strip()
plan = json.loads(json_response_str)
if isinstance(plan, list):
return plan
else:
raise ValueError("LLM did not return a JSON list.")
except (json.JSONDecodeError, ValueError, IndexError) as e:
print(f"Error parsing LLM response: {e}\nRaw response:\n{response}")
# Fallback or re-try logic could be added here
return [{"step": 1, "thought": "Failed to generate a plan due to a parsing error.", "tool": "report_failure", "params": {"message": f"LLM response parsing failed: {e}"}}]
async def generate_plan(self, goal: str, available_tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Generates a step-by-step plan.
"""
prompt = self._construct_prompt(goal, available_tools)
return await self._invoke_llm(prompt)
async def regenerate_plan_on_error(self, goal: str, available_tools: List[Dict[str, Any]], completed_steps: List, error_message: str) -> List[Dict[str, Any]]:
"""Generates a new plan after an error occurred."""
prompt = self._construct_prompt(goal, available_tools, previous_steps=completed_steps, error=error_message)
return await self._invoke_llm(prompt)
class ForgeApp:
"""The main orchestrator for the Forge application."""
def __init__(self, goal: str, mcp_server_urls: List[str], hf_token: str):
self.goal = goal
self.planner = HuggingFaceAgent(hf_token=hf_token)
self.tool_registry = ToolRegistry(server_urls=mcp_server_urls)
async def run(self):
"""
Runs the agent and yields status updates as a generator.
"""
yield "πŸš€ **Starting Forge... Initializing systems.**"
await self.tool_registry.discover_tools()
yield f"βœ… **Tool Discovery Complete.** Found {len(self.tool_registry.tools)} tools."
# Provide the LLM with full tool details, not just names
available_tools_details = [{"name": name, "description": data["description"]} for name, data in self.tool_registry.tools.items()]
yield f"🧠 **Generating a plan for your goal:** '{self.goal}'"
plan = await self.planner.generate_plan(self.goal, available_tools_details)
yield "πŸ“ **Plan Generated!** Starting execution..."
completed_steps = []
while plan:
task = plan.pop(0)
yield f"\n**[Step {task.get('step', '?')}]** πŸ€” **Thought:** {task.get('thought', 'N/A')}"
tool_name = task.get("tool")
if tool_name in ["report_success", "report_failure"]:
emoji = "πŸŽ‰" if tool_name == "report_success" else "πŸ›‘"
yield f"{emoji} **Final Result:** {task.get('params', {}).get('message', 'N/A')}"
plan = [] # End execution
continue
try:
yield f"πŸ› οΈ **Action:** Executing tool `{tool_name}` with params: `{task.get('params', {})}`"
result = await self.tool_registry.execute(tool_name, task.get("params", {}))
if result.get("status") == "error":
error_message = result.get('result', 'Unknown error')
yield f"❌ **Error:** {error_message}"
yield "🧠 **Agent is re-evaluating the plan based on the error...**"
completed_steps.append({"step": task, "outcome": "error", "details": error_message})
plan = await self.planner.regenerate_plan_on_error(self.goal, available_tools_details, completed_steps, error_message)
yield "πŸ“ **New Plan Generated!** Resuming execution..."
else:
observation = result.get('result', 'Tool executed successfully.')
yield f"βœ… **Observation:** {observation}"
completed_steps.append({"step": task, "outcome": "success", "details": observation})
except Exception as e:
yield f"❌ **Critical Error executing step {task.get('step', '?')}:** {e}"
yield "πŸ›‘ **Execution Halted due to critical error.**"
plan = [] # End execution
await self.tool_registry.close_all()
yield "\n🏁 **Forge execution finished.**"