|
import os |
|
import time |
|
import random |
|
from dotenv import load_dotenv |
|
from typing import List, Dict, Any, TypedDict, Annotated |
|
import operator |
|
|
|
from langchain_core.tools import tool |
|
from langchain_community.tools.tavily_search import TavilySearchResults |
|
from langchain_community.document_loaders import WikipediaLoader |
|
from langchain_community.vectorstores import Chroma |
|
from langchain.tools.retriever import create_retriever_tool |
|
from langchain_text_splitters import RecursiveCharacterTextSplitter |
|
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage |
|
from langchain_community.embeddings import SentenceTransformerEmbeddings |
|
|
|
from langgraph.graph import StateGraph, START, END |
|
from langgraph.checkpoint.memory import MemorySaver |
|
|
|
|
|
@tool |
|
def multiply(a: int, b: int) -> int: |
|
"""Multiply two integers and return the product.""" |
|
return a * b |
|
|
|
@tool |
|
def add(a: int, b: int) -> int: |
|
"""Add two integers and return the sum.""" |
|
return a + b |
|
|
|
@tool |
|
def subtract(a: int, b: int) -> int: |
|
"""Subtract the second integer from the first and return the difference.""" |
|
return a - b |
|
|
|
@tool |
|
def divide(a: int, b: int) -> float: |
|
"""Divide the first integer by the second and return the quotient.""" |
|
if b == 0: |
|
raise ValueError("Cannot divide by zero.") |
|
return a / b |
|
|
|
@tool |
|
def modulus(a: int, b: int) -> int: |
|
"""Return the remainder of the division of the first integer by the second.""" |
|
return a % b |
|
|
|
@tool |
|
def optimized_web_search(query: str) -> str: |
|
"""Perform an optimized web search using TavilySearchResults and return concatenated document snippets.""" |
|
try: |
|
time.sleep(random.uniform(1, 2)) |
|
docs = TavilySearchResults(max_results=2).invoke(query=query) |
|
return "\n\n---\n\n".join( |
|
f"<Doc url='{d.get('url','')}'>{d.get('content','')[:500]}</Doc>" |
|
for d in docs |
|
) |
|
except Exception as e: |
|
return f"Web search failed: {e}" |
|
|
|
@tool |
|
def optimized_wiki_search(query: str) -> str: |
|
"""Perform an optimized Wikipedia search and return concatenated document snippets.""" |
|
try: |
|
time.sleep(random.uniform(0.5, 1)) |
|
docs = WikipediaLoader(query=query, load_max_docs=1).load() |
|
return "\n\n---\n\n".join( |
|
f"<Doc src='{d.metadata['source']}'>{d.page_content[:800]}</Doc>" |
|
for d in docs |
|
) |
|
except Exception as e: |
|
return f"Wikipedia search failed: {e}" |
|
|
|
|
|
load_dotenv() |
|
|
|
from langchain_groq import ChatGroq |
|
from langchain_nvidia_ai_endpoints import ChatNVIDIA |
|
from google import genai |
|
|
|
import requests |
|
|
|
def baidu_ernie_generate(prompt, api_key=None): |
|
url = "https://api.baidu.com/ernie/v1/generate" |
|
headers = {"Authorization": f"Bearer {api_key}"} |
|
data = {"model": "ernie-4.5", "prompt": prompt} |
|
try: |
|
resp = requests.post(url, headers=headers, json=data, timeout=30) |
|
return resp.json().get("result", "") |
|
except Exception as e: |
|
return f"ERNIE API error: {e}" |
|
|
|
def deepseek_generate(prompt, api_key=None): |
|
url = "https://api.deepseek.com/v1/chat/completions" |
|
headers = {"Authorization": f"Bearer {api_key}"} |
|
data = {"model": "deepseek-chat", "messages": [{"role": "user", "content": prompt}]} |
|
try: |
|
resp = requests.post(url, headers=headers, json=data, timeout=30) |
|
choices = resp.json().get("choices", [{}]) |
|
if choices and "message" in choices[0]: |
|
return choices[0]["message"].get("content", "") |
|
return "" |
|
except Exception as e: |
|
return f"DeepSeek API error: {e}" |
|
|
|
class EnhancedAgentState(TypedDict): |
|
messages: Annotated[List[HumanMessage|AIMessage], operator.add] |
|
query: str |
|
agent_type: str |
|
final_answer: str |
|
perf: Dict[str,Any] |
|
agno_resp: str |
|
|
|
class HybridLangGraphMultiLLMSystem: |
|
def __init__(self): |
|
self.tools = [ |
|
multiply, add, subtract, divide, modulus, |
|
optimized_web_search, optimized_wiki_search |
|
] |
|
self.graph = self._build_graph() |
|
|
|
def _build_graph(self): |
|
groq_llm = ChatGroq(model="llama3-70b-8192", temperature=0, api_key=os.getenv("GROQ_API_KEY")) |
|
nvidia_llm = ChatNVIDIA(model="meta/llama3-70b-instruct", temperature=0, api_key=os.getenv("NVIDIA_API_KEY")) |
|
|
|
def router(st: EnhancedAgentState) -> EnhancedAgentState: |
|
q = st["query"].lower() |
|
if "groq" in q: t = "groq" |
|
elif "nvidia" in q: t = "nvidia" |
|
elif "gemini" in q or "google" in q: t = "gemini" |
|
elif "deepseek" in q: t = "deepseek" |
|
elif "ernie" in q or "baidu" in q: t = "baidu" |
|
else: t = "groq" |
|
return {**st, "agent_type": t} |
|
|
|
def groq_node(st: EnhancedAgentState) -> EnhancedAgentState: |
|
t0 = time.time() |
|
sys = SystemMessage(content="Answer as an expert.") |
|
res = groq_llm.invoke([sys, HumanMessage(content=st["query"])]) |
|
return {**st, "final_answer": res.content, "perf": {"time": time.time() - t0, "prov": "Groq"}} |
|
|
|
def nvidia_node(st: EnhancedAgentState) -> EnhancedAgentState: |
|
t0 = time.time() |
|
sys = SystemMessage(content="Answer as an expert.") |
|
res = nvidia_llm.invoke([sys, HumanMessage(content=st["query"])]) |
|
return {**st, "final_answer": res.content, "perf": {"time": time.time() - t0, "prov": "NVIDIA"}} |
|
|
|
def gemini_node(st: EnhancedAgentState) -> EnhancedAgentState: |
|
t0 = time.time() |
|
genai.configure(api_key=os.getenv("GEMINI_API_KEY")) |
|
model = genai.GenerativeModel("gemini-1.5-pro-latest") |
|
res = model.generate_content(st["query"]) |
|
return {**st, "final_answer": res.text, "perf": {"time": time.time() - t0, "prov": "Gemini"}} |
|
|
|
def deepseek_node(st: EnhancedAgentState) -> EnhancedAgentState: |
|
t0 = time.time() |
|
resp = deepseek_generate(st["query"], api_key=os.getenv("DEEPSEEK_API_KEY")) |
|
return {**st, "final_answer": resp, "perf": {"time": time.time() - t0, "prov": "DeepSeek"}} |
|
|
|
def baidu_node(st: EnhancedAgentState) -> EnhancedAgentState: |
|
t0 = time.time() |
|
resp = baidu_ernie_generate(st["query"], api_key=os.getenv("BAIDU_API_KEY")) |
|
return {**st, "final_answer": resp, "perf": {"time": time.time() - t0, "prov": "ERNIE"}} |
|
|
|
def pick(st: EnhancedAgentState) -> str: |
|
return st["agent_type"] |
|
|
|
g = StateGraph(EnhancedAgentState) |
|
g.add_node("router", router) |
|
g.add_node("groq", groq_node) |
|
g.add_node("nvidia", nvidia_node) |
|
g.add_node("gemini", gemini_node) |
|
g.add_node("deepseek", deepseek_node) |
|
g.add_node("baidu", baidu_node) |
|
g.set_entry_point("router") |
|
g.add_conditional_edges("router", pick, { |
|
"groq": "groq", |
|
"nvidia": "nvidia", |
|
"gemini": "gemini", |
|
"deepseek": "deepseek", |
|
"baidu": "baidu" |
|
}) |
|
for n in ["groq", "nvidia", "gemini", "deepseek", "baidu"]: |
|
g.add_edge(n, END) |
|
return g.compile(checkpointer=MemorySaver()) |
|
|
|
def process_query(self, q: str) -> str: |
|
state = { |
|
"messages": [HumanMessage(content=q)], |
|
"query": q, |
|
"agent_type": "", |
|
"final_answer": "", |
|
"perf": {}, |
|
"agno_resp": "" |
|
} |
|
cfg = {"configurable": {"thread_id": f"hyb_{hash(q)}"}} |
|
out = self.graph.invoke(state, cfg) |
|
raw_answer = out["final_answer"] |
|
parts = raw_answer.split('\n\n', 1) |
|
answer_part = parts[1].strip() if len(parts) > 1 else raw_answer.strip() |
|
return answer_part |
|
|
|
def build_graph(provider=None): |
|
return HybridLangGraphMultiLLMSystem().graph |
|
|