Spaces:
Running
Running
Log google drive documents in the mongodb, add source of the document and made chunks to overlap text.
acdfaa9
# src/agents/rag_agent.py | |
from typing import List, Optional, Tuple, Dict | |
import uuid | |
from .excel_aware_rag import ExcelAwareRAGAgent | |
from ..llms.base_llm import BaseLLM | |
from src.embeddings.base_embedding import BaseEmbedding | |
from src.vectorstores.base_vectorstore import BaseVectorStore | |
from src.utils.conversation_manager import ConversationManager | |
from src.db.mongodb_store import MongoDBStore | |
from src.models.rag import RAGResponse | |
from src.utils.logger import logger | |
from config.config import settings | |
class RAGAgent(ExcelAwareRAGAgent): | |
def __init__( | |
self, | |
llm: BaseLLM, | |
embedding: BaseEmbedding, | |
vector_store: BaseVectorStore, | |
mongodb: MongoDBStore, | |
max_history_tokens: int = 4000, | |
max_history_messages: int = 10 | |
): | |
""" | |
Initialize RAG Agent | |
Args: | |
llm (BaseLLM): Language model instance | |
embedding (BaseEmbedding): Embedding model instance | |
vector_store (BaseVectorStore): Vector store instance | |
mongodb (MongoDBStore): MongoDB store instance | |
max_history_tokens (int): Maximum tokens in conversation history | |
max_history_messages (int): Maximum messages to keep in history | |
""" | |
super().__init__() # Initialize ExcelAwareRAGAgent | |
self.llm = llm | |
self.embedding = embedding | |
self.vector_store = vector_store | |
self.mongodb = mongodb | |
self.conversation_manager = ConversationManager( | |
max_tokens=max_history_tokens, | |
max_messages=max_history_messages | |
) | |
def _extract_markdown_section(self, docs: List[str], section_header: str) -> str: | |
"""Extract complete section content from markdown documents""" | |
combined_text = '\n'.join(docs) | |
section_start = combined_text.find(section_header) | |
if section_start == -1: | |
return "" | |
next_section = combined_text.find( | |
"\n\n**", section_start + len(section_header)) | |
if next_section == -1: | |
section_content = combined_text[section_start:] | |
else: | |
section_content = combined_text[section_start:next_section] | |
return self._clean_markdown_content(section_content) | |
def _clean_markdown_content(self, content: str) -> str: | |
"""Clean and format markdown content""" | |
lines = content.split('\n') | |
seen_lines = set() | |
cleaned_lines = [] | |
for line in lines: | |
# Always keep headers and table formatting | |
if '| :----' in line or line.startswith('**'): | |
if line not in seen_lines: | |
cleaned_lines.append(line) | |
seen_lines.add(line) | |
continue | |
# Keep table rows and list items | |
if line.strip().startswith('|') or line.strip().startswith('-'): | |
cleaned_lines.append(line) | |
continue | |
# Remove duplicates for other content | |
stripped = line.strip() | |
if stripped and stripped not in seen_lines: | |
cleaned_lines.append(line) | |
seen_lines.add(stripped) | |
return '\n'.join(cleaned_lines) | |
async def generate_response( | |
self, | |
query: str, | |
conversation_id: Optional[str], | |
temperature: float, | |
max_tokens: Optional[int] = None, | |
context_docs: Optional[List[str]] = None | |
) -> RAGResponse: | |
"""Generate response with improved markdown and conversation handling""" | |
try: | |
# Handle introduction/welcome message queries | |
is_introduction = ( | |
"wants support" in query and | |
"This is Introduction" in query and | |
("A new user with name:" in query or "An old user with name:" in query) | |
) | |
if is_introduction: | |
welcome_message = self._handle_contact_query(query) | |
return RAGResponse( | |
response=welcome_message, | |
context_docs=[], | |
sources=[], | |
scores=None | |
) | |
# Get conversation history if conversation_id exists | |
history = [] | |
if conversation_id: | |
history = await self.mongodb.get_recent_messages( | |
conversation_id, | |
limit=self.conversation_manager.max_messages | |
) | |
history = self.conversation_manager.get_relevant_history( | |
messages=history, | |
current_query=query | |
) | |
# Retrieve context if not provided | |
if not context_docs: | |
context_docs, sources, scores = await self.retrieve_context( | |
query=query, | |
conversation_history=history | |
) | |
else: | |
sources = None | |
scores = None | |
# Special handling for markdown section queries | |
if "DISCUSSIONS AND ACTION ITEMS" in query.upper(): | |
section_content = self._extract_markdown_section( | |
context_docs, | |
"**DISCUSSIONS AND ACTION ITEMS**" | |
) | |
if section_content: | |
return RAGResponse( | |
response=section_content.strip(), | |
context_docs=context_docs, | |
sources=sources, | |
scores=scores | |
) | |
# Check if we have any relevant context | |
if not context_docs: | |
return RAGResponse( | |
response="Information about this is not available, do you want to inquire about something else?", | |
context_docs=[], | |
sources=[], | |
scores=None | |
) | |
# Generate prompt with context and history | |
augmented_prompt = self.conversation_manager.generate_prompt_with_history( | |
current_query=query, | |
history=history, | |
context_docs=context_docs | |
) | |
# Generate response | |
response = self.llm.generate( | |
prompt=augmented_prompt, | |
temperature=temperature, | |
max_tokens=max_tokens | |
) | |
# Clean the response | |
cleaned_response = self._clean_response(response) | |
# Return the final response | |
return RAGResponse( | |
response=cleaned_response, | |
context_docs=context_docs, | |
sources=sources, | |
scores=scores | |
) | |
except Exception as e: | |
logger.error(f"Error in RAGAgent: {str(e)}") | |
raise | |
def _create_response_prompt(self, query: str, context_docs: List[str]) -> str: | |
""" | |
Create prompt for generating response from context | |
Args: | |
query (str): User query | |
context_docs (List[str]): Retrieved context documents | |
Returns: | |
str: Formatted prompt for the LLM | |
""" | |
if not context_docs: | |
return f"Query: {query}\nResponse: Information about this is not available, do you want to inquire about something else?" | |
# Format context documents | |
formatted_context = "\n\n".join( | |
f"Context {i+1}:\n{doc.strip()}" | |
for i, doc in enumerate(context_docs) | |
if doc and doc.strip() | |
) | |
# Build the prompt with detailed instructions | |
prompt = f"""You are a knowledgeable assistant. Use the following context to answer the query accurately and informatively. | |
Context Information: | |
{formatted_context} | |
Query: {query} | |
Instructions: | |
1. Base your response ONLY on the information provided in the context above | |
2. If the context contains numbers, statistics, or specific details, include them in your response | |
3. Keep your response focused and relevant to the query | |
4. Use clear and professional language | |
5. If the context includes technical terms, explain them appropriately | |
6. Do not make assumptions or add information not present in the context | |
7. If specific sections of a report are mentioned, maintain their original structure | |
8. Format the response in a clear, readable manner | |
9. If the context includes chronological information, maintain the proper sequence | |
Response:""" | |
return prompt | |
async def retrieve_context( | |
self, | |
query: str, | |
conversation_history: Optional[List[Dict]] = None | |
) -> Tuple[List[str], List[Dict], Optional[List[float]]]: | |
""" | |
Retrieve context with conversation history enhancement | |
""" | |
# Enhance query with conversation history | |
if conversation_history: | |
recent_queries = [ | |
msg['query'] for msg in conversation_history[-2:] | |
if msg.get('query') | |
] | |
enhanced_query = " ".join([*recent_queries, query]) | |
else: | |
enhanced_query = query | |
# Debug log the enhanced query | |
logger.info(f"Enhanced query: {enhanced_query}") | |
# Embed the enhanced query | |
query_embedding = self.embedding.embed_query(enhanced_query) | |
# Debug log embedding shape | |
logger.info(f"Query embedding shape: {len(query_embedding)}") | |
# Retrieve similar documents | |
results = self.vector_store.similarity_search( | |
query_embedding, | |
top_k=settings.TOP_CHUNKS | |
) | |
# Debug log search results | |
logger.info(f"Number of search results: {len(results)}") | |
for i, result in enumerate(results): | |
logger.info(f"Result {i} score: {result.get('score', 'N/A')}") | |
logger.info( | |
f"Result {i} text preview: {result.get('text', '')[:100]}...") | |
# Process results | |
documents = [doc['text'] for doc in results] | |
sources = [self._convert_metadata_to_strings(doc['metadata']) | |
for doc in results] | |
scores = [doc['score'] for doc in results | |
if doc.get('score') is not None] | |
# Return scores only if available for all documents | |
if len(scores) != len(documents): | |
scores = None | |
return documents, sources, scores | |
def _convert_metadata_to_strings(self, metadata: Dict) -> Dict: | |
"""Convert numeric metadata values to strings""" | |
converted = {} | |
for key, value in metadata.items(): | |
if isinstance(value, (int, float)): | |
converted[key] = str(value) | |
else: | |
converted[key] = value | |
return converted | |