import os import time import random from dotenv import load_dotenv from typing import List, Dict, Any, TypedDict, Annotated import operator from langchain_core.tools import tool from langchain_community.tools.tavily_search import TavilySearchResults from langchain_community.document_loaders import WikipediaLoader from langchain_community.vectorstores import Chroma from langchain.tools.retriever import create_retriever_tool from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain_core.messages import SystemMessage, HumanMessage, AIMessage from langchain_community.embeddings import SentenceTransformerEmbeddings from langgraph.graph import StateGraph, START, END from langgraph.checkpoint.memory import MemorySaver # ---- Tool Definitions ---- @tool def multiply(a: int, b: int) -> int: """Multiply two integers and return the product.""" return a * b @tool def add(a: int, b: int) -> int: """Add two integers and return the sum.""" return a + b @tool def subtract(a: int, b: int) -> int: """Subtract the second integer from the first and return the difference.""" return a - b @tool def divide(a: int, b: int) -> float: """Divide the first integer by the second and return the quotient.""" if b == 0: raise ValueError("Cannot divide by zero.") return a / b @tool def modulus(a: int, b: int) -> int: """Return the remainder of the division of the first integer by the second.""" return a % b @tool def optimized_web_search(query: str) -> str: """Perform an optimized web search using TavilySearchResults and return concatenated document snippets.""" try: time.sleep(random.uniform(1, 2)) docs = TavilySearchResults(max_results=2).invoke(query=query) return "\n\n---\n\n".join( f"{d.get('content','')[:500]}" for d in docs ) except Exception as e: return f"Web search failed: {e}" @tool def optimized_wiki_search(query: str) -> str: """Perform an optimized Wikipedia search and return concatenated document snippets.""" try: time.sleep(random.uniform(0.5, 1)) docs = WikipediaLoader(query=query, load_max_docs=1).load() return "\n\n---\n\n".join( f"{d.page_content[:800]}" for d in docs ) except Exception as e: return f"Wikipedia search failed: {e}" # ---- LLM Integrations ---- load_dotenv() # Groq (Llama 3, DeepSeek, etc. via LangChain integration) from langchain_groq import ChatGroq # NVIDIA NIM (LangChain integration) from langchain_nvidia_ai_endpoints import ChatNVIDIA from google import genai # DeepSeek (via Ollama or API) import requests # Baidu ERNIE (assume open source API, use requests as placeholder) def baidu_ernie_generate(prompt, api_key=None): """Call Baidu ERNIE open source API (pseudo-code, replace with actual endpoint and params).""" # Example endpoint and payload for demonstration purposes only: url = "https://api.baidu.com/ernie/v1/generate" headers = {"Authorization": f"Bearer {api_key}"} data = {"model": "ernie-4.5", "prompt": prompt} try: resp = requests.post(url, headers=headers, json=data, timeout=30) return resp.json().get("result", "") except Exception as e: return f"ERNIE API error: {e}" def deepseek_generate(prompt, api_key=None): """Call DeepSeek open source API (pseudo-code, replace with actual endpoint and params).""" url = "https://api.deepseek.com/v1/chat/completions" headers = {"Authorization": f"Bearer {api_key}"} data = {"model": "deepseek-chat", "messages": [{"role": "user", "content": prompt}]} try: resp = requests.post(url, headers=headers, json=data, timeout=30) return resp.json().get("choices", [{}])[0].get("message", {}).get("content", "") except Exception as e: return f"DeepSeek API error: {e}" # ---- Graph State and System ---- class EnhancedAgentState(TypedDict): messages: Annotated[List[HumanMessage|AIMessage], operator.add] query: str agent_type: str final_answer: str perf: Dict[str,Any] agno_resp: str class HybridLangGraphMultiLLMSystem: def __init__(self): self.tools = [ multiply, add, subtract, divide, modulus, optimized_web_search, optimized_wiki_search ] self.graph = self._build_graph() def _build_graph(self): groq_llm = ChatGroq(model="llama3-70b-8192", temperature=0, api_key=os.getenv("GROQ_API_KEY")) nvidia_llm = ChatNVIDIA(model="meta/llama3-70b-instruct", temperature=0, api_key=os.getenv("NVIDIA_API_KEY")) def router(st: EnhancedAgentState) -> EnhancedAgentState: q = st["query"].lower() if "groq" in q: t = "groq" elif "nvidia" in q: t = "nvidia" elif "gemini" in q or "google" in q: t = "gemini" elif "deepseek" in q: t = "deepseek" elif "ernie" in q or "baidu" in q: t = "baidu" else: t = "groq" # default return {**st, "agent_type": t} def groq_node(st: EnhancedAgentState) -> EnhancedAgentState: t0 = time.time() sys = SystemMessage(content="Answer as an expert.") res = groq_llm.invoke([sys, HumanMessage(content=st["query"])]) return {**st, "final_answer": res.content, "perf": {"time": time.time() - t0, "prov": "Groq"}} def nvidia_node(st: EnhancedAgentState) -> EnhancedAgentState: t0 = time.time() sys = SystemMessage(content="Answer as an expert.") res = nvidia_llm.invoke([sys, HumanMessage(content=st["query"])]) return {**st, "final_answer": res.content, "perf": {"time": time.time() - t0, "prov": "NVIDIA"}} def gemini_node(st: EnhancedAgentState) -> EnhancedAgentState: t0 = time.time() genai.configure(api_key=os.getenv("GEMINI_API_KEY")) model = genai.GenerativeModel("gemini-1.5-pro-latest") res = model.generate_content(st["query"]) return {**st, "final_answer": res.text, "perf": {"time": time.time() - t0, "prov": "Gemini"}} def deepseek_node(st: EnhancedAgentState) -> EnhancedAgentState: t0 = time.time() resp = deepseek_generate(st["query"], api_key=os.getenv("DEEPSEEK_API_KEY")) return {**st, "final_answer": resp, "perf": {"time": time.time() - t0, "prov": "DeepSeek"}} def baidu_node(st: EnhancedAgentState) -> EnhancedAgentState: t0 = time.time() resp = baidu_ernie_generate(st["query"], api_key=os.getenv("BAIDU_API_KEY")) return {**st, "final_answer": resp, "perf": {"time": time.time() - t0, "prov": "ERNIE"}} def pick(st: EnhancedAgentState) -> str: return st["agent_type"] g = StateGraph(EnhancedAgentState) g.add_node("router", router) g.add_node("groq", groq_node) g.add_node("nvidia", nvidia_node) g.add_node("gemini", gemini_node) g.add_node("deepseek", deepseek_node) g.add_node("baidu", baidu_node) g.set_entry_point("router") g.add_conditional_edges("router", pick, { "groq": "groq", "nvidia": "nvidia", "gemini": "gemini", "deepseek": "deepseek", "baidu": "baidu" }) for n in ["groq", "nvidia", "gemini", "deepseek", "baidu"]: g.add_edge(n, END) return g.compile(checkpointer=MemorySaver()) def process_query(self, q: str) -> str: state = { "messages": [HumanMessage(content=q)], "query": q, "agent_type": "", "final_answer": "", "perf": {}, "agno_resp": "" } cfg = {"configurable": {"thread_id": f"hyb_{hash(q)}"}} out = self.graph.invoke(state, cfg) raw_answer = out["final_answer"] parts = raw_answer.split('\n\n', 1) answer_part = parts[1].strip() if len(parts) > 1 else raw_answer.strip() return answer_part if __name__ == "__main__": query = "What are the names of the US presidents who were assassinated? (groq)" print("LangGraph Hybrid:", HybridLangGraphMultiLLMSystem().process_query(query))