|
"""Enhanced LangGraph + Agno Hybrid Agent System""" |
|
import os |
|
import time |
|
import random |
|
from dotenv import load_dotenv |
|
from typing import List, Dict, Any, TypedDict, Annotated |
|
import operator |
|
|
|
|
|
from langgraph.graph import START, StateGraph, MessagesState |
|
from langgraph.prebuilt import tools_condition, ToolNode |
|
from langgraph.checkpoint.memory import MemorySaver |
|
|
|
|
|
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage |
|
from langchain_core.tools import tool |
|
from langchain_groq import ChatGroq |
|
from langchain_google_genai import ChatGoogleGenerativeAI |
|
from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings |
|
from langchain_community.tools.tavily_search import TavilySearchResults |
|
from langchain_community.document_loaders import WikipediaLoader, ArxivLoader, JSONLoader |
|
from langchain_community.vectorstores import FAISS |
|
from langchain.tools.retriever import create_retriever_tool |
|
from langchain_text_splitters import RecursiveCharacterTextSplitter |
|
|
|
|
|
from agno.agent import Agent |
|
from agno.models.groq import Groq |
|
from agno.models.google import Gemini |
|
from agno.tools.tavily import TavilyTools |
|
from agno.memory.agent import AgentMemory |
|
from agno.storage.sqlite import SqliteStorage |
|
from agno.memory.v2.db.sqlite import SqliteMemoryDb |
|
|
|
load_dotenv() |
|
|
|
|
|
class PerformanceRateLimiter: |
|
def __init__(self, rpm: int, name: str): |
|
self.rpm = rpm |
|
self.name = name |
|
self.times: List[float] = [] |
|
self.failures = 0 |
|
|
|
def wait_if_needed(self): |
|
now = time.time() |
|
self.times = [t for t in self.times if now - t < 60] |
|
if len(self.times) >= self.rpm: |
|
wait = 60 - (now - self.times[0]) + random.uniform(1, 3) |
|
time.sleep(wait) |
|
if self.failures: |
|
backoff = min(2 ** self.failures, 30) + random.uniform(0.5, 1.5) |
|
time.sleep(backoff) |
|
self.times.append(now) |
|
|
|
def record_success(self): |
|
self.failures = 0 |
|
|
|
def record_failure(self): |
|
self.failures += 1 |
|
|
|
|
|
gemini_limiter = PerformanceRateLimiter(28, "Gemini") |
|
groq_limiter = PerformanceRateLimiter(28, "Groq") |
|
nvidia_limiter = PerformanceRateLimiter(4, "NVIDIA") |
|
|
|
|
|
def create_agno_agents(): |
|
|
|
storage = SqliteStorage( |
|
table_name="agent_sessions", |
|
db_file="tmp/agent_sessions.db", |
|
auto_upgrade_schema=True |
|
) |
|
|
|
memory_db = SqliteMemoryDb(db_file="tmp/agent_memory.db") |
|
|
|
|
|
agent_memory = AgentMemory( |
|
db=memory_db, |
|
create_user_memories=True, |
|
create_session_summary=True |
|
) |
|
|
|
math_agent = Agent( |
|
name="MathSpecialist", |
|
model=Groq( |
|
model="llama-3.3-70b-versatile", |
|
api_key=os.getenv("GROQ_API_KEY"), |
|
temperature=0 |
|
), |
|
description="Expert mathematical problem solver", |
|
instructions=[ |
|
"Solve math problems with precision", |
|
"Show step-by-step calculations", |
|
"Finish with: FINAL ANSWER: [result]" |
|
], |
|
storage=storage, |
|
memory=agent_memory, |
|
show_tool_calls=False, |
|
markdown=False |
|
) |
|
research_agent = Agent( |
|
name="ResearchSpecialist", |
|
model=Gemini( |
|
model="gemini-2.0-flash-lite", |
|
api_key=os.getenv("GOOGLE_API_KEY"), |
|
temperature=0 |
|
), |
|
description="Expert research and information gathering specialist", |
|
instructions=[ |
|
"Conduct thorough research using available tools", |
|
"Synthesize information from multiple sources", |
|
"Finish with: FINAL ANSWER: [answer]" |
|
], |
|
tools=[ |
|
TavilyTools( |
|
api_key=os.getenv("TAVILY_API_KEY"), |
|
search=True, |
|
max_tokens=6000, |
|
search_depth="advanced", |
|
format="markdown" |
|
) |
|
], |
|
storage=storage, |
|
memory=agent_memory, |
|
show_tool_calls=False, |
|
markdown=False |
|
) |
|
return {"math": math_agent, "research": research_agent} |
|
|
|
|
|
@tool |
|
def multiply(a: int, b: int) -> int: |
|
"""Multiply two numbers.""" |
|
return a * b |
|
|
|
@tool |
|
def add(a: int, b: int) -> int: |
|
"""Add two numbers.""" |
|
return a + b |
|
|
|
@tool |
|
def subtract(a: int, b: int) -> int: |
|
"""Subtract two numbers.""" |
|
return a - b |
|
|
|
@tool |
|
def divide(a: int, b: int) -> float: |
|
"""Divide two numbers.""" |
|
if b == 0: |
|
raise ValueError("Cannot divide by zero.") |
|
return a / b |
|
|
|
@tool |
|
def modulus(a: int, b: int) -> int: |
|
"""Return the remainder of a divided by b.""" |
|
return a % b |
|
|
|
@tool |
|
def optimized_web_search(query: str) -> str: |
|
"""Optimized Tavily web search.""" |
|
try: |
|
time.sleep(random.uniform(1, 2)) |
|
docs = TavilySearchResults(max_results=2).invoke(query=query) |
|
return "\n\n---\n\n".join( |
|
f"<Doc url='{d.get('url','')}'>{d.get('content','')[:500]}</Doc>" |
|
for d in docs |
|
) |
|
except Exception as e: |
|
return f"Web search failed: {e}" |
|
|
|
@tool |
|
def optimized_wiki_search(query: str) -> str: |
|
"""Optimized Wikipedia search.""" |
|
try: |
|
time.sleep(random.uniform(0.5, 1)) |
|
docs = WikipediaLoader(query=query, load_max_docs=1).load() |
|
return "\n\n---\n\n".join( |
|
f"<Doc src='{d.metadata['source']}'>{d.page_content[:800]}</Doc>" |
|
for d in docs |
|
) |
|
except Exception as e: |
|
return f"Wikipedia search failed: {e}" |
|
|
|
|
|
def setup_faiss(): |
|
try: |
|
schema = """ |
|
{ page_content: .Question, metadata: { task_id: .task_id, Final_answer: ."Final answer" } } |
|
""" |
|
loader = JSONLoader(file_path="metadata.jsonl", jq_schema=schema, json_lines=True, text_content=False) |
|
docs = loader.load() |
|
splitter = RecursiveCharacterTextSplitter(chunk_size=256, chunk_overlap=50) |
|
chunks = splitter.split_documents(docs) |
|
embeds = NVIDIAEmbeddings( |
|
model="nvidia/nv-embedqa-e5-v5", |
|
api_key=os.getenv("NVIDIA_API_KEY") |
|
) |
|
return FAISS.from_documents(chunks, embeds) |
|
except Exception as e: |
|
print(f"FAISS setup failed: {e}") |
|
return None |
|
|
|
class EnhancedAgentState(TypedDict): |
|
messages: Annotated[List[HumanMessage|AIMessage], operator.add] |
|
query: str |
|
agent_type: str |
|
final_answer: str |
|
perf: Dict[str,Any] |
|
agno_resp: str |
|
|
|
class HybridLangGraphAgnoSystem: |
|
def __init__(self): |
|
self.agno = create_agno_agents() |
|
self.store = setup_faiss() |
|
self.tools = [ |
|
multiply, add, subtract, divide, modulus, |
|
optimized_web_search, optimized_wiki_search |
|
] |
|
if self.store: |
|
retr = self.store.as_retriever(search_type="similarity", search_kwargs={"k":2}) |
|
self.tools.append(create_retriever_tool( |
|
retriever=retr, |
|
name="Question_Search", |
|
description="Retrieve similar questions" |
|
)) |
|
self.graph = self._build_graph() |
|
|
|
def _build_graph(self): |
|
groq_llm = ChatGroq(model="llama-3.3-70b-versatile", temperature=0) |
|
gemini_llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash-lite", temperature=0) |
|
nvidia_llm = ChatNVIDIA(model="meta/llama-3.1-70b-instruct", temperature=0) |
|
|
|
def router(st: EnhancedAgentState) -> EnhancedAgentState: |
|
q = st["query"].lower() |
|
if any(k in q for k in ["calculate","math"]): t="lg_math" |
|
elif any(k in q for k in ["research","analyze"]): t="agno_research" |
|
elif any(k in q for k in ["what is","who is"]): t="lg_retrieval" |
|
else: t="agno_general" |
|
return {**st, "agent_type": t} |
|
|
|
def lg_math(st: EnhancedAgentState) -> EnhancedAgentState: |
|
groq_limiter.wait_if_needed() |
|
t0=time.time() |
|
llm=groq_llm.bind_tools([multiply,add,subtract,divide,modulus]) |
|
sys=SystemMessage(content="Fast calculator. FINAL ANSWER: [result]") |
|
res=llm.invoke([sys,HumanMessage(content=st["query"])]) |
|
return {**st, "final_answer":res.content, "perf":{"time":time.time()-t0,"prov":"LG-Groq"}} |
|
|
|
def agno_research(st: EnhancedAgentState) -> EnhancedAgentState: |
|
gemini_limiter.wait_if_needed() |
|
t0=time.time() |
|
resp=self.agno["research"].run(st["query"],stream=False) |
|
return {**st, "final_answer":resp, "perf":{"time":time.time()-t0,"prov":"Agno-Gemini"}} |
|
|
|
def lg_retrieval(st: EnhancedAgentState) -> EnhancedAgentState: |
|
groq_limiter.wait_if_needed() |
|
t0=time.time() |
|
llm=groq_llm.bind_tools(self.tools) |
|
sys=SystemMessage(content="Retrieve. FINAL ANSWER: [answer]") |
|
res=llm.invoke([sys,HumanMessage(content=st["query"])]) |
|
return {**st, "final_answer":res.content, "perf":{"time":time.time()-t0,"prov":"LG-Retrieval"}} |
|
|
|
def agno_general(st: EnhancedAgentState) -> EnhancedAgentState: |
|
nvidia_limiter.wait_if_needed() |
|
t0=time.time() |
|
if any(k in st["query"].lower() for k in ["calculate","compute"]): |
|
resp=self.agno["math"].run(st["query"],stream=False) |
|
else: |
|
resp=self.agno["research"].run(st["query"],stream=False) |
|
return {**st, "final_answer":resp, "perf":{"time":time.time()-t0,"prov":"Agno-General"}} |
|
|
|
def pick(st: EnhancedAgentState) -> str: |
|
return st["agent_type"] |
|
|
|
g=StateGraph(EnhancedAgentState) |
|
g.add_node("router",router) |
|
g.add_node("lg_math",lg_math) |
|
g.add_node("agno_research",agno_research) |
|
g.add_node("lg_retrieval",lg_retrieval) |
|
g.add_node("agno_general",agno_general) |
|
g.set_entry_point("router") |
|
g.add_conditional_edges("router",pick,{ |
|
"lg_math":"lg_math","agno_research":"agno_research", |
|
"lg_retrieval":"lg_retrieval","agno_general":"agno_general" |
|
}) |
|
for n in ["lg_math","agno_research","lg_retrieval","agno_general"]: |
|
g.add_edge(n,"END") |
|
return g.compile(checkpointer=MemorySaver()) |
|
|
|
def process_query(self, q: str) -> Dict[str,Any]: |
|
state={ |
|
"messages":[HumanMessage(content=q)], |
|
"query":q,"agent_type":"","final_answer":"", |
|
"perf":{},"agno_resp":"" |
|
} |
|
cfg={"configurable":{"thread_id":f"hyb_{hash(q)}"}} |
|
try: |
|
out=self.graph.invoke(state,cfg) |
|
return { |
|
"answer":out["final_answer"], |
|
"performance_metrics":out["perf"], |
|
"provider_used":out["perf"].get("prov") |
|
} |
|
except Exception as e: |
|
return {"answer":f"Error: {e}","performance_metrics":{},"provider_used":"Error"} |
|
|
|
def build_graph(provider: str = "hybrid"): |
|
""" |
|
Build and return the StateGraph for the requested provider. |
|
- "hybrid", "groq", "google", and "nvidia" are all valid and |
|
will return the full HybridLangGraphAgnoSystem graph. |
|
""" |
|
if provider in ("hybrid", "groq", "google", "nvidia"): |
|
return HybridLangGraphAgnoSystem().graph |
|
else: |
|
raise ValueError(f"Unsupported provider: '{provider}'. Please use 'hybrid', 'groq', 'google', or 'nvidia'.") |
|
|
|
|
|
if __name__=="__main__": |
|
graph=build_graph() |
|
msgs=[HumanMessage(content="What are the names of the US presidents who were assassinated?")] |
|
res=graph.invoke({"messages":msgs},{"configurable":{"thread_id":"test"}}) |
|
for m in res["messages"]: |
|
m.pretty_print() |
|
|