josondev commited on
Commit
76738d8
·
verified ·
1 Parent(s): 0c69489

Create agent.py

Browse files
Files changed (1) hide show
  1. agent.py +204 -0
agent.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import random
4
+ from dotenv import load_dotenv
5
+ from typing import List, Dict, Any, TypedDict, Annotated
6
+ import operator
7
+
8
+ from langchain_core.tools import tool
9
+ from langchain_community.tools.tavily_search import TavilySearchResults
10
+ from langchain_community.document_loaders import WikipediaLoader
11
+ from langchain_community.vectorstores import Chroma
12
+ from langchain.tools.retriever import create_retriever_tool
13
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
14
+ from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
15
+ from langchain_community.embeddings import SentenceTransformerEmbeddings
16
+
17
+ from langgraph.graph import StateGraph, START, END
18
+ from langgraph.checkpoint.memory import MemorySaver
19
+
20
+ # ---- Tool Definitions ----
21
+ @tool
22
+ def multiply(a: int, b: int) -> int:
23
+ """Multiply two integers and return the product."""
24
+ return a * b
25
+
26
+ @tool
27
+ def add(a: int, b: int) -> int:
28
+ """Add two integers and return the sum."""
29
+ return a + b
30
+
31
+ @tool
32
+ def subtract(a: int, b: int) -> int:
33
+ """Subtract the second integer from the first and return the difference."""
34
+ return a - b
35
+
36
+ @tool
37
+ def divide(a: int, b: int) -> float:
38
+ """Divide the first integer by the second and return the quotient."""
39
+ if b == 0:
40
+ raise ValueError("Cannot divide by zero.")
41
+ return a / b
42
+
43
+ @tool
44
+ def modulus(a: int, b: int) -> int:
45
+ """Return the remainder of the division of the first integer by the second."""
46
+ return a % b
47
+
48
+ @tool
49
+ def optimized_web_search(query: str) -> str:
50
+ """Perform an optimized web search using TavilySearchResults and return concatenated document snippets."""
51
+ try:
52
+ time.sleep(random.uniform(1, 2))
53
+ docs = TavilySearchResults(max_results=2).invoke(query=query)
54
+ return "\n\n---\n\n".join(
55
+ f"<Doc url='{d.get('url','')}'>{d.get('content','')[:500]}</Doc>"
56
+ for d in docs
57
+ )
58
+ except Exception as e:
59
+ return f"Web search failed: {e}"
60
+
61
+ @tool
62
+ def optimized_wiki_search(query: str) -> str:
63
+ """Perform an optimized Wikipedia search and return concatenated document snippets."""
64
+ try:
65
+ time.sleep(random.uniform(0.5, 1))
66
+ docs = WikipediaLoader(query=query, load_max_docs=1).load()
67
+ return "\n\n---\n\n".join(
68
+ f"<Doc src='{d.metadata['source']}'>{d.page_content[:800]}</Doc>"
69
+ for d in docs
70
+ )
71
+ except Exception as e:
72
+ return f"Wikipedia search failed: {e}"
73
+
74
+ # ---- LLM Integrations ----
75
+ load_dotenv()
76
+
77
+ from langchain_groq import ChatGroq
78
+ from langchain_nvidia_ai_endpoints import ChatNVIDIA
79
+ from google import genai
80
+
81
+ import requests
82
+
83
+ def baidu_ernie_generate(prompt, api_key=None):
84
+ url = "https://api.baidu.com/ernie/v1/generate"
85
+ headers = {"Authorization": f"Bearer {api_key}"}
86
+ data = {"model": "ernie-4.5", "prompt": prompt}
87
+ try:
88
+ resp = requests.post(url, headers=headers, json=data, timeout=30)
89
+ return resp.json().get("result", "")
90
+ except Exception as e:
91
+ return f"ERNIE API error: {e}"
92
+
93
+ def deepseek_generate(prompt, api_key=None):
94
+ url = "https://api.deepseek.com/v1/chat/completions"
95
+ headers = {"Authorization": f"Bearer {api_key}"}
96
+ data = {"model": "deepseek-chat", "messages": [{"role": "user", "content": prompt}]}
97
+ try:
98
+ resp = requests.post(url, headers=headers, json=data, timeout=30)
99
+ choices = resp.json().get("choices", [{}])
100
+ if choices and "message" in choices[0]:
101
+ return choices[0]["message"].get("content", "")
102
+ return ""
103
+ except Exception as e:
104
+ return f"DeepSeek API error: {e}"
105
+
106
+ class EnhancedAgentState(TypedDict):
107
+ messages: Annotated[List[HumanMessage|AIMessage], operator.add]
108
+ query: str
109
+ agent_type: str
110
+ final_answer: str
111
+ perf: Dict[str,Any]
112
+ agno_resp: str
113
+
114
+ class HybridLangGraphMultiLLMSystem:
115
+ def __init__(self):
116
+ self.tools = [
117
+ multiply, add, subtract, divide, modulus,
118
+ optimized_web_search, optimized_wiki_search
119
+ ]
120
+ self.graph = self._build_graph()
121
+
122
+ def _build_graph(self):
123
+ groq_llm = ChatGroq(model="llama3-70b-8192", temperature=0, api_key=os.getenv("GROQ_API_KEY"))
124
+ nvidia_llm = ChatNVIDIA(model="meta/llama3-70b-instruct", temperature=0, api_key=os.getenv("NVIDIA_API_KEY"))
125
+
126
+ def router(st: EnhancedAgentState) -> EnhancedAgentState:
127
+ q = st["query"].lower()
128
+ if "groq" in q: t = "groq"
129
+ elif "nvidia" in q: t = "nvidia"
130
+ elif "gemini" in q or "google" in q: t = "gemini"
131
+ elif "deepseek" in q: t = "deepseek"
132
+ elif "ernie" in q or "baidu" in q: t = "baidu"
133
+ else: t = "groq" # default
134
+ return {**st, "agent_type": t}
135
+
136
+ def groq_node(st: EnhancedAgentState) -> EnhancedAgentState:
137
+ t0 = time.time()
138
+ sys = SystemMessage(content="Answer as an expert.")
139
+ res = groq_llm.invoke([sys, HumanMessage(content=st["query"])])
140
+ return {**st, "final_answer": res.content, "perf": {"time": time.time() - t0, "prov": "Groq"}}
141
+
142
+ def nvidia_node(st: EnhancedAgentState) -> EnhancedAgentState:
143
+ t0 = time.time()
144
+ sys = SystemMessage(content="Answer as an expert.")
145
+ res = nvidia_llm.invoke([sys, HumanMessage(content=st["query"])])
146
+ return {**st, "final_answer": res.content, "perf": {"time": time.time() - t0, "prov": "NVIDIA"}}
147
+
148
+ def gemini_node(st: EnhancedAgentState) -> EnhancedAgentState:
149
+ t0 = time.time()
150
+ genai.configure(api_key=os.getenv("GEMINI_API_KEY"))
151
+ model = genai.GenerativeModel("gemini-1.5-pro-latest")
152
+ res = model.generate_content(st["query"])
153
+ return {**st, "final_answer": res.text, "perf": {"time": time.time() - t0, "prov": "Gemini"}}
154
+
155
+ def deepseek_node(st: EnhancedAgentState) -> EnhancedAgentState:
156
+ t0 = time.time()
157
+ resp = deepseek_generate(st["query"], api_key=os.getenv("DEEPSEEK_API_KEY"))
158
+ return {**st, "final_answer": resp, "perf": {"time": time.time() - t0, "prov": "DeepSeek"}}
159
+
160
+ def baidu_node(st: EnhancedAgentState) -> EnhancedAgentState:
161
+ t0 = time.time()
162
+ resp = baidu_ernie_generate(st["query"], api_key=os.getenv("BAIDU_API_KEY"))
163
+ return {**st, "final_answer": resp, "perf": {"time": time.time() - t0, "prov": "ERNIE"}}
164
+
165
+ def pick(st: EnhancedAgentState) -> str:
166
+ return st["agent_type"]
167
+
168
+ g = StateGraph(EnhancedAgentState)
169
+ g.add_node("router", router)
170
+ g.add_node("groq", groq_node)
171
+ g.add_node("nvidia", nvidia_node)
172
+ g.add_node("gemini", gemini_node)
173
+ g.add_node("deepseek", deepseek_node)
174
+ g.add_node("baidu", baidu_node)
175
+ g.set_entry_point("router")
176
+ g.add_conditional_edges("router", pick, {
177
+ "groq": "groq",
178
+ "nvidia": "nvidia",
179
+ "gemini": "gemini",
180
+ "deepseek": "deepseek",
181
+ "baidu": "baidu"
182
+ })
183
+ for n in ["groq", "nvidia", "gemini", "deepseek", "baidu"]:
184
+ g.add_edge(n, END)
185
+ return g.compile(checkpointer=MemorySaver())
186
+
187
+ def process_query(self, q: str) -> str:
188
+ state = {
189
+ "messages": [HumanMessage(content=q)],
190
+ "query": q,
191
+ "agent_type": "",
192
+ "final_answer": "",
193
+ "perf": {},
194
+ "agno_resp": ""
195
+ }
196
+ cfg = {"configurable": {"thread_id": f"hyb_{hash(q)}"}}
197
+ out = self.graph.invoke(state, cfg)
198
+ raw_answer = out["final_answer"]
199
+ parts = raw_answer.split('\n\n', 1)
200
+ answer_part = parts[1].strip() if len(parts) > 1 else raw_answer.strip()
201
+ return answer_part
202
+
203
+ def build_graph():
204
+ return HybridLangGraphMultiLLMSystem().graph