# src/agents/rag_agent.py from typing import List, Optional, Tuple, Dict import uuid from .excel_aware_rag import ExcelAwareRAGAgent from ..llms.base_llm import BaseLLM from src.embeddings.base_embedding import BaseEmbedding from src.vectorstores.base_vectorstore import BaseVectorStore from src.utils.conversation_manager import ConversationManager from src.db.mongodb_store import MongoDBStore from src.models.rag import RAGResponse from src.utils.logger import logger from config.config import settings class RAGAgent(ExcelAwareRAGAgent): def __init__( self, llm: BaseLLM, embedding: BaseEmbedding, vector_store: BaseVectorStore, mongodb: MongoDBStore, max_history_tokens: int = 4000, max_history_messages: int = 10 ): """ Initialize RAG Agent Args: llm (BaseLLM): Language model instance embedding (BaseEmbedding): Embedding model instance vector_store (BaseVectorStore): Vector store instance mongodb (MongoDBStore): MongoDB store instance max_history_tokens (int): Maximum tokens in conversation history max_history_messages (int): Maximum messages to keep in history """ super().__init__() # Initialize ExcelAwareRAGAgent self.llm = llm self.embedding = embedding self.vector_store = vector_store self.mongodb = mongodb self.conversation_manager = ConversationManager( max_tokens=max_history_tokens, max_messages=max_history_messages ) def _extract_markdown_section(self, docs: List[str], section_header: str) -> str: """Extract complete section content from markdown documents""" combined_text = '\n'.join(docs) section_start = combined_text.find(section_header) if section_start == -1: return "" next_section = combined_text.find( "\n\n**", section_start + len(section_header)) if next_section == -1: section_content = combined_text[section_start:] else: section_content = combined_text[section_start:next_section] return self._clean_markdown_content(section_content) def _clean_markdown_content(self, content: str) -> str: """Clean and format markdown content""" lines = content.split('\n') seen_lines = set() cleaned_lines = [] for line in lines: # Always keep headers and table formatting if '| :----' in line or line.startswith('**'): if line not in seen_lines: cleaned_lines.append(line) seen_lines.add(line) continue # Keep table rows and list items if line.strip().startswith('|') or line.strip().startswith('-'): cleaned_lines.append(line) continue # Remove duplicates for other content stripped = line.strip() if stripped and stripped not in seen_lines: cleaned_lines.append(line) seen_lines.add(stripped) return '\n'.join(cleaned_lines) async def generate_response( self, query: str, conversation_id: Optional[str], temperature: float, max_tokens: Optional[int] = None, context_docs: Optional[List[str]] = None ) -> RAGResponse: """Generate response with improved markdown and conversation handling""" try: # Handle introduction/welcome message queries is_introduction = ( "wants support" in query and "This is Introduction" in query and ("A new user with name:" in query or "An old user with name:" in query) ) if is_introduction: welcome_message = self._handle_contact_query(query) return RAGResponse( response=welcome_message, context_docs=[], sources=[], scores=None ) # Get conversation history if conversation_id exists history = [] if conversation_id: history = await self.mongodb.get_recent_messages( conversation_id, limit=self.conversation_manager.max_messages ) history = self.conversation_manager.get_relevant_history( messages=history, current_query=query ) # Retrieve context if not provided if not context_docs: context_docs, sources, scores = await self.retrieve_context( query=query, conversation_history=history ) else: sources = None scores = None # Special handling for markdown section queries if "DISCUSSIONS AND ACTION ITEMS" in query.upper(): section_content = self._extract_markdown_section( context_docs, "**DISCUSSIONS AND ACTION ITEMS**" ) if section_content: return RAGResponse( response=section_content.strip(), context_docs=context_docs, sources=sources, scores=scores ) # Check if we have any relevant context if not context_docs: return RAGResponse( response="Information about this is not available, do you want to inquire about something else?", context_docs=[], sources=[], scores=None ) # Generate prompt with context and history augmented_prompt = self.conversation_manager.generate_prompt_with_history( current_query=query, history=history, context_docs=context_docs ) # Generate response response = self.llm.generate( prompt=augmented_prompt, temperature=temperature, max_tokens=max_tokens ) # Clean the response cleaned_response = self._clean_response(response) # Return the final response return RAGResponse( response=cleaned_response, context_docs=context_docs, sources=sources, scores=scores ) except Exception as e: logger.error(f"Error in RAGAgent: {str(e)}") raise def _create_response_prompt(self, query: str, context_docs: List[str]) -> str: """ Create prompt for generating response from context Args: query (str): User query context_docs (List[str]): Retrieved context documents Returns: str: Formatted prompt for the LLM """ if not context_docs: return f"Query: {query}\nResponse: Information about this is not available, do you want to inquire about something else?" # Format context documents formatted_context = "\n\n".join( f"Context {i+1}:\n{doc.strip()}" for i, doc in enumerate(context_docs) if doc and doc.strip() ) # Build the prompt with detailed instructions prompt = f"""You are a knowledgeable assistant. Use the following context to answer the query accurately and informatively. Context Information: {formatted_context} Query: {query} Instructions: 1. Base your response ONLY on the information provided in the context above 2. If the context contains numbers, statistics, or specific details, include them in your response 3. Keep your response focused and relevant to the query 4. Use clear and professional language 5. If the context includes technical terms, explain them appropriately 6. Do not make assumptions or add information not present in the context 7. If specific sections of a report are mentioned, maintain their original structure 8. Format the response in a clear, readable manner 9. If the context includes chronological information, maintain the proper sequence Response:""" return prompt async def retrieve_context( self, query: str, conversation_history: Optional[List[Dict]] = None ) -> Tuple[List[str], List[Dict], Optional[List[float]]]: """ Retrieve context with conversation history enhancement """ # Enhance query with conversation history if conversation_history: recent_queries = [ msg['query'] for msg in conversation_history[-2:] if msg.get('query') ] enhanced_query = " ".join([*recent_queries, query]) else: enhanced_query = query # Debug log the enhanced query logger.info(f"Enhanced query: {enhanced_query}") # Embed the enhanced query query_embedding = self.embedding.embed_query(enhanced_query) # Debug log embedding shape logger.info(f"Query embedding shape: {len(query_embedding)}") # Retrieve similar documents results = self.vector_store.similarity_search( query_embedding, top_k=settings.TOP_CHUNKS ) # Debug log search results logger.info(f"Number of search results: {len(results)}") for i, result in enumerate(results): logger.info(f"Result {i} score: {result.get('score', 'N/A')}") logger.info( f"Result {i} text preview: {result.get('text', '')[:100]}...") # Process results documents = [doc['text'] for doc in results] sources = [self._convert_metadata_to_strings(doc['metadata']) for doc in results] scores = [doc['score'] for doc in results if doc.get('score') is not None] # Return scores only if available for all documents if len(scores) != len(documents): scores = None return documents, sources, scores def _convert_metadata_to_strings(self, metadata: Dict) -> Dict: """Convert numeric metadata values to strings""" converted = {} for key, value in metadata.items(): if isinstance(value, (int, float)): converted[key] = str(value) else: converted[key] = value return converted