import os import time import random import operator from typing import List, Dict, Any, TypedDict, Annotated from dotenv import load_dotenv from langchain_core.tools import tool from langchain_community.tools.tavily_search import TavilySearchResults from langchain_community.document_loaders import WikipediaLoader from langgraph.graph import StateGraph, END from langgraph.checkpoint.memory import MemorySaver from langchain_core.messages import SystemMessage, HumanMessage, AIMessage from langchain_groq import ChatGroq load_dotenv() # expects GROQ_API_KEY in your .env @tool def multiply(a: int, b: int) -> int: '''multiplies two numbers''' return a * b @tool def add(a: int, b: int) -> int: '''adds two numbers''' return a + b @tool def subtract(a: int, b: int) -> int: '''subtracts two numbers''' return a - b @tool def divide(a: int, b: int) -> float: '''divides two numbers''' if b == 0: raise ValueError("Cannot divide by zero.") return a / b @tool def modulus(a: int, b: int) -> int: '''returns the remainder while dividing two numbers''' return a % b @tool def optimized_web_search(query: str) -> str: '''searches the web using tavily''' try: time.sleep(random.uniform(0.7, 1.5)) docs = TavilySearchResults(max_results=2).invoke(query=query) return "\n\n---\n\n".join( f"{d.get('content','')[:500]}" for d in docs ) except Exception as e: return f"Web search failed: {e}" @tool def optimized_wiki_search(query: str) -> str: '''searches wikipedia''' try: time.sleep(random.uniform(0.3, 1)) docs = WikipediaLoader(query=query, load_max_docs=1).load() return "\n\n---\n\n".join( f"{d.page_content[:800]}" for d in docs ) except Exception as e: return f"Wikipedia search failed: {e}" class EnhancedAgentState(TypedDict): messages: Annotated[List[HumanMessage | AIMessage], operator.add] query: str agent_type: str final_answer: str perf: Dict[str, Any] agno_resp: str class HybridLangGraphMultiLLMSystem: """ Router that picks between Groq-hosted Llama-3 8B, Llama-3 70B (default), and Groq-hosted DeepSeek-Chat according to the query content. """ def __init__(self): self.tools = [ multiply, add, subtract, divide, modulus, optimized_web_search, optimized_wiki_search ] self.graph = self._build_graph() def _llm(self, model_name: str): return ChatGroq( model=model_name, temperature=0, api_key=os.getenv("GROQ_API_KEY") ) def _build_graph(self): llama8_llm = self._llm("llama3-8b-8192") llama70_llm = self._llm("llama3-70b-8192") deepseek_llm = self._llm("deepseek-chat") def router(st: EnhancedAgentState) -> EnhancedAgentState: q = st["query"].lower() if "llama-8" in q: t = "llama8" elif "deepseek" in q: t = "deepseek" else: t = "llama70" return {**st, "agent_type": t} def llama8_node(st: EnhancedAgentState) -> EnhancedAgentState: t0 = time.time() sys = SystemMessage(content="You are a helpful AI assistant.") res = llama8_llm.invoke([sys, HumanMessage(content=st["query"])]) return {**st, "final_answer": res.content, "perf": {"time": time.time() - t0, "prov": "Groq-Llama3-8B"}} def llama70_node(st: EnhancedAgentState) -> EnhancedAgentState: t0 = time.time() sys = SystemMessage(content="You are a helpful AI assistant.") res = llama70_llm.invoke([sys, HumanMessage(content=st["query"])]) return {**st, "final_answer": res.content, "perf": {"time": time.time() - t0, "prov": "Groq-Llama3-70B"}} def deepseek_node(st: EnhancedAgentState) -> EnhancedAgentState: t0 = time.time() sys = SystemMessage(content="You are a helpful AI assistant.") res = deepseek_llm.invoke([sys, HumanMessage(content=st["query"])]) return {**st, "final_answer": res.content, "perf": {"time": time.time() - t0, "prov": "Groq-DeepSeek"}} g = StateGraph(EnhancedAgentState) g.add_node("router", router) g.add_node("llama8", llama8_node) g.add_node("llama70", llama70_node) g.add_node("deepseek", deepseek_node) g.set_entry_point("router") g.add_conditional_edges("router", lambda s: s["agent_type"], {"llama8": "llama8", "llama70": "llama70", "deepseek": "deepseek"}) g.add_edge("llama8", END) g.add_edge("llama70", END) g.add_edge("deepseek", END) return g.compile(checkpointer=MemorySaver()) def process_query(self, q: str) -> str: state = { "messages": [HumanMessage(content=q)], "query": q, "agent_type": "", "final_answer": "", "perf": {}, "agno_resp": "" } cfg = {"configurable": {"thread_id": f"hyb_{hash(q)}"}} out = self.graph.invoke(state, cfg) return out.get("final_answer", "").strip() def build_graph(provider: str | None = None): return HybridLangGraphMultiLLMSystem().graph if __name__ == "__main__": qa_system = HybridLangGraphMultiLLMSystem() # Test each model print(qa_system.process_query("llama-8: What is the capital of France?")) print(qa_system.process_query("llama-70: Tell me about quantum mechanics.")) print(qa_system.process_query("deepseek: What is the Riemann Hypothesis?"))