File size: 7,870 Bytes
76738d8 504dc8a 76738d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 |
import os
import time
import random
from dotenv import load_dotenv
from typing import List, Dict, Any, TypedDict, Annotated
import operator
from langchain_core.tools import tool
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.document_loaders import WikipediaLoader
from langchain_community.vectorstores import Chroma
from langchain.tools.retriever import create_retriever_tool
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
from langchain_community.embeddings import SentenceTransformerEmbeddings
from langgraph.graph import StateGraph, START, END
from langgraph.checkpoint.memory import MemorySaver
# ---- Tool Definitions ----
@tool
def multiply(a: int, b: int) -> int:
"""Multiply two integers and return the product."""
return a * b
@tool
def add(a: int, b: int) -> int:
"""Add two integers and return the sum."""
return a + b
@tool
def subtract(a: int, b: int) -> int:
"""Subtract the second integer from the first and return the difference."""
return a - b
@tool
def divide(a: int, b: int) -> float:
"""Divide the first integer by the second and return the quotient."""
if b == 0:
raise ValueError("Cannot divide by zero.")
return a / b
@tool
def modulus(a: int, b: int) -> int:
"""Return the remainder of the division of the first integer by the second."""
return a % b
@tool
def optimized_web_search(query: str) -> str:
"""Perform an optimized web search using TavilySearchResults and return concatenated document snippets."""
try:
time.sleep(random.uniform(1, 2))
docs = TavilySearchResults(max_results=2).invoke(query=query)
return "\n\n---\n\n".join(
f"<Doc url='{d.get('url','')}'>{d.get('content','')[:500]}</Doc>"
for d in docs
)
except Exception as e:
return f"Web search failed: {e}"
@tool
def optimized_wiki_search(query: str) -> str:
"""Perform an optimized Wikipedia search and return concatenated document snippets."""
try:
time.sleep(random.uniform(0.5, 1))
docs = WikipediaLoader(query=query, load_max_docs=1).load()
return "\n\n---\n\n".join(
f"<Doc src='{d.metadata['source']}'>{d.page_content[:800]}</Doc>"
for d in docs
)
except Exception as e:
return f"Wikipedia search failed: {e}"
# ---- LLM Integrations ----
load_dotenv()
from langchain_groq import ChatGroq
from langchain_nvidia_ai_endpoints import ChatNVIDIA
from google import genai
import requests
def baidu_ernie_generate(prompt, api_key=None):
url = "https://api.baidu.com/ernie/v1/generate"
headers = {"Authorization": f"Bearer {api_key}"}
data = {"model": "ernie-4.5", "prompt": prompt}
try:
resp = requests.post(url, headers=headers, json=data, timeout=30)
return resp.json().get("result", "")
except Exception as e:
return f"ERNIE API error: {e}"
def deepseek_generate(prompt, api_key=None):
url = "https://api.deepseek.com/v1/chat/completions"
headers = {"Authorization": f"Bearer {api_key}"}
data = {"model": "deepseek-chat", "messages": [{"role": "user", "content": prompt}]}
try:
resp = requests.post(url, headers=headers, json=data, timeout=30)
choices = resp.json().get("choices", [{}])
if choices and "message" in choices[0]:
return choices[0]["message"].get("content", "")
return ""
except Exception as e:
return f"DeepSeek API error: {e}"
class EnhancedAgentState(TypedDict):
messages: Annotated[List[HumanMessage|AIMessage], operator.add]
query: str
agent_type: str
final_answer: str
perf: Dict[str,Any]
agno_resp: str
class HybridLangGraphMultiLLMSystem:
def __init__(self):
self.tools = [
multiply, add, subtract, divide, modulus,
optimized_web_search, optimized_wiki_search
]
self.graph = self._build_graph()
def _build_graph(self):
groq_llm = ChatGroq(model="llama3-70b-8192", temperature=0, api_key=os.getenv("GROQ_API_KEY"))
nvidia_llm = ChatNVIDIA(model="meta/llama3-70b-instruct", temperature=0, api_key=os.getenv("NVIDIA_API_KEY"))
def router(st: EnhancedAgentState) -> EnhancedAgentState:
q = st["query"].lower()
if "groq" in q: t = "groq"
elif "nvidia" in q: t = "nvidia"
elif "gemini" in q or "google" in q: t = "gemini"
elif "deepseek" in q: t = "deepseek"
elif "ernie" in q or "baidu" in q: t = "baidu"
else: t = "groq" # default
return {**st, "agent_type": t}
def groq_node(st: EnhancedAgentState) -> EnhancedAgentState:
t0 = time.time()
sys = SystemMessage(content="Answer as an expert.")
res = groq_llm.invoke([sys, HumanMessage(content=st["query"])])
return {**st, "final_answer": res.content, "perf": {"time": time.time() - t0, "prov": "Groq"}}
def nvidia_node(st: EnhancedAgentState) -> EnhancedAgentState:
t0 = time.time()
sys = SystemMessage(content="Answer as an expert.")
res = nvidia_llm.invoke([sys, HumanMessage(content=st["query"])])
return {**st, "final_answer": res.content, "perf": {"time": time.time() - t0, "prov": "NVIDIA"}}
def gemini_node(st: EnhancedAgentState) -> EnhancedAgentState:
t0 = time.time()
genai.configure(api_key=os.getenv("GEMINI_API_KEY"))
model = genai.GenerativeModel("gemini-1.5-pro-latest")
res = model.generate_content(st["query"])
return {**st, "final_answer": res.text, "perf": {"time": time.time() - t0, "prov": "Gemini"}}
def deepseek_node(st: EnhancedAgentState) -> EnhancedAgentState:
t0 = time.time()
resp = deepseek_generate(st["query"], api_key=os.getenv("DEEPSEEK_API_KEY"))
return {**st, "final_answer": resp, "perf": {"time": time.time() - t0, "prov": "DeepSeek"}}
def baidu_node(st: EnhancedAgentState) -> EnhancedAgentState:
t0 = time.time()
resp = baidu_ernie_generate(st["query"], api_key=os.getenv("BAIDU_API_KEY"))
return {**st, "final_answer": resp, "perf": {"time": time.time() - t0, "prov": "ERNIE"}}
def pick(st: EnhancedAgentState) -> str:
return st["agent_type"]
g = StateGraph(EnhancedAgentState)
g.add_node("router", router)
g.add_node("groq", groq_node)
g.add_node("nvidia", nvidia_node)
g.add_node("gemini", gemini_node)
g.add_node("deepseek", deepseek_node)
g.add_node("baidu", baidu_node)
g.set_entry_point("router")
g.add_conditional_edges("router", pick, {
"groq": "groq",
"nvidia": "nvidia",
"gemini": "gemini",
"deepseek": "deepseek",
"baidu": "baidu"
})
for n in ["groq", "nvidia", "gemini", "deepseek", "baidu"]:
g.add_edge(n, END)
return g.compile(checkpointer=MemorySaver())
def process_query(self, q: str) -> str:
state = {
"messages": [HumanMessage(content=q)],
"query": q,
"agent_type": "",
"final_answer": "",
"perf": {},
"agno_resp": ""
}
cfg = {"configurable": {"thread_id": f"hyb_{hash(q)}"}}
out = self.graph.invoke(state, cfg)
raw_answer = out["final_answer"]
parts = raw_answer.split('\n\n', 1)
answer_part = parts[1].strip() if len(parts) > 1 else raw_answer.strip()
return answer_part
def build_graph(provider=None):
return HybridLangGraphMultiLLMSystem().graph
|