|
""" |
|
Enhanced Multi-LLM Agent System with Supabase FAISS Integration |
|
Complete system for document insertion, retrieval, and question answering |
|
""" |
|
|
|
import os |
|
import time |
|
import random |
|
import operator |
|
from typing import List, Dict, Any, TypedDict, Annotated, Optional |
|
from dotenv import load_dotenv |
|
|
|
from langchain_core.tools import tool |
|
from langchain_community.tools.tavily_search import TavilySearchResults |
|
from langchain_community.document_loaders import WikipediaLoader |
|
from langgraph.graph import StateGraph, END |
|
from langgraph.checkpoint.memory import MemorySaver |
|
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage |
|
from langchain_groq import ChatGroq |
|
|
|
|
|
import faiss |
|
import numpy as np |
|
from sentence_transformers import SentenceTransformer |
|
from supabase import create_client, Client |
|
import pandas as pd |
|
import json |
|
import pickle |
|
|
|
load_dotenv() |
|
|
|
|
|
ENHANCED_SYSTEM_PROMPT = ( |
|
"You are a helpful assistant tasked with answering questions using a set of tools. " |
|
"You must provide accurate, comprehensive answers based on available information. " |
|
"When answering questions, follow these guidelines:\n" |
|
"1. Use available tools to gather information when needed\n" |
|
"2. Provide precise, factual answers\n" |
|
"3. For numbers: don't use commas or units unless specified\n" |
|
"4. For strings: don't use articles or abbreviations, write digits in plain text\n" |
|
"5. For lists: apply above rules based on element type\n" |
|
"6. Always end with 'FINAL ANSWER: [YOUR ANSWER]'\n" |
|
"7. Be concise but thorough in your reasoning\n" |
|
"8. If you cannot find the answer, state that clearly" |
|
) |
|
|
|
|
|
@tool |
|
def multiply(a: int, b: int) -> int: |
|
"""Multiply two integers and return the product.""" |
|
return a * b |
|
|
|
@tool |
|
def add(a: int, b: int) -> int: |
|
"""Add two integers and return the sum.""" |
|
return a + b |
|
|
|
@tool |
|
def subtract(a: int, b: int) -> int: |
|
"""Subtract the second integer from the first and return the difference.""" |
|
return a - b |
|
|
|
@tool |
|
def divide(a: int, b: int) -> float: |
|
"""Divide the first integer by the second and return the quotient.""" |
|
if b == 0: |
|
raise ValueError("Cannot divide by zero.") |
|
return a / b |
|
|
|
@tool |
|
def modulus(a: int, b: int) -> int: |
|
"""Return the remainder when dividing the first integer by the second.""" |
|
return a % b |
|
|
|
@tool |
|
def optimized_web_search(query: str) -> str: |
|
"""Perform an optimized web search using TavilySearchResults.""" |
|
try: |
|
time.sleep(random.uniform(0.7, 1.5)) |
|
search_tool = TavilySearchResults(max_results=3) |
|
docs = search_tool.invoke({"query": query}) |
|
return "\n\n---\n\n".join( |
|
f"<Doc url='{d.get('url','')}'>{d.get('content','')[:800]}</Doc>" |
|
for d in docs |
|
) |
|
except Exception as e: |
|
return f"Web search failed: {e}" |
|
|
|
@tool |
|
def optimized_wiki_search(query: str) -> str: |
|
"""Perform an optimized Wikipedia search and return content snippets.""" |
|
try: |
|
time.sleep(random.uniform(0.3, 1)) |
|
docs = WikipediaLoader(query=query, load_max_docs=2).load() |
|
return "\n\n---\n\n".join( |
|
f"<Doc src='{d.metadata.get('source','Wikipedia')}'>{d.page_content[:1000]}</Doc>" |
|
for d in docs |
|
) |
|
except Exception as e: |
|
return f"Wikipedia search failed: {e}" |
|
|
|
|
|
class SupabaseFAISSVectorDB: |
|
"""Enhanced vector database combining FAISS with Supabase for persistent storage""" |
|
|
|
def __init__(self): |
|
|
|
self.supabase_url = os.getenv("SUPABASE_URL") |
|
self.supabase_key = os.getenv("SUPABASE_SERVICE_KEY") |
|
if self.supabase_url and self.supabase_key: |
|
self.supabase: Client = create_client(self.supabase_url, self.supabase_key) |
|
else: |
|
self.supabase = None |
|
print("Supabase credentials not found, running without vector database") |
|
|
|
|
|
self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2') |
|
self.embedding_dim = self.embedding_model.get_sentence_embedding_dimension() |
|
|
|
|
|
self.index = faiss.IndexFlatL2(self.embedding_dim) |
|
self.document_store = [] |
|
|
|
def insert_question_data(self, data: Dict[str, Any]) -> bool: |
|
"""Insert question data into both Supabase and FAISS""" |
|
try: |
|
question_text = data.get("Question", "") |
|
embedding = self.embedding_model.encode([question_text])[0] |
|
|
|
|
|
if self.supabase: |
|
question_data = { |
|
"task_id": data.get("task_id"), |
|
"question": question_text, |
|
"final_answer": data.get("Final answer"), |
|
"level": data.get("Level"), |
|
"file_name": data.get("file_name", ""), |
|
"embedding": embedding.tolist() |
|
} |
|
self.supabase.table("questions").insert(question_data).execute() |
|
|
|
|
|
self.index.add(embedding.reshape(1, -1).astype('float32')) |
|
self.document_store.append({ |
|
"task_id": data.get("task_id"), |
|
"question": question_text, |
|
"answer": data.get("Final answer"), |
|
"level": data.get("Level") |
|
}) |
|
|
|
return True |
|
except Exception as e: |
|
print(f"Error inserting data: {e}") |
|
return False |
|
|
|
def search_similar_questions(self, query: str, k: int = 3) -> List[Dict[str, Any]]: |
|
"""Search for similar questions using vector similarity""" |
|
try: |
|
if self.index.ntotal == 0: |
|
return [] |
|
|
|
query_embedding = self.embedding_model.encode([query])[0] |
|
k = min(k, self.index.ntotal) |
|
distances, indices = self.index.search( |
|
query_embedding.reshape(1, -1).astype('float32'), k |
|
) |
|
|
|
results = [] |
|
for i, idx in enumerate(indices[0]): |
|
if 0 <= idx < len(self.document_store): |
|
doc = self.document_store[idx] |
|
results.append({ |
|
"task_id": doc["task_id"], |
|
"question": doc["question"], |
|
"answer": doc["answer"], |
|
"similarity_score": 1 / (1 + distances[0][i]), |
|
"distance": float(distances[0][i]) |
|
}) |
|
|
|
return results |
|
except Exception as e: |
|
print(f"Error searching similar questions: {e}") |
|
return [] |
|
|
|
|
|
class EnhancedAgentState(TypedDict): |
|
"""State structure for the enhanced multi-LLM agent system.""" |
|
messages: Annotated[List[HumanMessage | AIMessage], operator.add] |
|
query: str |
|
agent_type: str |
|
final_answer: str |
|
perf: Dict[str, Any] |
|
agno_resp: str |
|
tools_used: List[str] |
|
reasoning: str |
|
similar_questions: List[Dict[str, Any]] |
|
|
|
|
|
class HybridLangGraphMultiLLMSystem: |
|
""" |
|
Advanced question-answering system with multi-LLM support and vector database integration |
|
""" |
|
|
|
def __init__(self, provider="groq"): |
|
self.provider = provider |
|
self.tools = [ |
|
multiply, add, subtract, divide, modulus, |
|
optimized_web_search, optimized_wiki_search |
|
] |
|
|
|
|
|
self.vector_db = SupabaseFAISSVectorDB() |
|
|
|
self.graph = self._build_graph() |
|
|
|
def _llm(self, model_name: str) -> ChatGroq: |
|
"""Create a Groq LLM instance.""" |
|
return ChatGroq( |
|
model=model_name, |
|
temperature=0, |
|
api_key=os.getenv("GROQ_API_KEY") |
|
) |
|
|
|
def _build_graph(self) -> StateGraph: |
|
"""Build the LangGraph state machine with enhanced capabilities.""" |
|
|
|
llama8_llm = self._llm("llama3-8b-8192") |
|
llama70_llm = self._llm("llama3-70b-8192") |
|
deepseek_llm = self._llm("deepseek-chat") |
|
|
|
def router(st: EnhancedAgentState) -> EnhancedAgentState: |
|
"""Route queries to appropriate LLM based on complexity and content analysis.""" |
|
q = st["query"].lower() |
|
|
|
|
|
if any(keyword in q for keyword in ["calculate", "compute", "math", "multiply", "add", "subtract", "divide"]): |
|
t = "llama70" |
|
elif any(keyword in q for keyword in ["search", "find", "lookup", "wikipedia", "information about"]): |
|
t = "search_enhanced" |
|
elif "deepseek" in q or any(keyword in q for keyword in ["analyze", "reasoning", "complex"]): |
|
t = "deepseek" |
|
elif "llama-8" in q: |
|
t = "llama8" |
|
elif len(q.split()) > 20: |
|
t = "llama70" |
|
else: |
|
t = "llama8" |
|
|
|
|
|
similar_questions = self.vector_db.search_similar_questions(st["query"], k=3) |
|
|
|
return {**st, "agent_type": t, "tools_used": [], "reasoning": "", "similar_questions": similar_questions} |
|
|
|
def llama8_node(st: EnhancedAgentState) -> EnhancedAgentState: |
|
"""Process query with Llama-3 8B model.""" |
|
t0 = time.time() |
|
try: |
|
|
|
context = "" |
|
if st.get("similar_questions"): |
|
context = "\n\nSimilar questions for reference:\n" |
|
for sq in st["similar_questions"][:2]: |
|
context += f"Q: {sq['question']}\nA: {sq['answer']}\n" |
|
|
|
enhanced_query = f""" |
|
Question: {st["query"]} |
|
{context} |
|
Please provide a direct, accurate answer to this question. |
|
""" |
|
|
|
sys = SystemMessage(content=ENHANCED_SYSTEM_PROMPT) |
|
res = llama8_llm.invoke([sys, HumanMessage(content=enhanced_query)]) |
|
|
|
answer = res.content.strip() |
|
if "FINAL ANSWER:" in answer: |
|
answer = answer.split("FINAL ANSWER:")[-1].strip() |
|
|
|
return {**st, |
|
"final_answer": answer, |
|
"reasoning": "Used Llama-3 8B with similar questions context", |
|
"perf": {"time": time.time() - t0, "prov": "Groq-Llama3-8B"}} |
|
except Exception as e: |
|
return {**st, "final_answer": f"Error: {e}", "perf": {"error": str(e)}} |
|
|
|
def llama70_node(st: EnhancedAgentState) -> EnhancedAgentState: |
|
"""Process query with Llama-3 70B model.""" |
|
t0 = time.time() |
|
try: |
|
|
|
context = "" |
|
if st.get("similar_questions"): |
|
context = "\n\nSimilar questions for reference:\n" |
|
for sq in st["similar_questions"][:2]: |
|
context += f"Q: {sq['question']}\nA: {sq['answer']}\n" |
|
|
|
enhanced_query = f""" |
|
Question: {st["query"]} |
|
{context} |
|
Please provide a direct, accurate answer to this question. |
|
""" |
|
|
|
sys = SystemMessage(content=ENHANCED_SYSTEM_PROMPT) |
|
res = llama70_llm.invoke([sys, HumanMessage(content=enhanced_query)]) |
|
|
|
answer = res.content.strip() |
|
if "FINAL ANSWER:" in answer: |
|
answer = answer.split("FINAL ANSWER:")[-1].strip() |
|
|
|
return {**st, |
|
"final_answer": answer, |
|
"reasoning": "Used Llama-3 70B for complex reasoning with context", |
|
"perf": {"time": time.time() - t0, "prov": "Groq-Llama3-70B"}} |
|
except Exception as e: |
|
return {**st, "final_answer": f"Error: {e}", "perf": {"error": str(e)}} |
|
|
|
def deepseek_node(st: EnhancedAgentState) -> EnhancedAgentState: |
|
"""Process query with DeepSeek model.""" |
|
t0 = time.time() |
|
try: |
|
|
|
context = "" |
|
if st.get("similar_questions"): |
|
context = "\n\nSimilar questions for reference:\n" |
|
for sq in st["similar_questions"][:2]: |
|
context += f"Q: {sq['question']}\nA: {sq['answer']}\n" |
|
|
|
enhanced_query = f""" |
|
Question: {st["query"]} |
|
{context} |
|
Please provide a direct, accurate answer to this question. |
|
""" |
|
|
|
sys = SystemMessage(content=ENHANCED_SYSTEM_PROMPT) |
|
res = deepseek_llm.invoke([sys, HumanMessage(content=enhanced_query)]) |
|
|
|
answer = res.content.strip() |
|
if "FINAL ANSWER:" in answer: |
|
answer = answer.split("FINAL ANSWER:")[-1].strip() |
|
|
|
return {**st, |
|
"final_answer": answer, |
|
"reasoning": "Used DeepSeek for advanced reasoning and analysis", |
|
"perf": {"time": time.time() - t0, "prov": "Groq-DeepSeek"}} |
|
except Exception as e: |
|
return {**st, "final_answer": f"Error: {e}", "perf": {"error": str(e)}} |
|
|
|
def search_enhanced_node(st: EnhancedAgentState) -> EnhancedAgentState: |
|
"""Process query with search enhancement.""" |
|
t0 = time.time() |
|
tools_used = [] |
|
|
|
try: |
|
|
|
query = st["query"] |
|
search_results = "" |
|
|
|
if any(keyword in query.lower() for keyword in ["wikipedia", "wiki"]): |
|
search_results = optimized_wiki_search.invoke({"query": query}) |
|
tools_used.append("wikipedia_search") |
|
else: |
|
search_results = optimized_web_search.invoke({"query": query}) |
|
tools_used.append("web_search") |
|
|
|
|
|
context = "" |
|
if st.get("similar_questions"): |
|
context = "\n\nSimilar questions for reference:\n" |
|
for sq in st["similar_questions"][:2]: |
|
context += f"Q: {sq['question']}\nA: {sq['answer']}\n" |
|
|
|
enhanced_query = f""" |
|
Original Question: {query} |
|
|
|
Search Results: |
|
{search_results} |
|
{context} |
|
|
|
Based on the search results and similar questions above, provide a direct answer to the original question. |
|
""" |
|
|
|
sys = SystemMessage(content=ENHANCED_SYSTEM_PROMPT) |
|
res = llama70_llm.invoke([sys, HumanMessage(content=enhanced_query)]) |
|
|
|
answer = res.content.strip() |
|
if "FINAL ANSWER:" in answer: |
|
answer = answer.split("FINAL ANSWER:")[-1].strip() |
|
|
|
return {**st, |
|
"final_answer": answer, |
|
"tools_used": tools_used, |
|
"reasoning": "Used search enhancement with similar questions context", |
|
"perf": {"time": time.time() - t0, "prov": "Search-Enhanced-Llama70"}} |
|
except Exception as e: |
|
return {**st, "final_answer": f"Error: {e}", "perf": {"error": str(e)}} |
|
|
|
|
|
g = StateGraph(EnhancedAgentState) |
|
g.add_node("router", router) |
|
g.add_node("llama8", llama8_node) |
|
g.add_node("llama70", llama70_node) |
|
g.add_node("deepseek", deepseek_node) |
|
g.add_node("search_enhanced", search_enhanced_node) |
|
|
|
g.set_entry_point("router") |
|
g.add_conditional_edges("router", lambda s: s["agent_type"], { |
|
"llama8": "llama8", |
|
"llama70": "llama70", |
|
"deepseek": "deepseek", |
|
"search_enhanced": "search_enhanced" |
|
}) |
|
|
|
for node in ["llama8", "llama70", "deepseek", "search_enhanced"]: |
|
g.add_edge(node, END) |
|
|
|
return g.compile(checkpointer=MemorySaver()) |
|
|
|
def process_query(self, q: str) -> str: |
|
"""Process a query through the enhanced multi-LLM system.""" |
|
state = { |
|
"messages": [HumanMessage(content=q)], |
|
"query": q, |
|
"agent_type": "", |
|
"final_answer": "", |
|
"perf": {}, |
|
"agno_resp": "", |
|
"tools_used": [], |
|
"reasoning": "", |
|
"similar_questions": [] |
|
} |
|
cfg = {"configurable": {"thread_id": f"enhanced_qa_{hash(q)}"}} |
|
|
|
try: |
|
out = self.graph.invoke(state, cfg) |
|
answer = out.get("final_answer", "").strip() |
|
|
|
|
|
if answer == q or answer.startswith(q): |
|
return "Information not available" |
|
|
|
return answer if answer else "No answer generated" |
|
except Exception as e: |
|
return f"Error processing query: {e}" |
|
|
|
def load_metadata_from_jsonl(self, jsonl_file_path: str) -> int: |
|
"""Load question metadata from JSONL file into vector database""" |
|
success_count = 0 |
|
|
|
try: |
|
with open(jsonl_file_path, 'r', encoding='utf-8') as file: |
|
for line_num, line in enumerate(file, 1): |
|
try: |
|
data = json.loads(line.strip()) |
|
if self.vector_db.insert_question_data(data): |
|
success_count += 1 |
|
|
|
if line_num % 10 == 0: |
|
print(f"Processed {line_num} records, {success_count} successful") |
|
|
|
except json.JSONDecodeError as e: |
|
print(f"JSON decode error on line {line_num}: {e}") |
|
except Exception as e: |
|
print(f"Error processing line {line_num}: {e}") |
|
|
|
except FileNotFoundError: |
|
print(f"File not found: {jsonl_file_path}") |
|
|
|
print(f"Loaded {success_count} questions into vector database") |
|
return success_count |
|
|
|
def build_graph(provider: str | None = None) -> StateGraph: |
|
"""Build and return the graph for the enhanced agent system.""" |
|
return HybridLangGraphMultiLLMSystem(provider or "groq").graph |
|
|
|
if __name__ == "__main__": |
|
|
|
system = HybridLangGraphMultiLLMSystem() |
|
|
|
|
|
if os.path.exists("metadata.jsonl"): |
|
system.load_metadata_from_jsonl("metadata.jsonl") |
|
|
|
|
|
test_questions = [ |
|
"How many studio albums were published by Mercedes Sosa between 2000 and 2009?", |
|
"What is 25 multiplied by 17?", |
|
"Find information about artificial intelligence on Wikipedia" |
|
] |
|
|
|
for question in test_questions: |
|
print(f"Question: {question}") |
|
answer = system.process_query(question) |
|
print(f"Answer: {answer}") |
|
print("-" * 50) |
|
|