import os
import time
import random
from dotenv import load_dotenv
from typing import List, Dict, Any, TypedDict, Annotated
import operator
from langchain_core.tools import tool
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.document_loaders import WikipediaLoader
from langchain_community.vectorstores import Chroma
from langchain.tools.retriever import create_retriever_tool
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
from langchain_community.embeddings import SentenceTransformerEmbeddings
from langgraph.graph import StateGraph, START, END
from langgraph.checkpoint.memory import MemorySaver
# ---- Tool Definitions ----
@tool
def multiply(a: int, b: int) -> int:
"""Multiply two integers and return the product."""
return a * b
@tool
def add(a: int, b: int) -> int:
"""Add two integers and return the sum."""
return a + b
@tool
def subtract(a: int, b: int) -> int:
"""Subtract the second integer from the first and return the difference."""
return a - b
@tool
def divide(a: int, b: int) -> float:
"""Divide the first integer by the second and return the quotient."""
if b == 0:
raise ValueError("Cannot divide by zero.")
return a / b
@tool
def modulus(a: int, b: int) -> int:
"""Return the remainder of the division of the first integer by the second."""
return a % b
@tool
def optimized_web_search(query: str) -> str:
"""Perform an optimized web search using TavilySearchResults and return concatenated document snippets."""
try:
time.sleep(random.uniform(1, 2))
docs = TavilySearchResults(max_results=2).invoke(query=query)
return "\n\n---\n\n".join(
f"{d.get('content','')[:500]}"
for d in docs
)
except Exception as e:
return f"Web search failed: {e}"
@tool
def optimized_wiki_search(query: str) -> str:
"""Perform an optimized Wikipedia search and return concatenated document snippets."""
try:
time.sleep(random.uniform(0.5, 1))
docs = WikipediaLoader(query=query, load_max_docs=1).load()
return "\n\n---\n\n".join(
f"{d.page_content[:800]}"
for d in docs
)
except Exception as e:
return f"Wikipedia search failed: {e}"
# ---- LLM Integrations ----
load_dotenv()
from langchain_groq import ChatGroq
from langchain_nvidia_ai_endpoints import ChatNVIDIA
from google import genai
import requests
def baidu_ernie_generate(prompt, api_key=None):
url = "https://api.baidu.com/ernie/v1/generate"
headers = {"Authorization": f"Bearer {api_key}"}
data = {"model": "ernie-4.5", "prompt": prompt}
try:
resp = requests.post(url, headers=headers, json=data, timeout=30)
return resp.json().get("result", "")
except Exception as e:
return f"ERNIE API error: {e}"
def deepseek_generate(prompt, api_key=None):
url = "https://api.deepseek.com/v1/chat/completions"
headers = {"Authorization": f"Bearer {api_key}"}
data = {"model": "deepseek-chat", "messages": [{"role": "user", "content": prompt}]}
try:
resp = requests.post(url, headers=headers, json=data, timeout=30)
choices = resp.json().get("choices", [{}])
if choices and "message" in choices[0]:
return choices[0]["message"].get("content", "")
return ""
except Exception as e:
return f"DeepSeek API error: {e}"
class EnhancedAgentState(TypedDict):
messages: Annotated[List[HumanMessage|AIMessage], operator.add]
query: str
agent_type: str
final_answer: str
perf: Dict[str,Any]
agno_resp: str
class HybridLangGraphMultiLLMSystem:
def __init__(self):
self.tools = [
multiply, add, subtract, divide, modulus,
optimized_web_search, optimized_wiki_search
]
self.graph = self._build_graph()
def _build_graph(self):
groq_llm = ChatGroq(model="llama3-70b-8192", temperature=0, api_key=os.getenv("GROQ_API_KEY"))
nvidia_llm = ChatNVIDIA(model="meta/llama3-70b-instruct", temperature=0, api_key=os.getenv("NVIDIA_API_KEY"))
def router(st: EnhancedAgentState) -> EnhancedAgentState:
q = st["query"].lower()
if "groq" in q: t = "groq"
elif "nvidia" in q: t = "nvidia"
elif "gemini" in q or "google" in q: t = "gemini"
elif "deepseek" in q: t = "deepseek"
elif "ernie" in q or "baidu" in q: t = "baidu"
else: t = "groq" # default
return {**st, "agent_type": t}
def groq_node(st: EnhancedAgentState) -> EnhancedAgentState:
t0 = time.time()
sys = SystemMessage(content="Answer as an expert.")
res = groq_llm.invoke([sys, HumanMessage(content=st["query"])])
return {**st, "final_answer": res.content, "perf": {"time": time.time() - t0, "prov": "Groq"}}
def nvidia_node(st: EnhancedAgentState) -> EnhancedAgentState:
t0 = time.time()
sys = SystemMessage(content="Answer as an expert.")
res = nvidia_llm.invoke([sys, HumanMessage(content=st["query"])])
return {**st, "final_answer": res.content, "perf": {"time": time.time() - t0, "prov": "NVIDIA"}}
def gemini_node(st: EnhancedAgentState) -> EnhancedAgentState:
t0 = time.time()
genai.configure(api_key=os.getenv("GEMINI_API_KEY"))
model = genai.GenerativeModel("gemini-1.5-pro-latest")
res = model.generate_content(st["query"])
return {**st, "final_answer": res.text, "perf": {"time": time.time() - t0, "prov": "Gemini"}}
def deepseek_node(st: EnhancedAgentState) -> EnhancedAgentState:
t0 = time.time()
resp = deepseek_generate(st["query"], api_key=os.getenv("DEEPSEEK_API_KEY"))
return {**st, "final_answer": resp, "perf": {"time": time.time() - t0, "prov": "DeepSeek"}}
def baidu_node(st: EnhancedAgentState) -> EnhancedAgentState:
t0 = time.time()
resp = baidu_ernie_generate(st["query"], api_key=os.getenv("BAIDU_API_KEY"))
return {**st, "final_answer": resp, "perf": {"time": time.time() - t0, "prov": "ERNIE"}}
def pick(st: EnhancedAgentState) -> str:
return st["agent_type"]
g = StateGraph(EnhancedAgentState)
g.add_node("router", router)
g.add_node("groq", groq_node)
g.add_node("nvidia", nvidia_node)
g.add_node("gemini", gemini_node)
g.add_node("deepseek", deepseek_node)
g.add_node("baidu", baidu_node)
g.set_entry_point("router")
g.add_conditional_edges("router", pick, {
"groq": "groq",
"nvidia": "nvidia",
"gemini": "gemini",
"deepseek": "deepseek",
"baidu": "baidu"
})
for n in ["groq", "nvidia", "gemini", "deepseek", "baidu"]:
g.add_edge(n, END)
return g.compile(checkpointer=MemorySaver())
def process_query(self, q: str) -> str:
state = {
"messages": [HumanMessage(content=q)],
"query": q,
"agent_type": "",
"final_answer": "",
"perf": {},
"agno_resp": ""
}
cfg = {"configurable": {"thread_id": f"hyb_{hash(q)}"}}
out = self.graph.invoke(state, cfg)
raw_answer = out["final_answer"]
parts = raw_answer.split('\n\n', 1)
answer_part = parts[1].strip() if len(parts) > 1 else raw_answer.strip()
return answer_part
def build_graph(provider=None):
return HybridLangGraphMultiLLMSystem().graph