josondev's picture
Update veryfinal.py
7c04f3e verified
raw
history blame
11.7 kB
import os, json, time, random
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
# Imports
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings
from langchain_groq import ChatGroq
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_nvidia_ai_endpoints import ChatNVIDIA
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.document_loaders import WikipediaLoader
from langchain_community.document_loaders import ArxivLoader
from langchain_community.vectorstores import FAISS
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.tools import tool
from langchain.tools.retriever import create_retriever_tool
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import JSONLoader
from langgraph.prebuilt import create_react_agent
from langgraph.checkpoint.memory import MemorySaver
from langchain_core.rate_limiters import InMemoryRateLimiter
# Rate limiters for different providers
groq_rate_limiter = InMemoryRateLimiter(
requests_per_second=0.5, # 30 requests per minute
check_every_n_seconds=0.1,
max_bucket_size=10
)
google_rate_limiter = InMemoryRateLimiter(
requests_per_second=0.33, # 20 requests per minute
check_every_n_seconds=0.1,
max_bucket_size=10
)
nvidia_rate_limiter = InMemoryRateLimiter(
requests_per_second=0.25, # 15 requests per minute
check_every_n_seconds=0.1,
max_bucket_size=10
)
# Define all tools
@tool
def multiply(a: int | float, b: int | float) -> int | float:
"""Multiply two numbers.
Args:
a: first int | float
b: second int | float
"""
return a * b
@tool
def add(a: int | float, b: int | float) -> int | float:
"""Add two numbers.
Args:
a: first int | float
b: second int | float
"""
return a + b
@tool
def subtract(a: int | float , b: int | float) -> int | float:
"""Subtract two numbers.
Args:
a: first int | float
b: second int | float
"""
return a - b
@tool
def divide(a: int | float, b: int | float) -> int | float:
"""Divide two numbers.
Args:
a: first int | float
b: second int | float
"""
if b == 0:
raise ValueError("Cannot divide by zero.")
return a / b
@tool
def modulus(a: int | float, b: int | float) -> int | float:
"""Get the modulus of two numbers.
Args:
a: first int | float
b: second int | float
"""
return a % b
@tool
def wiki_search(query: str) -> str:
"""Search the wikipedia for a query and return the first paragraph
args:
query: the query to search for
"""
try:
loader = WikipediaLoader(query=query, load_max_docs=1)
data = loader.load()
formatted_search_docs = "\n\n---\n\n".join(
[
f'\n{doc.page_content}\n'
for doc in data
])
return formatted_search_docs
except Exception as e:
return f"Wikipedia search failed: {str(e)}"
@tool
def web_search(query: str) -> str:
"""Search Tavily for a query and return maximum 3 results.
Args:
query: The search query.
"""
try:
# Add delay to prevent rate limiting
time.sleep(random.uniform(1, 3))
search_docs = TavilySearchResults(max_results=3).invoke(query=query)
formatted_search_docs = "\n\n---\n\n".join(
[
f'\n{doc.get("content", "")}\n'
for doc in search_docs
])
return formatted_search_docs
except Exception as e:
return f"Web search failed: {str(e)}"
@tool
def arxiv_search(query: str) -> str:
"""Search Arxiv for a query and return maximum 3 result.
Args:
query: The search query.
"""
try:
search_docs = ArxivLoader(query=query, load_max_docs=3).load()
formatted_search_docs = "\n\n---\n\n".join(
[
f'\n{doc.page_content[:1000]}\n'
for doc in search_docs
])
return formatted_search_docs
except Exception as e:
return f"ArXiv search failed: {str(e)}"
# Load and process your JSONL data
jq_schema = """
{
page_content: .Question,
metadata: {
task_id: .task_id,
Level: .Level,
Final_answer: ."Final answer",
file_name: .file_name,
Steps: .["Annotator Metadata"].Steps,
Number_of_steps: .["Annotator Metadata"]["Number of steps"],
How_long: .["Annotator Metadata"]["How long did this take?"],
Tools: .["Annotator Metadata"].Tools,
Number_of_tools: .["Annotator Metadata"]["Number of tools"]
}
}
"""
# Load documents and create vector database
json_loader = JSONLoader(file_path="metadata.jsonl", jq_schema=jq_schema, json_lines=True, text_content=False)
json_docs = json_loader.load()
# Split documents
text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=200)
json_chunks = text_splitter.split_documents(json_docs)
# Create vector database
database = FAISS.from_documents(json_chunks, NVIDIAEmbeddings())
# Initialize LLMs with rate limiting
def create_rate_limited_llm(provider="groq"):
"""Create rate-limited LLM based on provider"""
if provider == "groq":
return ChatGroq(
model="llama-3.3-70b-versatile",
temperature=0,
api_key=os.getenv("GROQ_API_KEY"),
rate_limiter=groq_rate_limiter,
max_retries=2,
request_timeout=60
)
elif provider == "google":
return ChatGoogleGenerativeAI(
model="gemini-2.0-flash-exp",
temperature=0,
api_key=os.getenv("GOOGLE_API_KEY"),
rate_limiter=google_rate_limiter,
max_retries=2,
request_timeout=60
)
elif provider == "nvidia":
return ChatNVIDIA(
model="meta/llama-3.1-405b-instruct",
temperature=0,
api_key=os.getenv("NVIDIA_API_KEY"),
rate_limiter=nvidia_rate_limiter,
max_retries=2
)
# Create fallback chain with exponential backoff
def create_llm_with_smart_fallbacks():
"""Create LLM with intelligent fallback and rate limiting"""
# Primary: Groq (fastest)
primary_llm = create_rate_limited_llm("groq")
# Fallback 1: Google (most capable)
fallback_1 = create_rate_limited_llm("google")
# Fallback 2: NVIDIA (reliable)
fallback_2 = create_rate_limited_llm("nvidia")
# Create fallback chain
llm_with_fallbacks = primary_llm.with_fallbacks([fallback_1, fallback_2])
return llm_with_fallbacks
# Initialize LLM with smart fallbacks
llm = create_llm_with_smart_fallbacks()
# Create retriever and retriever tool
retriever = database.as_retriever(search_type="similarity", search_kwargs={"k": 3})
retriever_tool = create_retriever_tool(
retriever=retriever,
name="question_search",
description="Search for similar questions and their solutions from the knowledge base."
)
# Combine all tools
tools = [
multiply,
add,
subtract,
divide,
modulus,
wiki_search,
web_search,
arxiv_search,
retriever_tool
]
# Create memory for conversation
memory = MemorySaver()
# Create the agent
agent_executor = create_react_agent(
model=llm,
tools=tools,
checkpointer=memory
)
# Enhanced robust agent run with exponential backoff
def robust_agent_run(query, thread_id="robust_conversation", max_retries=3):
"""Run agent with error handling, rate limiting, and exponential backoff"""
for attempt in range(max_retries):
try:
config = {"configurable": {"thread_id": f"{thread_id}_{attempt}"}}
system_msg = SystemMessage(content='''You are a helpful assistant tasked with answering questions using a set of tools.
Now, I will ask you a question. Report your thoughts, and finish your answer with the following template:
FINAL ANSWER: [YOUR FINAL ANSWER].
YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
Your answer should only start with "FINAL ANSWER: ", then follows with the answer.''')
user_msg = HumanMessage(content=query)
result = []
print(f"Attempt {attempt + 1}: Processing query...")
for step in agent_executor.stream(
{"messages": [system_msg, user_msg]},
config,
stream_mode="values"
):
result = step["messages"]
final_response = result[-1].content if result else "No response generated"
print(f"Query processed successfully on attempt {attempt + 1}")
return final_response
except Exception as e:
error_msg = str(e).lower()
# Check for rate limit errors
if any(keyword in error_msg for keyword in ['rate limit', 'too many requests', '429', 'quota exceeded']):
wait_time = (2 ** attempt) + random.uniform(1, 3) # Exponential backoff with jitter
print(f"Rate limit hit on attempt {attempt + 1}. Waiting {wait_time:.2f} seconds...")
time.sleep(wait_time)
if attempt == max_retries - 1:
return f"Rate limit exceeded after {max_retries} attempts: {str(e)}"
continue
# Check for other API errors
elif any(keyword in error_msg for keyword in ['api', 'connection', 'timeout', 'service unavailable']):
wait_time = (2 ** attempt) + random.uniform(0.5, 1.5)
print(f"API error on attempt {attempt + 1}. Retrying in {wait_time:.2f} seconds...")
time.sleep(wait_time)
if attempt == max_retries - 1:
return f"API error after {max_retries} attempts: {str(e)}"
continue
else:
# Non-recoverable error
return f"Error occurred: {str(e)}"
return "Maximum retries exceeded"
# Main function with request tracking
request_count = 0
last_request_time = time.time()
def main(query: str) -> str:
"""Main function to run the agent with request tracking"""
global request_count, last_request_time
current_time = time.time()
# Reset counter every minute
if current_time - last_request_time > 60:
request_count = 0
last_request_time = current_time
request_count += 1
print(f"Processing request #{request_count}")
# Add small delay between requests to prevent overwhelming APIs
if request_count > 1:
time.sleep(random.uniform(2, 5))
return robust_agent_run(query)
if __name__ == "__main__":
# Test the agent
result = main("What are the names of the US presidents who were assassinated?")
print(result)