File size: 5,957 Bytes
d4557ee f4505e9 e292008 1fa6961 f4505e9 0c69489 e292008 0c69489 e292008 f4505e9 e292008 a35ea13 f4505e9 f0e66e7 f4505e9 f0e66e7 1fa6961 f4505e9 f0e66e7 cc467c2 f4505e9 f0e66e7 f4505e9 f0e66e7 a55679f f4505e9 f0e66e7 25c1140 e292008 a55679f 25c1140 d4557ee 0f81d99 f4505e9 9a3d597 f0e66e7 25c1140 e292008 d4557ee a55679f e292008 a55679f 25c1140 d4557ee b1b6e20 a55679f e292008 f4505e9 e292008 d4557ee f4505e9 0c69489 e292008 a55679f d4557ee a55679f e292008 d4557ee e292008 a55679f e292008 a35ea13 e292008 a55679f e292008 0c69489 e292008 0c69489 e292008 0c69489 e292008 0c69489 e292008 0c69489 e292008 a55679f 0c69489 e292008 0c69489 0ab2059 e292008 d4557ee 0f81d99 0c69489 72c7dbb 0c69489 e292008 a35ea13 e292008 0c69489 e292008 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
import os
import time
import random
import operator
from typing import List, Dict, Any, TypedDict, Annotated
from dotenv import load_dotenv
from langchain_core.tools import tool
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.document_loaders import WikipediaLoader
from langgraph.graph import StateGraph, END
from langgraph.checkpoint.memory import MemorySaver
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
from langchain_groq import ChatGroq
load_dotenv() # expects GROQ_API_KEY in your .env
@tool
def multiply(a: int, b: int) -> int:
'''multiplies two numbers'''
return a * b
@tool
def add(a: int, b: int) -> int:
'''adds two numbers'''
return a + b
@tool
def subtract(a: int, b: int) -> int:
'''subtracts two numbers'''
return a - b
@tool
def divide(a: int, b: int) -> float:
'''divides two numbers'''
if b == 0:
raise ValueError("Cannot divide by zero.")
return a / b
@tool
def modulus(a: int, b: int) -> int:
'''returns the remainder while dividing two numbers'''
return a % b
@tool
def optimized_web_search(query: str) -> str:
'''searches the web using tavily'''
try:
time.sleep(random.uniform(0.7, 1.5))
docs = TavilySearchResults(max_results=2).invoke(query=query)
return "\n\n---\n\n".join(
f"<Doc url='{d.get('url','')}'>{d.get('content','')[:500]}</Doc>"
for d in docs
)
except Exception as e:
return f"Web search failed: {e}"
@tool
def optimized_wiki_search(query: str) -> str:
'''searches wikipedia'''
try:
time.sleep(random.uniform(0.3, 1))
docs = WikipediaLoader(query=query, load_max_docs=1).load()
return "\n\n---\n\n".join(
f"<Doc src='{d.metadata.get('source','Wikipedia')}'>{d.page_content[:800]}</Doc>"
for d in docs
)
except Exception as e:
return f"Wikipedia search failed: {e}"
class EnhancedAgentState(TypedDict):
messages: Annotated[List[HumanMessage | AIMessage], operator.add]
query: str
agent_type: str
final_answer: str
perf: Dict[str, Any]
agno_resp: str
class HybridLangGraphMultiLLMSystem:
"""
Router that picks between Groq-hosted Llama-3 8B, Llama-3 70B (default),
and Groq-hosted DeepSeek-Chat according to the query content.
"""
def __init__(self):
self.tools = [
multiply, add, subtract, divide, modulus,
optimized_web_search, optimized_wiki_search
]
self.graph = self._build_graph()
def _llm(self, model_name: str):
return ChatGroq(
model=model_name,
temperature=0,
api_key=os.getenv("GROQ_API_KEY")
)
def _build_graph(self):
llama8_llm = self._llm("llama3-8b-8192")
llama70_llm = self._llm("llama3-70b-8192")
deepseek_llm = self._llm("deepseek-chat")
def router(st: EnhancedAgentState) -> EnhancedAgentState:
q = st["query"].lower()
if "llama-8" in q:
t = "llama8"
elif "deepseek" in q:
t = "deepseek"
else:
t = "llama70"
return {**st, "agent_type": t}
def llama8_node(st: EnhancedAgentState) -> EnhancedAgentState:
t0 = time.time()
sys = SystemMessage(content="You are a helpful AI assistant.")
res = llama8_llm.invoke([sys, HumanMessage(content=st["query"])])
return {**st,
"final_answer": res.content,
"perf": {"time": time.time() - t0, "prov": "Groq-Llama3-8B"}}
def llama70_node(st: EnhancedAgentState) -> EnhancedAgentState:
t0 = time.time()
sys = SystemMessage(content="You are a helpful AI assistant.")
res = llama70_llm.invoke([sys, HumanMessage(content=st["query"])])
return {**st,
"final_answer": res.content,
"perf": {"time": time.time() - t0, "prov": "Groq-Llama3-70B"}}
def deepseek_node(st: EnhancedAgentState) -> EnhancedAgentState:
t0 = time.time()
sys = SystemMessage(content="You are a helpful AI assistant.")
res = deepseek_llm.invoke([sys, HumanMessage(content=st["query"])])
return {**st,
"final_answer": res.content,
"perf": {"time": time.time() - t0, "prov": "Groq-DeepSeek"}}
g = StateGraph(EnhancedAgentState)
g.add_node("router", router)
g.add_node("llama8", llama8_node)
g.add_node("llama70", llama70_node)
g.add_node("deepseek", deepseek_node)
g.set_entry_point("router")
g.add_conditional_edges("router", lambda s: s["agent_type"],
{"llama8": "llama8", "llama70": "llama70", "deepseek": "deepseek"})
g.add_edge("llama8", END)
g.add_edge("llama70", END)
g.add_edge("deepseek", END)
return g.compile(checkpointer=MemorySaver())
def process_query(self, q: str) -> str:
state = {
"messages": [HumanMessage(content=q)],
"query": q,
"agent_type": "",
"final_answer": "",
"perf": {},
"agno_resp": ""
}
cfg = {"configurable": {"thread_id": f"hyb_{hash(q)}"}}
out = self.graph.invoke(state, cfg)
return out.get("final_answer", "").strip()
def build_graph(provider: str | None = None):
return HybridLangGraphMultiLLMSystem().graph
if __name__ == "__main__":
qa_system = HybridLangGraphMultiLLMSystem()
# Test each model
print(qa_system.process_query("llama-8: What is the capital of France?"))
print(qa_system.process_query("llama-70: Tell me about quantum mechanics."))
print(qa_system.process_query("deepseek: What is the Riemann Hypothesis?"))
|