josondev commited on
Commit
b102339
·
verified ·
1 Parent(s): c870dc0

Update veryfinal.py

Browse files
Files changed (1) hide show
  1. veryfinal.py +96 -75
veryfinal.py CHANGED
@@ -7,7 +7,6 @@ load_dotenv()
7
  # Imports
8
  from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings
9
  from langchain_groq import ChatGroq
10
- from langchain_google_genai import ChatGoogleGenerativeAI
11
  from langchain_nvidia_ai_endpoints import ChatNVIDIA
12
  from langchain_community.tools.tavily_search import TavilySearchResults
13
  from langchain_community.document_loaders import WikipediaLoader
@@ -42,7 +41,57 @@ nvidia_rate_limiter = InMemoryRateLimiter(
42
  max_bucket_size=10
43
  )
44
 
45
- # Define all tools
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  @tool
47
  def multiply(a: int | float, b: int | float) -> int | float:
48
  """Multiply two numbers.
@@ -94,6 +143,7 @@ def modulus(a: int | float, b: int | float) -> int | float:
94
  """
95
  return a % b
96
 
 
97
  @tool
98
  def wiki_search(query: str) -> str:
99
  """Search the wikipedia for a query and return the first paragraph
@@ -120,7 +170,6 @@ def web_search(query: str) -> str:
120
  query: The search query.
121
  """
122
  try:
123
- # Add delay to prevent rate limiting
124
  time.sleep(random.uniform(1, 3))
125
  search_docs = TavilySearchResults(max_results=3).invoke(query=query)
