chatbot-backend / src /agents /rag_agent.py
Last commit not found
raw
history blame
6.87 kB
# src/agents/rag_agent.py
from typing import List, Optional, Tuple, Dict
import uuid
from .excel_aware_rag import ExcelAwareRAGAgent
from ..llms.base_llm import BaseLLM
from src.embeddings.base_embedding import BaseEmbedding
from src.vectorstores.base_vectorstore import BaseVectorStore
from src.utils.conversation_manager import ConversationManager
from src.db.mongodb_store import MongoDBStore
from src.models.rag import RAGResponse
from src.utils.logger import logger
class RAGAgent(ExcelAwareRAGAgent):
def __init__(
self,
llm: BaseLLM,
embedding: BaseEmbedding,
vector_store: BaseVectorStore,
mongodb: MongoDBStore,
max_history_tokens: int = 4000,
max_history_messages: int = 10
):
"""
Initialize RAG Agent
Args:
llm (BaseLLM): Language model instance
embedding (BaseEmbedding): Embedding model instance
vector_store (BaseVectorStore): Vector store instance
mongodb (MongoDBStore): MongoDB store instance
max_history_tokens (int): Maximum tokens in conversation history
max_history_messages (int): Maximum messages to keep in history
"""
super().__init__() # Initialize ExcelAwareRAGAgent
self.llm = llm
self.embedding = embedding
self.vector_store = vector_store
self.mongodb = mongodb
self.conversation_manager = ConversationManager(
max_tokens=max_history_tokens,
max_messages=max_history_messages
)
async def generate_response(
self,
query: str,
conversation_id: Optional[str] = None,
temperature: float = 0.7,
max_tokens: Optional[int] = None,
context_docs: Optional[List[str]] = None
) -> RAGResponse:
"""Generate a response using RAG with conversation history"""
try:
# Create new conversation if no ID provided
if not conversation_id:
conversation_id = str(uuid.uuid4())
await self.mongodb.create_conversation(conversation_id)
# Get conversation history
history = await self.mongodb.get_recent_messages(
conversation_id,
limit=self.conversation_manager.max_messages
)
# Get relevant history within token limits
relevant_history = self.conversation_manager.get_relevant_history(
messages=history,
current_query=query
) if history else []
# Retrieve context if not provided
if not context_docs:
context_docs, sources, scores = await self.retrieve_context(
query,
conversation_history=relevant_history
)
else:
sources = None
scores = None
# Check if this is an Excel-related query and enhance context if needed
has_excel_content = any('Sheet:' in doc for doc in (context_docs or []))
if has_excel_content:
try:
context_docs = self._process_excel_context(context_docs, query)
except Exception as e:
logger.warning(f"Error processing Excel context: {str(e)}")
# Continue with original context if Excel processing fails
# Generate prompt with context and history
augmented_prompt = self.conversation_manager.generate_prompt_with_history(
current_query=query,
history=relevant_history,
context_docs=context_docs
)
# Generate initial response using LLM
response = self.llm.generate(
augmented_prompt,
temperature=temperature,
max_tokens=max_tokens
)
# Enhance response for Excel queries if applicable
if has_excel_content:
try:
response = await self.enhance_excel_response(
query=query,
response=response,
context_docs=context_docs
)
except Exception as e:
logger.warning(f"Error enhancing Excel response: {str(e)}")
# Continue with original response if enhancement fails
return RAGResponse(
response=response,
context_docs=context_docs,
sources=sources,
scores=scores
)
except Exception as e:
logger.error(f"Error generating response: {str(e)}")
raise
async def retrieve_context(
self,
query: str,
conversation_history: Optional[List[Dict]] = None,
top_k: int = 3
) -> Tuple[List[str], List[Dict], Optional[List[float]]]:
"""
Retrieve context with conversation history enhancement
Args:
query (str): Current query
conversation_history (Optional[List[Dict]]): Recent conversation history
top_k (int): Number of documents to retrieve
Returns:
Tuple[List[str], List[Dict], Optional[List[float]]]:
Retrieved documents, sources, and scores
"""
# Enhance query with conversation history
if conversation_history:
recent_queries = [
msg['query'] for msg in conversation_history[-2:]
if msg.get('query')
]
enhanced_query = " ".join([*recent_queries, query])
else:
enhanced_query = query
# Embed the enhanced query
query_embedding = self.embedding.embed_query(enhanced_query)
# Retrieve similar documents
results = self.vector_store.similarity_search(
query_embedding,
top_k=top_k
)
# Process results
documents = [doc['text'] for doc in results]
sources = [self._convert_metadata_to_strings(doc['metadata'])
for doc in results]
scores = [doc['score'] for doc in results
if doc.get('score') is not None]
# Return scores only if available for all documents
if len(scores) != len(documents):
scores = None
return documents, sources, scores
def _convert_metadata_to_strings(self, metadata: Dict) -> Dict:
"""Convert numeric metadata values to strings"""
converted = {}
for key, value in metadata.items():
if isinstance(value, (int, float)):
converted[key] = str(value)
else:
converted[key] = value
return converted