File size: 6,867 Bytes
640b1c8
0739c8b
9700f95
640b1c8
b953016
640b1c8
e87abff
 
9700f95
 
0739c8b
9700f95
640b1c8
b953016
640b1c8
9700f95
 
 
 
 
 
 
640b1c8
 
9700f95
640b1c8
 
9700f95
 
 
 
 
 
640b1c8
b953016
9700f95
 
 
 
 
 
 
640b1c8
9700f95
0739c8b
9700f95
0739c8b
9700f95
0739c8b
 
640b1c8
 
9700f95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b953016
 
 
 
 
 
 
 
 
9700f95
 
 
 
 
 
 
b953016
9700f95
 
 
 
 
 
b953016
 
 
 
 
 
 
 
 
 
 
 
9700f95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
640b1c8
9700f95
640b1c8
 
9700f95
 
 
 
640b1c8
9700f95
 
640b1c8
9700f95
 
 
 
 
 
 
0739c8b
9700f95
 
 
 
 
 
 
 
 
640b1c8
9700f95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
# src/agents/rag_agent.py
from typing import List, Optional, Tuple, Dict
import uuid

from .excel_aware_rag import ExcelAwareRAGAgent
from ..llms.base_llm import BaseLLM
from src.embeddings.base_embedding import BaseEmbedding
from src.vectorstores.base_vectorstore import BaseVectorStore
from src.utils.conversation_manager import ConversationManager
from src.db.mongodb_store import MongoDBStore
from src.models.rag import RAGResponse
from src.utils.logger import logger

class RAGAgent(ExcelAwareRAGAgent):
    def __init__(
        self,
        llm: BaseLLM,
        embedding: BaseEmbedding,
        vector_store: BaseVectorStore,
        mongodb: MongoDBStore,
        max_history_tokens: int = 4000,
        max_history_messages: int = 10
    ):
        """
        Initialize RAG Agent
        
        Args:
            llm (BaseLLM): Language model instance
            embedding (BaseEmbedding): Embedding model instance
            vector_store (BaseVectorStore): Vector store instance
            mongodb (MongoDBStore): MongoDB store instance
            max_history_tokens (int): Maximum tokens in conversation history
            max_history_messages (int): Maximum messages to keep in history
        """
        super().__init__()  # Initialize ExcelAwareRAGAgent
        self.llm = llm
        self.embedding = embedding
        self.vector_store = vector_store
        self.mongodb = mongodb
        self.conversation_manager = ConversationManager(
            max_tokens=max_history_tokens,
            max_messages=max_history_messages
        )

    async def generate_response(
        self,
        query: str,
        conversation_id: Optional[str] = None,
        temperature: float = 0.7,
        max_tokens: Optional[int] = None,
        context_docs: Optional[List[str]] = None
    ) -> RAGResponse:
        """Generate a response using RAG with conversation history"""
        try:
            # Create new conversation if no ID provided
            if not conversation_id:
                conversation_id = str(uuid.uuid4())
                await self.mongodb.create_conversation(conversation_id)
            
            # Get conversation history
            history = await self.mongodb.get_recent_messages(
                conversation_id,
                limit=self.conversation_manager.max_messages
            )
            
            # Get relevant history within token limits
            relevant_history = self.conversation_manager.get_relevant_history(
                messages=history,
                current_query=query
            ) if history else []

            # Retrieve context if not provided
            if not context_docs:
                context_docs, sources, scores = await self.retrieve_context(
                    query,
                    conversation_history=relevant_history
                )
            else:
                sources = None
                scores = None

            # Check if this is an Excel-related query and enhance context if needed
            has_excel_content = any('Sheet:' in doc for doc in (context_docs or []))
            if has_excel_content:
                try:
                    context_docs = self._process_excel_context(context_docs, query)
                except Exception as e:
                    logger.warning(f"Error processing Excel context: {str(e)}")
                    # Continue with original context if Excel processing fails

            # Generate prompt with context and history
            augmented_prompt = self.conversation_manager.generate_prompt_with_history(
                current_query=query,
                history=relevant_history,
                context_docs=context_docs
            )

            # Generate initial response using LLM
            response = self.llm.generate(
                augmented_prompt,
                temperature=temperature,
                max_tokens=max_tokens
            )

            # Enhance response for Excel queries if applicable
            if has_excel_content:
                try:
                    response = await self.enhance_excel_response(
                        query=query,
                        response=response,
                        context_docs=context_docs
                    )
                except Exception as e:
                    logger.warning(f"Error enhancing Excel response: {str(e)}")
                    # Continue with original response if enhancement fails

            return RAGResponse(
                response=response,
                context_docs=context_docs,
                sources=sources,
                scores=scores
            )

        except Exception as e:
            logger.error(f"Error generating response: {str(e)}")
            raise

    async def retrieve_context(
        self,
        query: str,
        conversation_history: Optional[List[Dict]] = None,
        top_k: int = 3
    ) -> Tuple[List[str], List[Dict], Optional[List[float]]]:
        """
        Retrieve context with conversation history enhancement
        
        Args:
            query (str): Current query
            conversation_history (Optional[List[Dict]]): Recent conversation history
            top_k (int): Number of documents to retrieve
            
        Returns:
            Tuple[List[str], List[Dict], Optional[List[float]]]: 
                Retrieved documents, sources, and scores
        """
        # Enhance query with conversation history
        if conversation_history:
            recent_queries = [
                msg['query'] for msg in conversation_history[-2:]
                if msg.get('query')
            ]
            enhanced_query = " ".join([*recent_queries, query])
        else:
            enhanced_query = query

        # Embed the enhanced query
        query_embedding = self.embedding.embed_query(enhanced_query)

        # Retrieve similar documents
        results = self.vector_store.similarity_search(
            query_embedding,
            top_k=top_k
        )

        # Process results
        documents = [doc['text'] for doc in results]
        sources = [self._convert_metadata_to_strings(doc['metadata']) 
                  for doc in results]
        scores = [doc['score'] for doc in results 
                 if doc.get('score') is not None]

        # Return scores only if available for all documents
        if len(scores) != len(documents):
            scores = None

        return documents, sources, scores

    def _convert_metadata_to_strings(self, metadata: Dict) -> Dict:
        """Convert numeric metadata values to strings"""
        converted = {}
        for key, value in metadata.items():
            if isinstance(value, (int, float)):
                converted[key] = str(value)
            else:
                converted[key] = value
        return converted