|
import os, time, random |
|
from dotenv import load_dotenv |
|
from typing import List, Dict, Any, TypedDict, Annotated |
|
import operator |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
from langgraph.graph import StateGraph, END |
|
from langgraph.prebuilt import create_react_agent |
|
from langgraph.checkpoint.memory import MemorySaver |
|
|
|
|
|
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage |
|
from langchain_core.tools import tool |
|
from langchain_groq import ChatGroq |
|
from langchain_google_genai import ChatGoogleGenerativeAI |
|
from langchain_nvidia_ai_endpoints import ChatNVIDIA |
|
from langchain_core.rate_limiters import InMemoryRateLimiter |
|
|
|
|
|
from tavily import TavilyClient |
|
|
|
|
|
class AdvancedRateLimiter: |
|
def __init__(self, requests_per_minute: int): |
|
self.requests_per_minute = requests_per_minute |
|
self.request_times = [] |
|
|
|
def wait_if_needed(self): |
|
current_time = time.time() |
|
|
|
self.request_times = [t for t in self.request_times if current_time - t < 60] |
|
|
|
|
|
if len(self.request_times) >= self.requests_per_minute: |
|
wait_time = 60 - (current_time - self.request_times[0]) + random.uniform(2, 8) |
|
time.sleep(wait_time) |
|
|
|
|
|
self.request_times.append(current_time) |
|
|
|
|
|
groq_limiter = AdvancedRateLimiter(requests_per_minute=30) |
|
gemini_limiter = AdvancedRateLimiter(requests_per_minute=2) |
|
nvidia_limiter = AdvancedRateLimiter(requests_per_minute=5) |
|
tavily_limiter = AdvancedRateLimiter(requests_per_minute=50) |
|
|
|
|
|
nvidia_rate_limiter = InMemoryRateLimiter( |
|
requests_per_second=0.083, |
|
check_every_n_seconds=0.1, |
|
max_bucket_size=5 |
|
) |
|
|
|
|
|
groq_llm = ChatGroq( |
|
model="llama-3.3-70b-versatile", |
|
api_key=os.getenv("GROQ_API_KEY"), |
|
temperature=0 |
|
) |
|
|
|
gemini_llm = ChatGoogleGenerativeAI( |
|
model="gemini-2.0-flash-thinking-exp", |
|
api_key=os.getenv("GOOGLE_API_KEY"), |
|
temperature=0 |
|
) |
|
|
|
|
|
nvidia_general_llm = ChatNVIDIA( |
|
model="meta/llama3-70b-instruct", |
|
api_key=os.getenv("NVIDIA_API_KEY"), |
|
temperature=0, |
|
max_tokens=4000, |
|
rate_limiter=nvidia_rate_limiter |
|
) |
|
|
|
nvidia_code_llm = ChatNVIDIA( |
|
model="meta/codellama-70b", |
|
api_key=os.getenv("NVIDIA_API_KEY"), |
|
temperature=0, |
|
max_tokens=4000, |
|
rate_limiter=nvidia_rate_limiter |
|
) |
|
|
|
nvidia_math_llm = ChatNVIDIA( |
|
model="mistralai/mixtral-8x22b-instruct-v0.1", |
|
api_key=os.getenv("NVIDIA_API_KEY"), |
|
temperature=0, |
|
max_tokens=4000, |
|
rate_limiter=nvidia_rate_limiter |
|
) |
|
|
|
|
|
tavily_client = TavilyClient(api_key=os.getenv("TAVILY_API_KEY")) |
|
|
|
|
|
class AgentState(TypedDict): |
|
messages: Annotated[List[HumanMessage | AIMessage], operator.add] |
|
query: str |
|
agent_type: str |
|
final_answer: str |
|
|
|
|
|
@tool |
|
def multiply_tool(a: float, b: float) -> float: |
|
"""Multiply two numbers together""" |
|
return a * b |
|
|
|
@tool |
|
def add_tool(a: float, b: float) -> float: |
|
"""Add two numbers together""" |
|
return a + b |
|
|
|
@tool |
|
def subtract_tool(a: float, b: float) -> float: |
|
"""Subtract two numbers""" |
|
return a - b |
|
|
|
@tool |
|
def divide_tool(a: float, b: float) -> float: |
|
"""Divide two numbers""" |
|
if b == 0: |
|
raise ValueError("Cannot divide by zero.") |
|
return a / b |
|
|
|
@tool |
|
def tavily_search_tool(query: str) -> str: |
|
"""Search the web using Tavily for current information""" |
|
try: |
|
tavily_limiter.wait_if_needed() |
|
response = tavily_client.search( |
|
query=query, |
|
max_results=3, |
|
search_depth="basic", |
|
include_answer=False |
|
) |
|
|
|
|
|
results = [] |
|
for result in response.get('results', []): |
|
results.append(f"Title: {result.get('title', '')}\nContent: {result.get('content', '')}") |
|
|
|
return "\n\n---\n\n".join(results) |
|
|
|
except Exception as e: |
|
return f"Tavily search failed: {str(e)}" |
|
|
|
@tool |
|
def wiki_search_tool(query: str) -> str: |
|
"""Search Wikipedia for encyclopedic information""" |
|
try: |
|
time.sleep(random.uniform(1, 3)) |
|
from langchain_community.document_loaders import WikipediaLoader |
|
loader = WikipediaLoader(query=query, load_max_docs=1) |
|
data = loader.load() |
|
return "\n\n---\n\n".join([doc.page_content[:1000] for doc in data]) |
|
except Exception as e: |
|
return f"Wikipedia search failed: {str(e)}" |
|
|
|
|
|
math_tools = [multiply_tool, add_tool, subtract_tool, divide_tool] |
|
research_tools = [tavily_search_tool, wiki_search_tool] |
|
coordinator_tools = [tavily_search_tool, wiki_search_tool] |
|
|
|
|
|
def router_node(state: AgentState) -> AgentState: |
|
"""Route queries to appropriate agent type""" |
|
query = state["query"].lower() |
|
|
|
if any(word in query for word in ['calculate', 'math', 'multiply', 'add', 'subtract', 'divide', 'compute']): |
|
agent_type = "math" |
|
elif any(word in query for word in ['code', 'program', 'python', 'javascript', 'function', 'algorithm']): |
|
agent_type = "code" |
|
elif any(word in query for word in ['search', 'find', 'research', 'what is', 'who is', 'when', 'where']): |
|
agent_type = "research" |
|
else: |
|
agent_type = "coordinator" |
|
|
|
return {**state, "agent_type": agent_type} |
|
|
|
def math_agent_node(state: AgentState) -> AgentState: |
|
"""Mathematical specialist agent using NVIDIA Mixtral""" |
|
nvidia_limiter.wait_if_needed() |
|
|
|
system_message = SystemMessage(content="""You are a mathematical specialist with access to calculation tools. |
|
Use the appropriate math tools for calculations. |
|
Show your work step by step. |
|
Always provide precise numerical answers. |
|
Finish with: FINAL ANSWER: [numerical result]""") |
|
|
|
|
|
math_agent = create_react_agent(nvidia_math_llm, math_tools) |
|
|
|
|
|
messages = [system_message, HumanMessage(content=state["query"])] |
|
config = {"configurable": {"thread_id": "math_thread"}} |
|
|
|
try: |
|
result = math_agent.invoke({"messages": messages}, config) |
|
final_message = result["messages"][-1].content |
|
|
|
return { |
|
**state, |
|
"messages": state["messages"] + [AIMessage(content=final_message)], |
|
"final_answer": final_message |
|
} |
|
except Exception as e: |
|
error_msg = f"Math agent error: {str(e)}" |
|
return { |
|
**state, |
|
"messages": state["messages"] + [AIMessage(content=error_msg)], |
|
"final_answer": error_msg |
|
} |
|
|
|
def code_agent_node(state: AgentState) -> AgentState: |
|
"""Code generation specialist agent using NVIDIA CodeLlama""" |
|
nvidia_limiter.wait_if_needed() |
|
|
|
system_message = SystemMessage(content="""You are an expert coding AI specialist. |
|
Generate clean, efficient, and well-documented code. |
|
Explain your code solutions clearly. |
|
Always provide working code examples. |
|
Finish with: FINAL ANSWER: [your code solution]""") |
|
|
|
|
|
code_agent = create_react_agent(nvidia_code_llm, []) |
|
|
|
|
|
messages = [system_message, HumanMessage(content=state["query"])] |
|
config = {"configurable": {"thread_id": "code_thread"}} |
|
|
|
try: |
|
result = code_agent.invoke({"messages": messages}, config) |
|
final_message = result["messages"][-1].content |
|
|
|
return { |
|
**state, |
|
"messages": state["messages"] + [AIMessage(content=final_message)], |
|
"final_answer": final_message |
|
} |
|
except Exception as e: |
|
error_msg = f"Code agent error: {str(e)}" |
|
return { |
|
**state, |
|
"messages": state["messages"] + [AIMessage(content=error_msg)], |
|
"final_answer": error_msg |
|
} |
|
|
|
def research_agent_node(state: AgentState) -> AgentState: |
|
"""Research specialist agent using Gemini""" |
|
gemini_limiter.wait_if_needed() |
|
|
|
system_message = SystemMessage(content="""You are a research specialist with access to web search and Wikipedia. |
|
Use appropriate search tools to gather comprehensive information. |
|
Always cite sources and provide well-researched answers. |
|
Synthesize information from multiple sources when possible. |
|
Finish with: FINAL ANSWER: [your researched answer]""") |
|
|
|
|
|
research_agent = create_react_agent(gemini_llm, research_tools) |
|
|
|
|
|
messages = [system_message, HumanMessage(content=state["query"])] |
|
config = {"configurable": {"thread_id": "research_thread"}} |
|
|
|
try: |
|
result = research_agent.invoke({"messages": messages}, config) |
|
final_message = result["messages"][-1].content |
|
|
|
return { |
|
**state, |
|
"messages": state["messages"] + [AIMessage(content=final_message)], |
|
"final_answer": final_message |
|
} |
|
except Exception as e: |
|
error_msg = f"Research agent error: {str(e)}" |
|
return { |
|
**state, |
|
"messages": state["messages"] + [AIMessage(content=error_msg)], |
|
"final_answer": error_msg |
|
} |
|
|
|
def coordinator_agent_node(state: AgentState) -> AgentState: |
|
"""Coordinator agent using NVIDIA Llama3""" |
|
nvidia_limiter.wait_if_needed() |
|
|
|
system_message = SystemMessage(content="""You are the main coordinator agent. |
|
Analyze queries and provide comprehensive responses. |
|
Use search tools for factual information when needed. |
|
Always finish with: FINAL ANSWER: [your final answer]""") |
|
|
|
|
|
coordinator_agent = create_react_agent(nvidia_general_llm, coordinator_tools) |
|
|
|
|
|
messages = [system_message, HumanMessage(content=state["query"])] |
|
config = {"configurable": {"thread_id": "coordinator_thread"}} |
|
|
|
try: |
|
result = coordinator_agent.invoke({"messages": messages}, config) |
|
final_message = result["messages"][-1].content |
|
|
|
return { |
|
**state, |
|
"messages": state["messages"] + [AIMessage(content=final_message)], |
|
"final_answer": final_message |
|
} |
|
except Exception as e: |
|
error_msg = f"Coordinator agent error: {str(e)}" |
|
return { |
|
**state, |
|
"messages": state["messages"] + [AIMessage(content=error_msg)], |
|
"final_answer": error_msg |
|
} |
|
|
|
|
|
def route_agent(state: AgentState) -> str: |
|
"""Route to appropriate agent based on agent_type""" |
|
agent_type = state.get("agent_type", "coordinator") |
|
|
|
if agent_type == "math": |
|
return "math_agent" |
|
elif agent_type == "code": |
|
return "code_agent" |
|
elif agent_type == "research": |
|
return "research_agent" |
|
else: |
|
return "coordinator_agent" |
|
|
|
|
|
class LangGraphMultiAgentSystem: |
|
def __init__(self): |
|
self.request_count = 0 |
|
self.last_request_time = time.time() |
|
self.graph = self._create_graph() |
|
|
|
def _create_graph(self) -> StateGraph: |
|
"""Create the LangGraph workflow""" |
|
workflow = StateGraph(AgentState) |
|
|
|
|
|
workflow.add_node("router", router_node) |
|
workflow.add_node("math_agent", math_agent_node) |
|
workflow.add_node("code_agent", code_agent_node) |
|
workflow.add_node("research_agent", research_agent_node) |
|
workflow.add_node("coordinator_agent", coordinator_agent_node) |
|
|
|
|
|
workflow.set_entry_point("router") |
|
workflow.add_conditional_edges( |
|
"router", |
|
route_agent, |
|
{ |
|
"math_agent": "math_agent", |
|
"code_agent": "code_agent", |
|
"research_agent": "research_agent", |
|
"coordinator_agent": "coordinator_agent" |
|
} |
|
) |
|
|
|
|
|
workflow.add_edge("math_agent", END) |
|
workflow.add_edge("code_agent", END) |
|
workflow.add_edge("research_agent", END) |
|
workflow.add_edge("coordinator_agent", END) |
|
|
|
|
|
memory = MemorySaver() |
|
return workflow.compile(checkpointer=memory) |
|
|
|
def process_query(self, query: str) -> str: |
|
"""Process query using LangGraph multi-agent system""" |
|
|
|
current_time = time.time() |
|
if current_time - self.last_request_time > 3600: |
|
self.request_count = 0 |
|
self.last_request_time = current_time |
|
|
|
self.request_count += 1 |
|
|
|
|
|
if self.request_count > 1: |
|
time.sleep(random.uniform(3, 10)) |
|
|
|
|
|
initial_state = { |
|
"messages": [HumanMessage(content=query)], |
|
"query": query, |
|
"agent_type": "", |
|
"final_answer": "" |
|
} |
|
|
|
|
|
config = {"configurable": {"thread_id": f"thread_{self.request_count}"}} |
|
|
|
try: |
|
|
|
final_state = self.graph.invoke(initial_state, config) |
|
return final_state.get("final_answer", "No response generated") |
|
|
|
except Exception as e: |
|
return f"Error: {str(e)}" |
|
|
|
|
|
def main(query: str) -> str: |
|
"""Main function using LangGraph multi-agent system""" |
|
langgraph_system = LangGraphMultiAgentSystem() |
|
return langgraph_system.process_query(query) |
|
|
|
def get_final_answer(query: str) -> str: |
|
"""Extract only the FINAL ANSWER from the response""" |
|
full_response = main(query) |
|
|
|
if "FINAL ANSWER:" in full_response: |
|
final_answer = full_response.split("FINAL ANSWER:")[-1].strip() |
|
return final_answer |
|
else: |
|
return full_response.strip() |
|
|
|
if __name__ == "__main__": |
|
|
|
result = get_final_answer("What are the names of the US presidents who were assassinated?") |
|
print(result) |
|
|