File size: 7,695 Bytes
915668d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import asyncio
import json
from typing import List, Dict, Any
from huggingface_hub import AsyncInferenceClient
from .mcp_client import MCPClient

class ToolRegistry:
    """Manages connections to all required MCP servers and their tools."""
    def __init__(self, server_urls: List[str]):
        self.servers: Dict[str, MCPClient] = {url: MCPClient(url) for url in server_urls}
        self.tools: Dict[str, Dict[str, Any]] = {}

    async def discover_tools(self):
        """Discovers all available tools from all connected MCP servers."""
        discovery_tasks = [client.list_tools() for client in self.servers.values()]
        results = await asyncio.gather(*discovery_tasks, return_exceptions=True)

        for i, client in enumerate(self.servers.values()):
            server_tools = results[i]
            if isinstance(server_tools, list):
                for tool in server_tools:
                    self.tools[tool["name"]] = {"client": client, "description": tool["description"]}

    async def execute(self, tool_name: str, params: Dict[str, Any]) -> Dict[str, Any]:
        """Finds the correct MCP server and executes the tool."""
        if tool_name not in self.tools:
            raise ValueError(f"Tool '{tool_name}' not found in registry.")
        tool_info = self.tools[tool_name]
        return await tool_info["client"].execute_tool(tool_name, params)

    async def close_all(self):
        """Closes all client connections."""
        await asyncio.gather(*(client.close() for client in self.servers.values()))

class HuggingFaceAgent:
    """An AI agent that uses a Hugging Face model to generate plans."""
    def __init__(self, hf_token: str, model_name: str = "mistralai/Mixtral-8x7B-Instruct-v0.1"):
        self.model = model_name
        self.client = AsyncInferenceClient(model=model_name, token=hf_token)

    def _construct_prompt(self, goal: str, available_tools: List[Dict[str, Any]], previous_steps: List = None, error: str = None) -> str:
        """Constructs the detailed prompt for the LLM."""
        tools_json_string = json.dumps(available_tools, indent=2)
        
        prompt = f"""You are Forge, an autonomous AI agent. Your task is to create a step-by-step plan to achieve a goal.
You must respond with a valid JSON array of objects, where each object represents a step in the plan.
Each step must have 'step', 'thought', 'tool', and 'params' keys.
The final step must always use the 'report_success' tool.

Available Tools:
{tools_json_string}

Goal: "{goal}"
"""
        if previous_steps:
            prompt += f"\nYou have already completed these steps:\n{json.dumps(previous_steps, indent=2)}\n"
        if error:
            prompt += f"\nAn error occurred during the last step: {error}\nAnalyze the error and create a new, corrected plan to achieve the original goal. Start the new plan from the current state."

        prompt += "\nGenerate the JSON plan now:"
        return prompt

    async def _invoke_llm(self, prompt: str) -> List[Dict[str, Any]]:
        """Invokes the LLM and parses the JSON response."""
        try:
            response = await self.client.text_generation(prompt, max_new_tokens=1024)
            # The response might contain the JSON within backticks or other text.
            json_response_str = response.strip().split('```json')[-1].split('```')[0].strip()
            plan = json.loads(json_response_str)
            if isinstance(plan, list):
                return plan
            else:
                raise ValueError("LLM did not return a JSON list.")
        except (json.JSONDecodeError, ValueError, IndexError) as e:
            print(f"Error parsing LLM response: {e}\nRaw response:\n{response}")
            # Fallback or re-try logic could be added here
            return [{"step": 1, "thought": "Failed to generate a plan due to a parsing error.", "tool": "report_failure", "params": {"message": f"LLM response parsing failed: {e}"}}]

    async def generate_plan(self, goal: str, available_tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        """
        Generates a step-by-step plan.
        """
        prompt = self._construct_prompt(goal, available_tools)
        return await self._invoke_llm(prompt)

    async def regenerate_plan_on_error(self, goal: str, available_tools: List[Dict[str, Any]], completed_steps: List, error_message: str) -> List[Dict[str, Any]]:
        """Generates a new plan after an error occurred."""
        prompt = self._construct_prompt(goal, available_tools, previous_steps=completed_steps, error=error_message)
        return await self._invoke_llm(prompt)

class ForgeApp:
    """The main orchestrator for the Forge application."""
    def __init__(self, goal: str, mcp_server_urls: List[str], hf_token: str):
        self.goal = goal
        self.planner = HuggingFaceAgent(hf_token=hf_token)
        self.tool_registry = ToolRegistry(server_urls=mcp_server_urls)

    async def run(self):
        """
        Runs the agent and yields status updates as a generator.
        """
        yield "πŸš€ **Starting Forge... Initializing systems.**"
        await self.tool_registry.discover_tools()
        yield f"βœ… **Tool Discovery Complete.** Found {len(self.tool_registry.tools)} tools."
        
        # Provide the LLM with full tool details, not just names
        available_tools_details = [{"name": name, "description": data["description"]} for name, data in self.tool_registry.tools.items()]
        
        yield f"🧠 **Generating a plan for your goal:** '{self.goal}'"
        plan = await self.planner.generate_plan(self.goal, available_tools_details)
        yield "πŸ“ **Plan Generated!** Starting execution..."

        completed_steps = []
        while plan:
            task = plan.pop(0)
            yield f"\n**[Step {task.get('step', '?')}]** πŸ€” **Thought:** {task.get('thought', 'N/A')}"

            tool_name = task.get("tool")
            if tool_name in ["report_success", "report_failure"]:
                emoji = "πŸŽ‰" if tool_name == "report_success" else "πŸ›‘"
                yield f"{emoji} **Final Result:** {task.get('params', {}).get('message', 'N/A')}"
                plan = [] # End execution
                continue

            try:
                yield f"πŸ› οΈ **Action:** Executing tool `{tool_name}` with params: `{task.get('params', {})}`"
                result = await self.tool_registry.execute(tool_name, task.get("params", {}))
                
                if result.get("status") == "error":
                    error_message = result.get('result', 'Unknown error')
                    yield f"❌ **Error:** {error_message}"
                    yield "🧠 **Agent is re-evaluating the plan based on the error...**"
                    completed_steps.append({"step": task, "outcome": "error", "details": error_message})
                    plan = await self.planner.regenerate_plan_on_error(self.goal, available_tools_details, completed_steps, error_message)
                    yield "πŸ“ **New Plan Generated!** Resuming execution..."
                else:
                    observation = result.get('result', 'Tool executed successfully.')
                    yield f"βœ… **Observation:** {observation}"
                    completed_steps.append({"step": task, "outcome": "success", "details": observation})

            except Exception as e:
                yield f"❌ **Critical Error executing step {task.get('step', '?')}:** {e}"
                yield "πŸ›‘ **Execution Halted due to critical error.**"
                plan = [] # End execution
        
        await self.tool_registry.close_all()
        yield "\n🏁 **Forge execution finished.**"