|
""" |
|
Enhanced Multi-LLM Agent System with Question-Answering Capabilities |
|
Supports Groq (Llama-3 8B/70B, DeepSeek), Google Gemini, NVIDIA NIM, and Agno-style agents |
|
""" |
|
|
|
import os |
|
import time |
|
import random |
|
import operator |
|
from typing import List, Dict, Any, TypedDict, Annotated, Optional |
|
from dotenv import load_dotenv |
|
|
|
from langchain_core.tools import tool |
|
from langchain_community.tools.tavily_search import TavilySearchResults |
|
from langchain_community.document_loaders import WikipediaLoader |
|
from langgraph.graph import StateGraph, END |
|
from langgraph.checkpoint.memory import MemorySaver |
|
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage |
|
from langchain_groq import ChatGroq |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
ENHANCED_SYSTEM_PROMPT = ( |
|
"You are a helpful assistant tasked with answering questions using a set of tools. " |
|
"You must provide accurate, comprehensive answers based on available information. " |
|
"When answering questions, follow these guidelines:\n" |
|
'1)You are a helpful assistant tasked with answering questions using a set of tools. |
|
'2)Now, I will ask you a question. Report your thoughts, and finish your answer with the following template:' |
|
'FINAL ANSWER: [YOUR FINAL ANSWER].' |
|
'3)YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.' |
|
'4)Your answer should only start with "FINAL ANSWER: ", then follows with the answer. ' |
|
) |
|
|
|
# ---- Tool Definitions with Enhanced Docstrings ---- |
|
@tool |
|
def multiply(a: int | float, b: int | float) -> int | float: |
|
"""Multiply two numbers. |
|
|
|
Args: |
|
a: first int | float |
|
b: second int | float |
|
""" |
|
return a * b |
|
|
|
@tool |
|
def add(a: int | float, b: int | float) -> int | float: |
|
""" |
|
Adds two integers and returns the sum. |
|
|
|
Args: |
|
a (int): First integer |
|
b (int): Second integer |
|
|
|
Returns: |
|
int: Sum of a and b |
|
""" |
|
return a + b |
|
|
|
@tool |
|
def subtract(a: int | float, b: int | float) -> int | float: |
|
""" |
|
Subtracts the second integer from the first and returns the difference. |
|
|
|
Args: |
|
a (int): First integer (minuend) |
|
b (int): Second integer (subtrahend) |
|
|
|
Returns: |
|
int: Difference of a and b |
|
""" |
|
return a - b |
|
|
|
@tool |
|
def divide(a: int | float, b: int | float) -> int | float: |
|
""" |
|
Divides the first integer by the second and returns the quotient. |
|
|
|
Args: |
|
a (int): Dividend |
|
b (int): Divisor |
|
|
|
Returns: |
|
float: Quotient of a divided by b |
|
|
|
Raises: |
|
ValueError: If b is zero |
|
""" |
|
if b == 0 or b==0.0: |
|
raise ValueError("Cannot divide by zero.") |
|
return a / b |
|
|
|
@tool |
|
def modulus(a: int | float, b: int | float) -> int | float: |
|
""" |
|
Returns the remainder when dividing the first integer by the second. |
|
|
|
Args: |
|
a (int): Dividend |
|
b (int): Divisor |
|
|
|
Returns: |
|
int: Remainder of a divided by b |
|
""" |
|
return a % b |
|
|
|
@tool |
|
def optimized_web_search(query: str) -> str: |
|
""" |
|
Performs an optimized web search using TavilySearchResults. |
|
|
|
Args: |
|
query (str): Search query string |
|
|
|
Returns: |
|
str: Concatenated search results with URLs and content snippets |
|
""" |
|
try: |
|
time.sleep(random.uniform(0.7, 1.5)) |
|
docs = TavilySearchResults(max_results=3).invoke(query=query) |
|
return "\n\n---\n\n".join( |
|
f"<Doc url='{d.get('url','')}'>{d.get('content','')[:800]}</Doc>" |
|
for d in docs |
|
) |
|
except Exception as e: |
|
return f"Web search failed: {e}" |
|
|
|
@tool |
|
def optimized_wiki_search(query: str) -> str: |
|
""" |
|
Performs an optimized Wikipedia search and returns content snippets. |
|
|
|
Args: |
|
query (str): Wikipedia search query |
|
|
|
Returns: |
|
str: Wikipedia content with source attribution |
|
""" |
|
try: |
|
time.sleep(random.uniform(0.3, 1)) |
|
docs = WikipediaLoader(query=query, load_max_docs=2).load() |
|
return "\n\n---\n\n".join( |
|
f"<Doc src='{d.metadata.get('source','Wikipedia')}'>{d.page_content[:1000]}</Doc>" |
|
for d in docs |
|
) |
|
except Exception as e: |
|
return f"Wikipedia search failed: {e}" |
|
|
|
# ---- LLM Provider Integrations ---- |
|
try: |
|
from langchain_nvidia_ai_endpoints import ChatNVIDIA |
|
NVIDIA_AVAILABLE = True |
|
except ImportError: |
|
NVIDIA_AVAILABLE = False |
|
|
|
try: |
|
import google.generativeai as genai |
|
from langchain_google_genai import ChatGoogleGenerativeAI |
|
GOOGLE_AVAILABLE = True |
|
except ImportError: |
|
GOOGLE_AVAILABLE = False |
|
|
|
# ---- Enhanced Agent State ---- |
|
class EnhancedAgentState(TypedDict): |
|
""" |
|
State structure for the enhanced multi-LLM agent system. |
|
|
|
Attributes: |
|
messages: List of conversation messages |
|
query: Current query string |
|
agent_type: Selected agent/LLM type |
|
final_answer: Generated response |
|
perf: Performance metrics |
|
agno_resp: Agno-style response metadata |
|
tools_used: List of tools used in processing |
|
reasoning: Step-by-step reasoning process |
|
""" |
|
messages: Annotated[List[HumanMessage | AIMessage], operator.add] |
|
query: str |
|
agent_type: str |
|
final_answer: str |
|
perf: Dict[str, Any] |
|
agno_resp: str |
|
tools_used: List[str] |
|
reasoning: str |
|
|
|
# ---- Enhanced Multi-LLM System ---- |
|
class EnhancedQuestionAnsweringSystem: |
|
""" |
|
Advanced question-answering system that routes queries to appropriate LLM providers |
|
and uses tools to gather information for comprehensive answers. |
|
|
|
Features: |
|
- Multi-LLM routing (Groq, Google, NVIDIA) |
|
- Tool integration for web search and calculations |
|
- Structured reasoning and answer formatting |
|
- Performance monitoring |
|
""" |
|
|
|
def __init__(self): |
|
"""Initialize the enhanced question-answering system.""" |
|
self.tools = [ |
|
multiply, add, subtract, divide, modulus, |
|
optimized_web_search, optimized_wiki_search |
|
] |
|
self.graph = self._build_graph() |
|
|
|
def _llm(self, model_name: str) -> ChatGroq: |
|
""" |
|
Create a Groq LLM instance. |
|
|
|
Args: |
|
model_name (str): Model identifier |
|
|
|
Returns: |
|
ChatGroq: Configured Groq LLM instance |
|
""" |
|
return ChatGroq( |
|
model=model_name, |
|
temperature=0, |
|
api_key=os.getenv("GROQ_API_KEY") |
|
) |
|
|
|
def _build_graph(self) -> StateGraph: |
|
""" |
|
Build the LangGraph state machine with enhanced question-answering capabilities. |
|
|
|
Returns: |
|
StateGraph: Compiled graph with routing logic |
|
""" |
|
# Initialize LLMs |
|
llama8_llm = self._llm("llama3-8b-8192") |
|
llama70_llm = self._llm("llama3-70b-8192") |
|
deepseek_llm = self._llm("deepseek-chat") |
|
|
|
def router(st: EnhancedAgentState) -> EnhancedAgentState: |
|
""" |
|
Route queries to appropriate LLM based on complexity and content. |
|
|
|
Args: |
|
st (EnhancedAgentState): Current state |
|
|
|
Returns: |
|
EnhancedAgentState: Updated state with agent selection |
|
""" |
|
q = st["query"].lower() |
|
|
|
# Route based on query characteristics |
|
if any(keyword in q for keyword in ["calculate", "compute", "math", "number"]): |
|
t = "llama70" # Use more powerful model for calculations |
|
elif any(keyword in q for keyword in ["search", "find", "lookup", "wikipedia"]): |
|
t = "search_enhanced" # Use search-enhanced processing |
|
elif "deepseek" in q or any(keyword in q for keyword in ["analyze", "reasoning", "complex"]): |
|
t = "deepseek" |
|
elif len(q.split()) > 20: # Complex queries |
|
t = "llama70" |
|
else: |
|
t = "llama8" # Default for simple queries |
|
|
|
return {**st, "agent_type": t, "tools_used": [], "reasoning": ""} |
|
|
|
def llama8_node(st: EnhancedAgentState) -> EnhancedAgentState: |
|
"""Process query with Llama-3 8B model.""" |
|
t0 = time.time() |
|
try: |
|
sys = SystemMessage(content=ENHANCED_SYSTEM_PROMPT) |
|
res = llama8_llm.invoke([sys, HumanMessage(content=st["query"])]) |
|
|
|
reasoning = "Used Llama-3 8B for efficient processing of straightforward query." |
|
|
|
return {**st, |
|
"final_answer": res.content, |
|
"reasoning": reasoning, |
|
"perf": {"time": time.time() - t0, "prov": "Groq-Llama3-8B"}} |
|
except Exception as e: |
|
return {**st, "final_answer": f"Error: {e}", "perf": {"error": str(e)}} |
|
|
|
def llama70_node(st: EnhancedAgentState) -> EnhancedAgentState: |
|
"""Process query with Llama-3 70B model.""" |
|
t0 = time.time() |
|
try: |
|
sys = SystemMessage(content=ENHANCED_SYSTEM_PROMPT) |
|
res = llama70_llm.invoke([sys, HumanMessage(content=st["query"])]) |
|
|
|
reasoning = "Used Llama-3 70B for complex reasoning and detailed analysis." |
|
|
|
return {**st, |
|
"final_answer": res.content, |
|
"reasoning": reasoning, |
|
"perf": {"time": time.time() - t0, "prov": "Groq-Llama3-70B"}} |
|
except Exception as e: |
|
return {**st, "final_answer": f"Error: {e}", "perf": {"error": str(e)}} |
|
|
|
def deepseek_node(st: EnhancedAgentState) -> EnhancedAgentState: |
|
"""Process query with DeepSeek model.""" |
|
t0 = time.time() |
|
try: |
|
sys = SystemMessage(content=ENHANCED_SYSTEM_PROMPT) |
|
res = deepseek_llm.invoke([sys, HumanMessage(content=st["query"])]) |
|
|
|
reasoning = "Used DeepSeek for advanced reasoning and analytical tasks." |
|
|
|
return {**st, |
|
"final_answer": res.content, |
|
"reasoning": reasoning, |
|
"perf": {"time": time.time() - t0, "prov": "Groq-DeepSeek"}} |
|
except Exception as e: |
|
return {**st, "final_answer": f"Error: {e}", "perf": {"error": str(e)}} |
|
|
|
def search_enhanced_node(st: EnhancedAgentState) -> EnhancedAgentState: |
|
"""Process query with search enhancement.""" |
|
t0 = time.time() |
|
tools_used = [] |
|
reasoning_steps = [] |
|
|
|
try: |
|
# Determine if we need web search or Wikipedia |
|
query = st["query"] |
|
search_results = "" |
|
|
|
if any(keyword in query.lower() for keyword in ["wikipedia", "wiki"]): |
|
search_results = optimized_wiki_search.invoke({"query": query}) |
|
tools_used.append("wikipedia_search") |
|
reasoning_steps.append("Searched Wikipedia for relevant information") |
|
else: |
|
search_results = optimized_web_search.invoke({"query": query}) |
|
tools_used.append("web_search") |
|
reasoning_steps.append("Performed web search for current information") |
|
|
|
# Enhance query with search results |
|
enhanced_query = f""" |
|
Original Query: {query} |
|
|
|
Search Results: |
|
{search_results} |
|
|
|
Based on the search results above, please provide a comprehensive answer to the original query. |
|
""" |
|
|
|
sys = SystemMessage(content=ENHANCED_SYSTEM_PROMPT) |
|
res = llama70_llm.invoke([sys, HumanMessage(content=enhanced_query)]) |
|
|
|
reasoning_steps.append("Used Llama-3 70B to analyze search results and generate comprehensive answer") |
|
reasoning = " -> ".join(reasoning_steps) |
|
|
|
return {**st, |
|
"final_answer": res.content, |
|
"tools_used": tools_used, |
|
"reasoning": reasoning, |
|
"perf": {"time": time.time() - t0, "prov": "Search-Enhanced-Llama70"}} |
|
except Exception as e: |
|
return {**st, "final_answer": f"Error: {e}", "perf": {"error": str(e)}} |
|
|
|
# Build graph |
|
g = StateGraph(EnhancedAgentState) |
|
g.add_node("router", router) |
|
g.add_node("llama8", llama8_node) |
|
g.add_node("llama70", llama70_node) |
|
g.add_node("deepseek", deepseek_node) |
|
g.add_node("search_enhanced", search_enhanced_node) |
|
|
|
g.set_entry_point("router") |
|
g.add_conditional_edges("router", lambda s: s["agent_type"], { |
|
"llama8": "llama8", |
|
"llama70": "llama70", |
|
"deepseek": "deepseek", |
|
"search_enhanced": "search_enhanced" |
|
}) |
|
|
|
for node in ["llama8", "llama70", "deepseek", "search_enhanced"]: |
|
g.add_edge(node, END) |
|
|
|
return g.compile(checkpointer=MemorySaver()) |
|
|
|
def process_query(self, q: str) -> str: |
|
""" |
|
Process a query through the enhanced question-answering system. |
|
|
|
Args: |
|
q (str): Input query |
|
|
|
Returns: |
|
str: Generated response with proper formatting |
|
""" |
|
state = { |
|
"messages": [HumanMessage(content=q)], |
|
"query": q, |
|
"agent_type": "", |
|
"final_answer": "", |
|
"perf": {}, |
|
"agno_resp": "", |
|
"tools_used": [], |
|
"reasoning": "" |
|
} |
|
cfg = {"configurable": {"thread_id": f"qa_{hash(q)}"}} |
|
|
|
try: |
|
out = self.graph.invoke(state, cfg) |
|
answer = out.get("final_answer", "").strip() |
|
|
|
# Ensure proper formatting |
|
if not answer.startswith("FINAL ANSWER:"): |
|
# Extract the actual answer if it's buried in explanation |
|
if "FINAL ANSWER:" in answer: |
|
answer = answer.split("FINAL ANSWER:")[-1].strip() |
|
answer = f"FINAL ANSWER: {answer}" |
|
else: |
|
|
|
answer = f"FINAL ANSWER: {answer}" |
|
|
|
return answer |
|
except Exception as e: |
|
return f"FINAL ANSWER: Error processing query: {e}" |
|
|
|
def build_graph(provider: str | None = None) -> StateGraph: |
|
""" |
|
Build and return the graph for the enhanced question-answering system. |
|
|
|
Args: |
|
provider (str | None): Provider preference (optional) |
|
|
|
Returns: |
|
StateGraph: Compiled graph instance |
|
""" |
|
return EnhancedQuestionAnsweringSystem().graph |
|
|
|
|
|
class QuestionAnsweringAgent: |
|
""" |
|
Main interface for the question-answering agent system. |
|
""" |
|
|
|
def __init__(self): |
|
"""Initialize the question-answering agent.""" |
|
self.system = EnhancedQuestionAnsweringSystem() |
|
|
|
def answer_question(self, question: str) -> str: |
|
""" |
|
Answer a question using the enhanced multi-LLM system. |
|
|
|
Args: |
|
question (str): The question to answer |
|
|
|
Returns: |
|
str: Formatted answer with FINAL ANSWER prefix |
|
""" |
|
return self.system.process_query(question) |
|
|
|
if __name__ == "__main__": |
|
|
|
qa_agent = QuestionAnsweringAgent() |
|
|
|
|
|
test_questions = [ |
|
"How many studio albums were published by Mercedes Sosa between 2000 and 2009?", |
|
"What is 25 multiplied by 17?", |
|
"Find information about the capital of France on Wikipedia", |
|
"What is the population of Tokyo according to recent data?" |
|
] |
|
|
|
print("=" * 80) |
|
print("Enhanced Question-Answering Agent System") |
|
print("=" * 80) |
|
|
|
for i, question in enumerate(test_questions, 1): |
|
print(f"\nQuestion {i}: {question}") |
|
print("-" * 60) |
|
answer = qa_agent.answer_question(question) |
|
print(answer) |
|
print() |
|
|