chatbot-backend / src /agents /rag_agent.py
TalatMasood's picture
Updarte chatbot with deployment configurations on the Render
415595f
raw
history blame
9.71 kB
# src/agents/rag_agent.py
from typing import List, Optional, Tuple, Dict
import uuid
from .excel_aware_rag import ExcelAwareRAGAgent
from ..llms.base_llm import BaseLLM
from src.embeddings.base_embedding import BaseEmbedding
from src.vectorstores.base_vectorstore import BaseVectorStore
from src.utils.conversation_manager import ConversationManager
from src.db.mongodb_store import MongoDBStore
from src.models.rag import RAGResponse
from src.utils.logger import logger
class RAGAgent(ExcelAwareRAGAgent):
def __init__(
self,
llm: BaseLLM,
embedding: BaseEmbedding,
vector_store: BaseVectorStore,
mongodb: MongoDBStore,
max_history_tokens: int = 4000,
max_history_messages: int = 10
):
"""
Initialize RAG Agent
Args:
llm (BaseLLM): Language model instance
embedding (BaseEmbedding): Embedding model instance
vector_store (BaseVectorStore): Vector store instance
mongodb (MongoDBStore): MongoDB store instance
max_history_tokens (int): Maximum tokens in conversation history
max_history_messages (int): Maximum messages to keep in history
"""
super().__init__() # Initialize ExcelAwareRAGAgent
self.llm = llm
self.embedding = embedding
self.vector_store = vector_store
self.mongodb = mongodb
self.conversation_manager = ConversationManager(
max_tokens=max_history_tokens,
max_messages=max_history_messages
)
async def generate_response(
self,
query: str,
conversation_id: Optional[str],
temperature: float,
max_tokens: Optional[int] = None,
context_docs: Optional[List[str]] = None
) -> RAGResponse:
"""Generate response with specific handling for different query types"""
try:
# First, check if this is an introduction/welcome message query
is_introduction = (
"wants support" in query and
"This is Introduction" in query and
("A new user with name:" in query or "An old user with name:" in query)
)
if is_introduction:
# Handle introduction message - no context needed
welcome_message = self._handle_contact_query(query)
return RAGResponse(
response=welcome_message,
context_docs=[],
sources=[],
scores=None
)
# Get conversation history if conversation_id exists
history = []
if conversation_id:
history = await self.mongodb.get_recent_messages(
conversation_id,
limit=self.conversation_manager.max_messages
)
# Get relevant history within token limits
history = self.conversation_manager.get_relevant_history(
messages=history,
current_query=query
)
# Retrieve context if not provided
if not context_docs:
context_docs, sources, scores = await self.retrieve_context(
query=query,
conversation_history=history
)
else:
sources = None
scores = None
# Check if we have any relevant context
if not context_docs:
return RAGResponse(
response="Information about this is not available, do you want to inquire about something else?",
context_docs=[],
sources=[],
scores=None
)
# Check if this is an Excel-related query
has_excel_content = any('Sheet:' in doc for doc in context_docs)
if has_excel_content:
try:
context_docs = self._process_excel_context(
context_docs, query)
except Exception as e:
logger.warning(f"Error processing Excel context: {str(e)}")
# Generate prompt with context and history
augmented_prompt = self.conversation_manager.generate_prompt_with_history(
current_query=query,
history=history,
context_docs=context_docs
)
# Generate initial response
response = self.llm.generate(
prompt=augmented_prompt,
temperature=temperature,
max_tokens=max_tokens
)
# Clean the response
cleaned_response = self._clean_response(response)
# For Excel queries, enhance the response
if has_excel_content:
try:
enhanced_response = await self.enhance_excel_response(
query=query,
response=cleaned_response,
context_docs=context_docs
)
if enhanced_response:
cleaned_response = enhanced_response
except Exception as e:
logger.warning(f"Error enhancing Excel response: {str(e)}")
# Return the final response
return RAGResponse(
response=cleaned_response,
context_docs=context_docs,
sources=sources,
scores=scores
)
except Exception as e:
logger.error(f"Error in SystemInstructionsRAGAgent: {str(e)}")
raise
def _create_response_prompt(self, query: str, context_docs: List[str]) -> str:
"""
Create prompt for generating response from context
Args:
query (str): User query
context_docs (List[str]): Retrieved context documents
Returns:
str: Formatted prompt for the LLM
"""
if not context_docs:
return f"Query: {query}\nResponse: Information about this is not available, do you want to inquire about something else?"
# Format context documents
formatted_context = "\n\n".join(
f"Context {i+1}:\n{doc.strip()}"
for i, doc in enumerate(context_docs)
if doc and doc.strip()
)
# Build the prompt with detailed instructions
prompt = f"""You are a knowledgeable assistant. Use the following context to answer the query accurately and informatively.
Context Information:
{formatted_context}
Query: {query}
Instructions:
1. Base your response ONLY on the information provided in the context above
2. If the context contains numbers, statistics, or specific details, include them in your response
3. Keep your response focused and relevant to the query
4. Use clear and professional language
5. If the context includes technical terms, explain them appropriately
6. Do not make assumptions or add information not present in the context
7. If specific sections of a report are mentioned, maintain their original structure
8. Format the response in a clear, readable manner
9. If the context includes chronological information, maintain the proper sequence
Response:"""
return prompt
async def retrieve_context(
self,
query: str,
conversation_history: Optional[List[Dict]] = None,
top_k: int = 3
) -> Tuple[List[str], List[Dict], Optional[List[float]]]:
"""
Retrieve context with conversation history enhancement
"""
# Enhance query with conversation history
if conversation_history:
recent_queries = [
msg['query'] for msg in conversation_history[-2:]
if msg.get('query')
]
enhanced_query = " ".join([*recent_queries, query])
else:
enhanced_query = query
# Debug log the enhanced query
logger.info(f"Enhanced query: {enhanced_query}")
# Embed the enhanced query
query_embedding = self.embedding.embed_query(enhanced_query)
# Debug log embedding shape
logger.info(f"Query embedding shape: {len(query_embedding)}")
# Retrieve similar documents
results = self.vector_store.similarity_search(
query_embedding,
top_k=top_k
)
# Debug log search results
logger.info(f"Number of search results: {len(results)}")
for i, result in enumerate(results):
logger.info(f"Result {i} score: {result.get('score', 'N/A')}")
logger.info(
f"Result {i} text preview: {result.get('text', '')[:100]}...")
# Process results
documents = [doc['text'] for doc in results]
sources = [self._convert_metadata_to_strings(doc['metadata'])
for doc in results]
scores = [doc['score'] for doc in results
if doc.get('score') is not None]
# Return scores only if available for all documents
if len(scores) != len(documents):
scores = None
return documents, sources, scores
def _convert_metadata_to_strings(self, metadata: Dict) -> Dict:
"""Convert numeric metadata values to strings"""
converted = {}
for key, value in metadata.items():
if isinstance(value, (int, float)):
converted[key] = str(value)
else:
converted[key] = value
return converted