|
import os |
|
import time |
|
import random |
|
import operator |
|
from typing import List, Dict, Any, TypedDict, Annotated |
|
from dotenv import load_dotenv |
|
|
|
from langchain_core.tools import tool |
|
from langchain_community.tools.tavily_search import TavilySearchResults |
|
from langchain_community.document_loaders import WikipediaLoader |
|
from langgraph.graph import StateGraph, END |
|
from langgraph.checkpoint.memory import MemorySaver |
|
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage |
|
from langchain_groq import ChatGroq |
|
|
|
load_dotenv() |
|
|
|
@tool |
|
def multiply(a: int, b: int) -> int: |
|
'''multiplies two numbers''' |
|
return a * b |
|
|
|
@tool |
|
def add(a: int, b: int) -> int: |
|
'''adds two numbers''' |
|
return a + b |
|
|
|
@tool |
|
def subtract(a: int, b: int) -> int: |
|
'''subtracts two numbers''' |
|
return a - b |
|
|
|
@tool |
|
def divide(a: int, b: int) -> float: |
|
'''divides two numbers''' |
|
if b == 0: |
|
raise ValueError("Cannot divide by zero.") |
|
return a / b |
|
|
|
@tool |
|
def modulus(a: int, b: int) -> int: |
|
'''returns the remainder while dividing two numbers''' |
|
return a % b |
|
|
|
@tool |
|
def optimized_web_search(query: str) -> str: |
|
'''searches the web using tavily''' |
|
try: |
|
time.sleep(random.uniform(0.7, 1.5)) |
|
docs = TavilySearchResults(max_results=2).invoke(query=query) |
|
return "\n\n---\n\n".join( |
|
f"<Doc url='{d.get('url','')}'>{d.get('content','')[:500]}</Doc>" |
|
for d in docs |
|
) |
|
except Exception as e: |
|
return f"Web search failed: {e}" |
|
|
|
@tool |
|
def optimized_wiki_search(query: str) -> str: |
|
'''searches wikipedia''' |
|
try: |
|
time.sleep(random.uniform(0.3, 1)) |
|
docs = WikipediaLoader(query=query, load_max_docs=1).load() |
|
return "\n\n---\n\n".join( |
|
f"<Doc src='{d.metadata.get('source','Wikipedia')}'>{d.page_content[:800]}</Doc>" |
|
for d in docs |
|
) |
|
except Exception as e: |
|
return f"Wikipedia search failed: {e}" |
|
|
|
class EnhancedAgentState(TypedDict): |
|
messages: Annotated[List[HumanMessage | AIMessage], operator.add] |
|
query: str |
|
agent_type: str |
|
final_answer: str |
|
perf: Dict[str, Any] |
|
agno_resp: str |
|
|
|
class HybridLangGraphMultiLLMSystem: |
|
""" |
|
Router that picks between Groq-hosted Llama-3 8B, Llama-3 70B (default), |
|
and Groq-hosted DeepSeek-Chat according to the query content. |
|
""" |
|
def __init__(self): |
|
self.tools = [ |
|
multiply, add, subtract, divide, modulus, |
|
optimized_web_search, optimized_wiki_search |
|
] |
|
self.graph = self._build_graph() |
|
|
|
def _llm(self, model_name: str): |
|
return ChatGroq( |
|
model=model_name, |
|
temperature=0, |
|
api_key=os.getenv("GROQ_API_KEY") |
|
) |
|
|
|
def _build_graph(self): |
|
llama8_llm = self._llm("llama3-8b-8192") |
|
llama70_llm = self._llm("llama3-70b-8192") |
|
deepseek_llm = self._llm("deepseek-chat") |
|
|
|
def router(st: EnhancedAgentState) -> EnhancedAgentState: |
|
q = st["query"].lower() |
|
if "llama-8" in q: |
|
t = "llama8" |
|
elif "deepseek" in q: |
|
t = "deepseek" |
|
else: |
|
t = "llama70" |
|
return {**st, "agent_type": t} |
|
|
|
def llama8_node(st: EnhancedAgentState) -> EnhancedAgentState: |
|
t0 = time.time() |
|
sys = SystemMessage(content="You are a helpful AI assistant.") |
|
res = llama8_llm.invoke([sys, HumanMessage(content=st["query"])]) |
|
return {**st, |
|
"final_answer": res.content, |
|
"perf": {"time": time.time() - t0, "prov": "Groq-Llama3-8B"}} |
|
|
|
def llama70_node(st: EnhancedAgentState) -> EnhancedAgentState: |
|
t0 = time.time() |
|
sys = SystemMessage(content="You are a helpful AI assistant.") |
|
res = llama70_llm.invoke([sys, HumanMessage(content=st["query"])]) |
|
return {**st, |
|
"final_answer": res.content, |
|
"perf": {"time": time.time() - t0, "prov": "Groq-Llama3-70B"}} |
|
|
|
def deepseek_node(st: EnhancedAgentState) -> EnhancedAgentState: |
|
t0 = time.time() |
|
sys = SystemMessage(content="You are a helpful AI assistant.") |
|
res = deepseek_llm.invoke([sys, HumanMessage(content=st["query"])]) |
|
return {**st, |
|
"final_answer": res.content, |
|
"perf": {"time": time.time() - t0, "prov": "Groq-DeepSeek"}} |
|
|
|
g = StateGraph(EnhancedAgentState) |
|
g.add_node("router", router) |
|
g.add_node("llama8", llama8_node) |
|
g.add_node("llama70", llama70_node) |
|
g.add_node("deepseek", deepseek_node) |
|
g.set_entry_point("router") |
|
g.add_conditional_edges("router", lambda s: s["agent_type"], |
|
{"llama8": "llama8", "llama70": "llama70", "deepseek": "deepseek"}) |
|
g.add_edge("llama8", END) |
|
g.add_edge("llama70", END) |
|
g.add_edge("deepseek", END) |
|
return g.compile(checkpointer=MemorySaver()) |
|
|
|
def process_query(self, q: str) -> str: |
|
state = { |
|
"messages": [HumanMessage(content=q)], |
|
"query": q, |
|
"agent_type": "", |
|
"final_answer": "", |
|
"perf": {}, |
|
"agno_resp": "" |
|
} |
|
cfg = {"configurable": {"thread_id": f"hyb_{hash(q)}"}} |
|
out = self.graph.invoke(state, cfg) |
|
return out.get("final_answer", "").strip() |
|
|
|
def build_graph(provider: str | None = None): |
|
return HybridLangGraphMultiLLMSystem().graph |
|
|
|
if __name__ == "__main__": |
|
qa_system = HybridLangGraphMultiLLMSystem() |
|
|
|
print(qa_system.process_query("llama-8: What is the capital of France?")) |
|
print(qa_system.process_query("llama-70: Tell me about quantum mechanics.")) |
|
print(qa_system.process_query("deepseek: What is the Riemann Hypothesis?")) |
|
|