126
  formatted_search_docs = "\n\n---\n\n".join(
@@ -179,58 +228,6 @@ json_chunks = text_splitter.split_documents(json_docs)
179
  # Create vector database
180
  database = FAISS.from_documents(json_chunks, NVIDIAEmbeddings())
181
 
182
- # Initialize LLMs with rate limiting
183
- def create_rate_limited_llm(provider="groq"):
184
- """Create rate-limited LLM based on provider"""
185
-
186
- if provider == "groq":
187
- return ChatGroq(
188
- model="llama-3.3-70b-versatile",
189
- temperature=0,
190
- api_key=os.getenv("GROQ_API_KEY"),
191
- rate_limiter=groq_rate_limiter,
192
- max_retries=2,
193
- request_timeout=60
194
- )
195
- elif provider == "google":
196
- return ChatGoogleGenerativeAI(
197
- model="gemini-2.0-flash-exp",
198
- temperature=0,
199
- api_key=os.getenv("GOOGLE_API_KEY"),
200
- rate_limiter=google_rate_limiter,
201
- max_retries=2,
202
- timeout=60
203
- )
204
- elif provider == "nvidia":
205
- return ChatNVIDIA(
206
- model="meta/llama-3.1-405b-instruct",
207
- temperature=0,
208
- api_key=os.getenv("NVIDIA_API_KEY"),
209
- rate_limiter=nvidia_rate_limiter,
210
- max_retries=2
211
- )
212
-
213
- # Create fallback chain with exponential backoff
214
- def create_llm_with_smart_fallbacks():
215
- """Create LLM with intelligent fallback and rate limiting"""
216
-
217
- # Primary: Groq (fastest)
218
- primary_llm = create_rate_limited_llm("groq")
219
-
220
- # Fallback 1: Google (most capable)
221
- fallback_1 = create_rate_limited_llm("google")
222
-
223
- # Fallback 2: NVIDIA (reliable)
224
- fallback_2 = create_rate_limited_llm("nvidia")
225
-
226
- # Create fallback chain
227
- llm_with_fallbacks = primary_llm.with_fallbacks([fallback_1, fallback_2])
228
-
229
- return llm_with_fallbacks
230
-
231
- # Initialize LLM with smart fallbacks
232
- llm = create_llm_with_smart_fallbacks()
233
-
234
  # Create retriever and retriever tool
235
  retriever = database.as_retriever(search_type="similarity", search_kwargs={"k": 3})
236
 
@@ -240,47 +237,74 @@ retriever_tool = create_retriever_tool(
240
  description="Search for similar questions and their solutions from the knowledge base."
241
  )
242
 
243
- # Combine all tools
244
  tools = [
 
245
  multiply,
246
  add,
247
  subtract,
248
  divide,
249
  modulus,
 
 
250
  wiki_search,
251
  web_search,
252
  arxiv_search,
253
- retriever_tool
 
 
 
 
254
  ]
255
 
 
 
 
 
 
 
 
 
256
  # Create memory for conversation
257
  memory = MemorySaver()
258
 
259
- # Create the agent
260
  agent_executor = create_react_agent(
261
- model=llm,
262
  tools=tools,
263
  checkpointer=memory
264
  )
265
 
266
- # Enhanced robust agent run with exponential backoff
267
  def robust_agent_run(query, thread_id="robust_conversation", max_retries=3):
268
- """Run agent with error handling, rate limiting, and exponential backoff"""
269
 
270
  for attempt in range(max_retries):
271
  try:
272
  config = {"configurable": {"thread_id": f"{thread_id}_{attempt}"}}
273
 
274
- system_msg = SystemMessage(content='''You are a helpful assistant tasked with answering questions using a set of tools.
275
- Now, I will ask you a question. Report your thoughts, and finish your answer with the following template:
276
- FINAL ANSWER: [YOUR FINAL ANSWER].
277
- YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
278
- Your answer should only start with "FINAL ANSWER: ", then follows with the answer.''')
 
 
 
 
 
 
 
 
 
 
 
 
279
 
280
  user_msg = HumanMessage(content=query)
281
  result = []
282
 
283
- print(f"Attempt {attempt + 1}: Processing query...")
284
 
285
  for step in agent_executor.stream(
286
  {"messages": [system_msg, user_msg]},
@@ -296,9 +320,8 @@ def robust_agent_run(query, thread_id="robust_conversation", max_retries=3):
296
  except Exception as e:
297
  error_msg = str(e).lower()
298
 
299
- # Check for rate limit errors
300
  if any(keyword in error_msg for keyword in ['rate limit', 'too many requests', '429', 'quota exceeded']):
301
- wait_time = (2 ** attempt) + random.uniform(1, 3) # Exponential backoff with jitter
302
  print(f"Rate limit hit on attempt {attempt + 1}. Waiting {wait_time:.2f} seconds...")
303
  time.sleep(wait_time)
304
 
@@ -306,7 +329,6 @@ def robust_agent_run(query, thread_id="robust_conversation", max_retries=3):
306
  return f"Rate limit exceeded after {max_retries} attempts: {str(e)}"
307
  continue
308
 
309
- # Check for other API errors
310
  elif any(keyword in error_msg for keyword in ['api', 'connection', 'timeout', 'service unavailable']):
311
  wait_time = (2 ** attempt) + random.uniform(0.5, 1.5)
312
  print(f"API error on attempt {attempt + 1}. Retrying in {wait_time:.2f} seconds...")
@@ -317,7 +339,6 @@ def robust_agent_run(query, thread_id="robust_conversation", max_retries=3):
317
  continue
318
 
319
  else:
320
- # Non-recoverable error
321
  return f"Error occurred: {str(e)}"
322
 
323
  return "Maximum retries exceeded"
@@ -327,7 +348,7 @@ request_count = 0
327
  last_request_time = time.time()
328
 
329
  def main(query: str) -> str:
330
- """Main function to run the agent with request tracking"""
331
  global request_count, last_request_time
332
 
333
  current_time = time.time()
@@ -338,15 +359,15 @@ def main(query: str) -> str:
338
  last_request_time = current_time
339
 
340
  request_count += 1
341
- print(f"Processing request #{request_count}")
342
 
343
- # Add small delay between requests to prevent overwhelming APIs
344
  if request_count > 1:
345
  time.sleep(random.uniform(2, 5))
346
 
347
  return robust_agent_run(query)
348
 
349
  if __name__ == "__main__":
350
- # Test the agent
351
  result = main("What are the names of the US presidents who were assassinated?")
352
  print(result)
 
7
  # Imports
8
  from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings
9
  from langchain_groq import ChatGroq
 
10
  from langchain_nvidia_ai_endpoints import ChatNVIDIA
11
  from langchain_community.tools.tavily_search import TavilySearchResults
12
  from langchain_community.document_loaders import WikipediaLoader
 
41
  max_bucket_size=10
42
  )
43
 
44
+ # Initialize individual LLMs
45
+ groq_llm = ChatGroq(
46
+ model="llama-3.3-70b-versatile",
47
+ temperature=0,
48
+ api_key=os.getenv("GROQ_API_KEY"),
49
+ rate_limiter=groq_rate_limiter,
50
+ max_retries=2,
51
+ request_timeout=60
52
+ )
53
+
54
+ nvidia_llm = ChatNVIDIA(
55
+ model="meta/llama-3.1-405b-instruct",
56
+ temperature=0,
57
+ api_key=os.getenv("NVIDIA_API_KEY"),
58
+ rate_limiter=nvidia_rate_limiter,
59
+ max_retries=2
60
+ )
61
+
62
+ # Create LLM tools that can be selected by the agent
63
+ @tool
64
+ def groq_reasoning_tool(query: str) -> str:
65
+ """Use Groq's Llama model for fast reasoning, mathematical calculations, and logical problems.
66
+ Best for: Math problems, logical reasoning, quick calculations, code generation.
67
+
68
+ Args:
69
+ query: The question or problem to solve
70
+ """
71
+ try:
72
+ time.sleep(random.uniform(1, 2)) # Rate limiting
73
+ response = groq_llm.invoke([HumanMessage(content=query)])
74
+ return f"Groq Response: {response.content}"
75
+ except Exception as e:
76
+ return f"Groq tool failed: {str(e)}"
77
+
78
+
79
+ @tool
80
+ def nvidia_specialist_tool(query: str) -> str:
81
+ """Use NVIDIA's large model for specialized tasks, technical questions, and domain expertise.
82
+ Best for: Technical questions, specialized domains, scientific problems, detailed analysis.
83
+
84
+ Args:
85
+ query: The specialized question or technical problem
86
+ """
87
+ try:
88
+ time.sleep(random.uniform(2, 4)) # Rate limiting
89
+ response = nvidia_llm.invoke([HumanMessage(content=query)])
90
+ return f"NVIDIA Response: {response.content}"
91
+ except Exception as e:
92
+ return f"NVIDIA tool failed: {str(e)}"
93
+
94
+ # Define calculation tools
95
  @tool
96
  def multiply(a: int | float, b: int | float) -> int | float:
97
  """Multiply two numbers.
 
143
  """
144
  return a % b
145
 
146
+ # Define search tools
147
  @tool
148
  def wiki_search(query: str) -> str:
149
  """Search the wikipedia for a query and return the first paragraph
 
170
  query: The search query.
171
  """
172
  try:
 
173
  time.sleep(random.uniform(1, 3))
174
  search_docs = TavilySearchResults(max_results=3).invoke(query=query)
175
  formatted_search_docs = "\n\n---\n\n".join(
 
228
  # Create vector database
229
  database = FAISS.from_documents(json_chunks, NVIDIAEmbeddings())
230
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  # Create retriever and retriever tool
232
  retriever = database.as_retriever(search_type="similarity", search_kwargs={"k": 3})
233
 
 
237
  description="Search for similar questions and their solutions from the knowledge base."
238
  )
239
 
240
+ # Combine all tools including LLM tools
241
  tools = [
242
+ # Math tools
243
  multiply,
244
  add,
245
  subtract,
246
  divide,
247
  modulus,
248
+
249
+ # Search tools
250
  wiki_search,
251
  web_search,
252
  arxiv_search,
253
+ retriever_tool,
254
+
255
+ # LLM tools - agent can choose which LLM to use
256
+ groq_reasoning_tool,
257
+ nvidia_specialist_tool
258
  ]
259
 
260
+ # Use a lightweight coordinator LLM (Groq for speed)
261
+ coordinator_llm = ChatGroq(
262
+ model="llama-3.3-70b-versatile",
263
+ temperature=0,
264
+ api_key=os.getenv("GROQ_API_KEY"),
265
+ rate_limiter=groq_rate_limiter
266
+ )
267
+
268
  # Create memory for conversation
269
  memory = MemorySaver()
270
 
271
+ # Create the agent with coordinator LLM
272
  agent_executor = create_react_agent(
273
+ model=coordinator_llm,
274
  tools=tools,
275
  checkpointer=memory
276
  )
277
 
278
+ # Enhanced robust agent run
279
  def robust_agent_run(query, thread_id="robust_conversation", max_retries=3):
280
+ """Run agent with error handling, rate limiting, and LLM tool selection"""
281
 
282
  for attempt in range(max_retries):
283
  try:
284
  config = {"configurable": {"thread_id": f"{thread_id}_{attempt}"}}
285
 
286
+ system_msg = SystemMessage(content='''You are a helpful assistant with access to multiple specialized LLM tools and other utilities.
287
+
288
+ AVAILABLE LLM TOOLS:
289
+ - groq_reasoning_tool: Fast reasoning, math, calculations, code (use for quick logical problems)
290
+ - google_analysis_tool: Complex analysis, creative tasks, detailed explanations (use for comprehensive analysis)
291
+ - nvidia_specialist_tool: Technical questions, specialized domains, scientific problems (use for expert-level tasks)
292
+
293
+ TOOL SELECTION STRATEGY:
294
+ - For math/calculations: Use basic math tools (add, multiply, etc.) OR groq_reasoning_tool for complex math
295
+ - For factual questions: Use web_search, wiki_search, or arxiv_search first
296
+ - For analysis/reasoning: Choose the most appropriate LLM tool based on complexity
297
+ - For technical/scientific: Use nvidia_specialist_tool
298
+ - For creative/comprehensive: Use google_analysis_tool
299
+ - For quick logical problems: Use groq_reasoning_tool
300
+
301
+ Always finish with: FINAL ANSWER: [YOUR FINAL ANSWER]
302
+ Your answer should be a number OR few words OR comma separated list as appropriate.''')
303
 
304
  user_msg = HumanMessage(content=query)
305
  result = []
306
 
307
+ print(f"Attempt {attempt + 1}: Processing query with multi-LLM agent...")
308
 
309
  for step in agent_executor.stream(
310
  {"messages": [system_msg, user_msg]},
 
320
  except Exception as e:
321
  error_msg = str(e).lower()
322
 
 
323
  if any(keyword in error_msg for keyword in ['rate limit', 'too many requests', '429', 'quota exceeded']):
324
+ wait_time = (2 ** attempt) + random.uniform(1, 3)
325
  print(f"Rate limit hit on attempt {attempt + 1}. Waiting {wait_time:.2f} seconds...")
326
  time.sleep(wait_time)
327
 
 
329
  return f"Rate limit exceeded after {max_retries} attempts: {str(e)}"
330
  continue
331
 
 
332
  elif any(keyword in error_msg for keyword in ['api', 'connection', 'timeout', 'service unavailable']):
333
  wait_time = (2 ** attempt) + random.uniform(0.5, 1.5)
334
  print(f"API error on attempt {attempt + 1}. Retrying in {wait_time:.2f} seconds...")
 
339
  continue
340
 
341
  else:
 
342
  return f"Error occurred: {str(e)}"
343
 
344
  return "Maximum retries exceeded"
 
348
  last_request_time = time.time()
349
 
350
  def main(query: str) -> str:
351
+ """Main function to run the multi-LLM agent"""
352
  global request_count, last_request_time
353
 
354
  current_time = time.time()
 
359
  last_request_time = current_time
360
 
361
  request_count += 1
362
+ print(f"Processing request #{request_count} with multi-LLM agent")
363
 
364
+ # Add delay between requests
365
  if request_count > 1:
366
  time.sleep(random.uniform(2, 5))
367
 
368
  return robust_agent_run(query)
369
 
370
  if __name__ == "__main__":
371
+ # Test the multi-LLM agent
372
  result = main("What are the names of the US presidents who were assassinated?")
373
  print(result)