import os
import time
import random
import operator
from typing import List, Dict, Any, TypedDict, Annotated
from dotenv import load_dotenv
from langchain_core.tools import tool
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.document_loaders import WikipediaLoader
from langgraph.graph import StateGraph, END
from langgraph.checkpoint.memory import MemorySaver
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
from langchain_groq import ChatGroq
load_dotenv() # expects GROQ_API_KEY in your .env
@tool
def multiply(a: int, b: int) -> int:
'''multiplies two numbers'''
return a * b
@tool
def add(a: int, b: int) -> int:
'''adds two numbers'''
return a + b
@tool
def subtract(a: int, b: int) -> int:
'''subtracts two numbers'''
return a - b
@tool
def divide(a: int, b: int) -> float:
'''divides two numbers'''
if b == 0:
raise ValueError("Cannot divide by zero.")
return a / b
@tool
def modulus(a: int, b: int) -> int:
'''returns the remainder while dividing two numbers'''
return a % b
@tool
def optimized_web_search(query: str) -> str:
'''searches the web using tavily'''
try:
time.sleep(random.uniform(0.7, 1.5))
docs = TavilySearchResults(max_results=2).invoke(query=query)
return "\n\n---\n\n".join(
f"{d.get('content','')[:500]}"
for d in docs
)
except Exception as e:
return f"Web search failed: {e}"
@tool
def optimized_wiki_search(query: str) -> str:
'''searches wikipedia'''
try:
time.sleep(random.uniform(0.3, 1))
docs = WikipediaLoader(query=query, load_max_docs=1).load()
return "\n\n---\n\n".join(
f"{d.page_content[:800]}"
for d in docs
)
except Exception as e:
return f"Wikipedia search failed: {e}"
class EnhancedAgentState(TypedDict):
messages: Annotated[List[HumanMessage | AIMessage], operator.add]
query: str
agent_type: str
final_answer: str
perf: Dict[str, Any]
agno_resp: str
class HybridLangGraphMultiLLMSystem:
"""
Router that picks between Groq-hosted Llama-3 8B, Llama-3 70B (default),
and Groq-hosted DeepSeek-Chat according to the query content.
"""
def __init__(self):
self.tools = [
multiply, add, subtract, divide, modulus,
optimized_web_search, optimized_wiki_search
]
self.graph = self._build_graph()
def _llm(self, model_name: str):
return ChatGroq(
model=model_name,
temperature=0,
api_key=os.getenv("GROQ_API_KEY")
)
def _build_graph(self):
llama8_llm = self._llm("llama3-8b-8192")
llama70_llm = self._llm("llama3-70b-8192")
deepseek_llm = self._llm("deepseek-chat")
def router(st: EnhancedAgentState) -> EnhancedAgentState:
q = st["query"].lower()
if "llama-8" in q:
t = "llama8"
elif "deepseek" in q:
t = "deepseek"
else:
t = "llama70"
return {**st, "agent_type": t}
def llama8_node(st: EnhancedAgentState) -> EnhancedAgentState:
t0 = time.time()
sys = SystemMessage(content="You are a helpful AI assistant.")
res = llama8_llm.invoke([sys, HumanMessage(content=st["query"])])
return {**st,
"final_answer": res.content,
"perf": {"time": time.time() - t0, "prov": "Groq-Llama3-8B"}}
def llama70_node(st: EnhancedAgentState) -> EnhancedAgentState:
t0 = time.time()
sys = SystemMessage(content="You are a helpful AI assistant.")
res = llama70_llm.invoke([sys, HumanMessage(content=st["query"])])
return {**st,
"final_answer": res.content,
"perf": {"time": time.time() - t0, "prov": "Groq-Llama3-70B"}}
def deepseek_node(st: EnhancedAgentState) -> EnhancedAgentState:
t0 = time.time()
sys = SystemMessage(content="You are a helpful AI assistant.")
res = deepseek_llm.invoke([sys, HumanMessage(content=st["query"])])
return {**st,
"final_answer": res.content,
"perf": {"time": time.time() - t0, "prov": "Groq-DeepSeek"}}
g = StateGraph(EnhancedAgentState)
g.add_node("router", router)
g.add_node("llama8", llama8_node)
g.add_node("llama70", llama70_node)
g.add_node("deepseek", deepseek_node)
g.set_entry_point("router")
g.add_conditional_edges("router", lambda s: s["agent_type"],
{"llama8": "llama8", "llama70": "llama70", "deepseek": "deepseek"})
g.add_edge("llama8", END)
g.add_edge("llama70", END)
g.add_edge("deepseek", END)
return g.compile(checkpointer=MemorySaver())
def process_query(self, q: str) -> str:
state = {
"messages": [HumanMessage(content=q)],
"query": q,
"agent_type": "",
"final_answer": "",
"perf": {},
"agno_resp": ""
}
cfg = {"configurable": {"thread_id": f"hyb_{hash(q)}"}}
out = self.graph.invoke(state, cfg)
return out.get("final_answer", "").strip()
def build_graph(provider: str | None = None):
return HybridLangGraphMultiLLMSystem().graph
if __name__ == "__main__":
qa_system = HybridLangGraphMultiLLMSystem()
# Test each model
print(qa_system.process_query("llama-8: What is the capital of France?"))
print(qa_system.process_query("llama-70: Tell me about quantum mechanics."))
print(qa_system.process_query("deepseek: What is the Riemann Hypothesis?"))