Spaces:
Running
Running
Last commit not found
# src/agents/rag_agent.py | |
from typing import List, Optional, Tuple, Dict | |
import uuid | |
from .excel_aware_rag import ExcelAwareRAGAgent | |
from ..llms.base_llm import BaseLLM | |
from src.embeddings.base_embedding import BaseEmbedding | |
from src.vectorstores.base_vectorstore import BaseVectorStore | |
from src.utils.conversation_manager import ConversationManager | |
from src.db.mongodb_store import MongoDBStore | |
from src.models.rag import RAGResponse | |
from src.utils.logger import logger | |
class RAGAgent(ExcelAwareRAGAgent): | |
def __init__( | |
self, | |
llm: BaseLLM, | |
embedding: BaseEmbedding, | |
vector_store: BaseVectorStore, | |
mongodb: MongoDBStore, | |
max_history_tokens: int = 4000, | |
max_history_messages: int = 10 | |
): | |
""" | |
Initialize RAG Agent | |
Args: | |
llm (BaseLLM): Language model instance | |
embedding (BaseEmbedding): Embedding model instance | |
vector_store (BaseVectorStore): Vector store instance | |
mongodb (MongoDBStore): MongoDB store instance | |
max_history_tokens (int): Maximum tokens in conversation history | |
max_history_messages (int): Maximum messages to keep in history | |
""" | |
super().__init__() # Initialize ExcelAwareRAGAgent | |
self.llm = llm | |
self.embedding = embedding | |
self.vector_store = vector_store | |
self.mongodb = mongodb | |
self.conversation_manager = ConversationManager( | |
max_tokens=max_history_tokens, | |
max_messages=max_history_messages | |
) | |
async def generate_response( | |
self, | |
query: str, | |
conversation_id: Optional[str] = None, | |
temperature: float = 0.7, | |
max_tokens: Optional[int] = None, | |
context_docs: Optional[List[str]] = None | |
) -> RAGResponse: | |
"""Generate a response using RAG with conversation history""" | |
try: | |
# Create new conversation if no ID provided | |
if not conversation_id: | |
conversation_id = str(uuid.uuid4()) | |
await self.mongodb.create_conversation(conversation_id) | |
# Get conversation history | |
history = await self.mongodb.get_recent_messages( | |
conversation_id, | |
limit=self.conversation_manager.max_messages | |
) | |
# Get relevant history within token limits | |
relevant_history = self.conversation_manager.get_relevant_history( | |
messages=history, | |
current_query=query | |
) if history else [] | |
# Retrieve context if not provided | |
if not context_docs: | |
context_docs, sources, scores = await self.retrieve_context( | |
query, | |
conversation_history=relevant_history | |
) | |
else: | |
sources = None | |
scores = None | |
# Check if this is an Excel-related query and enhance context if needed | |
has_excel_content = any('Sheet:' in doc for doc in (context_docs or [])) | |
if has_excel_content: | |
try: | |
context_docs = self._process_excel_context(context_docs, query) | |
except Exception as e: | |
logger.warning(f"Error processing Excel context: {str(e)}") | |
# Continue with original context if Excel processing fails | |
# Generate prompt with context and history | |
augmented_prompt = self.conversation_manager.generate_prompt_with_history( | |
current_query=query, | |
history=relevant_history, | |
context_docs=context_docs | |
) | |
# Generate initial response using LLM | |
response = self.llm.generate( | |
augmented_prompt, | |
temperature=temperature, | |
max_tokens=max_tokens | |
) | |
# Enhance response for Excel queries if applicable | |
if has_excel_content: | |
try: | |
response = await self.enhance_excel_response( | |
query=query, | |
response=response, | |
context_docs=context_docs | |
) | |
except Exception as e: | |
logger.warning(f"Error enhancing Excel response: {str(e)}") | |
# Continue with original response if enhancement fails | |
return RAGResponse( | |
response=response, | |
context_docs=context_docs, | |
sources=sources, | |
scores=scores | |
) | |
except Exception as e: | |
logger.error(f"Error generating response: {str(e)}") | |
raise | |
async def retrieve_context( | |
self, | |
query: str, | |
conversation_history: Optional[List[Dict]] = None, | |
top_k: int = 3 | |
) -> Tuple[List[str], List[Dict], Optional[List[float]]]: | |
""" | |
Retrieve context with conversation history enhancement | |
Args: | |
query (str): Current query | |
conversation_history (Optional[List[Dict]]): Recent conversation history | |
top_k (int): Number of documents to retrieve | |
Returns: | |
Tuple[List[str], List[Dict], Optional[List[float]]]: | |
Retrieved documents, sources, and scores | |
""" | |
# Enhance query with conversation history | |
if conversation_history: | |
recent_queries = [ | |
msg['query'] for msg in conversation_history[-2:] | |
if msg.get('query') | |
] | |
enhanced_query = " ".join([*recent_queries, query]) | |
else: | |
enhanced_query = query | |
# Embed the enhanced query | |
query_embedding = self.embedding.embed_query(enhanced_query) | |
# Retrieve similar documents | |
results = self.vector_store.similarity_search( | |
query_embedding, | |
top_k=top_k | |
) | |
# Process results | |
documents = [doc['text'] for doc in results] | |
sources = [self._convert_metadata_to_strings(doc['metadata']) | |
for doc in results] | |
scores = [doc['score'] for doc in results | |
if doc.get('score') is not None] | |
# Return scores only if available for all documents | |
if len(scores) != len(documents): | |
scores = None | |
return documents, sources, scores | |
def _convert_metadata_to_strings(self, metadata: Dict) -> Dict: | |
"""Convert numeric metadata values to strings""" | |
converted = {} | |
for key, value in metadata.items(): | |
if isinstance(value, (int, float)): | |
converted[key] = str(value) | |
else: | |
converted[key] = value | |
return converted |