Spaces:
Running
Running
Commit
·
b953016
1
Parent(s):
9700f95
Enhanced the support for the excel file and added endpoint to have optimized vector store and Rag for the Excel.
Browse files- config/__pycache__/config.cpython-312.pyc +0 -0
- config/config.py +2 -0
- src/__pycache__/main.cpython-312.pyc +0 -0
- src/agents/__pycache__/excel_aware_rag.cpython-312.pyc +0 -0
- src/agents/__pycache__/rag_agent.cpython-312.pyc +0 -0
- src/agents/excel_aware_rag.py +237 -0
- src/agents/rag_agent.py +25 -2
- src/db/__pycache__/mongodb_store.cpython-312.pyc +0 -0
- src/db/mongodb_store.py +50 -3
- src/main.py +173 -14
- src/models/UserContact.py +28 -0
- src/models/__pycache__/UserContact.cpython-312.pyc +0 -0
- src/utils/__pycache__/database_cleanup.cpython-312.pyc +0 -0
- src/utils/__pycache__/document_processor.cpython-312.pyc +0 -0
- src/utils/__pycache__/enhanced_excel_processor.cpython-312.pyc +0 -0
- src/utils/__pycache__/llm_utils.cpython-312.pyc +0 -0
- src/utils/database_cleanup.py +182 -0
- src/utils/document_processor.py +169 -71
- src/utils/enhanced_excel_processor.py +187 -0
- src/utils/excel_integration +139 -0
- src/utils/llm_utils.py +11 -9
- src/vectorstores/__pycache__/optimized_vectorstore.cpython-312.pyc +0 -0
- src/vectorstores/optimized_vectorstore.py +137 -0
config/__pycache__/config.cpython-312.pyc
CHANGED
Binary files a/config/__pycache__/config.cpython-312.pyc and b/config/__pycache__/config.cpython-312.pyc differ
|
|
config/config.py
CHANGED
@@ -10,6 +10,8 @@ class Settings:
|
|
10 |
OPENAI_API_KEY = os.getenv('OPENAI_API_KEY', '')
|
11 |
OPENAI_MODEL = os.getenv('OPENAI_MODEL', 'gpt-3.5-turbo')
|
12 |
|
|
|
|
|
13 |
# Ollama Configuration
|
14 |
OLLAMA_BASE_URL = os.getenv('OLLAMA_BASE_URL', 'http://localhost:11434')
|
15 |
OLLAMA_MODEL = os.getenv('OLLAMA_MODEL', 'llama2')
|
|
|
10 |
OPENAI_API_KEY = os.getenv('OPENAI_API_KEY', '')
|
11 |
OPENAI_MODEL = os.getenv('OPENAI_MODEL', 'gpt-3.5-turbo')
|
12 |
|
13 |
+
ADMIN_API_KEY = 'aca4081f-6ff2-434c-843b-98f60285c499'
|
14 |
+
|
15 |
# Ollama Configuration
|
16 |
OLLAMA_BASE_URL = os.getenv('OLLAMA_BASE_URL', 'http://localhost:11434')
|
17 |
OLLAMA_MODEL = os.getenv('OLLAMA_MODEL', 'llama2')
|
src/__pycache__/main.cpython-312.pyc
CHANGED
Binary files a/src/__pycache__/main.cpython-312.pyc and b/src/__pycache__/main.cpython-312.pyc differ
|
|
src/agents/__pycache__/excel_aware_rag.cpython-312.pyc
ADDED
Binary file (10.1 kB). View file
|
|
src/agents/__pycache__/rag_agent.cpython-312.pyc
CHANGED
Binary files a/src/agents/__pycache__/rag_agent.cpython-312.pyc and b/src/agents/__pycache__/rag_agent.cpython-312.pyc differ
|
|
src/agents/excel_aware_rag.py
ADDED
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# src/agents/excel_aware_rag.py
|
2 |
+
from typing import List, Dict, Optional, Set
|
3 |
+
from src.utils.logger import logger
|
4 |
+
|
5 |
+
class ExcelAwareRAGAgent:
|
6 |
+
"""Extension of RAGAgent with enhanced Excel handling"""
|
7 |
+
|
8 |
+
def _process_excel_context(self, context_docs: List[str], query: str) -> List[str]:
|
9 |
+
"""
|
10 |
+
Process and enhance context for Excel-related queries
|
11 |
+
|
12 |
+
Args:
|
13 |
+
context_docs (List[str]): Original context documents
|
14 |
+
query (str): User query
|
15 |
+
|
16 |
+
Returns:
|
17 |
+
List[str]: Enhanced context documents
|
18 |
+
"""
|
19 |
+
excel_context = []
|
20 |
+
|
21 |
+
for doc in context_docs:
|
22 |
+
if 'Sheet:' in doc: # Identify Excel content
|
23 |
+
# Extract relevant sheet context based on query
|
24 |
+
relevant_sheets = self._identify_relevant_sheets(doc, query)
|
25 |
+
for sheet in relevant_sheets:
|
26 |
+
sheet_context = self._extract_sheet_context(doc, sheet)
|
27 |
+
if sheet_context:
|
28 |
+
excel_context.append(sheet_context)
|
29 |
+
|
30 |
+
# Add relationship context if query suggests multi-sheet analysis
|
31 |
+
if self._needs_relationship_context(query):
|
32 |
+
relationship_context = self._extract_relationship_context(doc)
|
33 |
+
if relationship_context:
|
34 |
+
excel_context.append(relationship_context)
|
35 |
+
else:
|
36 |
+
excel_context.append(doc)
|
37 |
+
|
38 |
+
return excel_context
|
39 |
+
|
40 |
+
def _identify_relevant_sheets(self, doc: str, query: str) -> List[str]:
|
41 |
+
"""Identify sheets relevant to the query"""
|
42 |
+
sheets = []
|
43 |
+
current_sheet = None
|
44 |
+
|
45 |
+
# Extract sheet names from the document
|
46 |
+
for line in doc.split('\n'):
|
47 |
+
if line.startswith('Sheet: '):
|
48 |
+
current_sheet = line.replace('Sheet: ', '').strip()
|
49 |
+
# Check if sheet name or its contents are relevant to query
|
50 |
+
if self._is_relevant_to_query(current_sheet, query):
|
51 |
+
sheets.append(current_sheet)
|
52 |
+
|
53 |
+
return sheets
|
54 |
+
|
55 |
+
def _is_relevant_to_query(self, sheet_name: str, query: str) -> bool:
|
56 |
+
"""Check if a sheet is relevant to the query"""
|
57 |
+
# Convert to lower case for comparison
|
58 |
+
query_lower = query.lower()
|
59 |
+
sheet_lower = sheet_name.lower()
|
60 |
+
|
61 |
+
# Direct mention of sheet name
|
62 |
+
if sheet_lower in query_lower:
|
63 |
+
return True
|
64 |
+
|
65 |
+
# Check for related terms
|
66 |
+
sheet_terms = set(sheet_lower.split())
|
67 |
+
query_terms = set(query_lower.split())
|
68 |
+
|
69 |
+
# If there's significant term overlap
|
70 |
+
common_terms = sheet_terms.intersection(query_terms)
|
71 |
+
if len(common_terms) > 0:
|
72 |
+
return True
|
73 |
+
|
74 |
+
return False
|
75 |
+
|
76 |
+
def _extract_sheet_context(self, doc: str, sheet_name: str) -> Optional[str]:
|
77 |
+
"""Extract context for a specific sheet"""
|
78 |
+
lines = doc.split('\n')
|
79 |
+
sheet_context = []
|
80 |
+
in_target_sheet = False
|
81 |
+
|
82 |
+
for line in lines:
|
83 |
+
if line.startswith(f'Sheet: {sheet_name}'):
|
84 |
+
in_target_sheet = True
|
85 |
+
sheet_context.append(line)
|
86 |
+
elif line.startswith('Sheet: '):
|
87 |
+
in_target_sheet = False
|
88 |
+
elif in_target_sheet:
|
89 |
+
sheet_context.append(line)
|
90 |
+
|
91 |
+
return '\n'.join(sheet_context) if sheet_context else None
|
92 |
+
|
93 |
+
def _needs_relationship_context(self, query: str) -> bool:
|
94 |
+
"""Determine if query needs relationship context"""
|
95 |
+
relationship_indicators = [
|
96 |
+
'compare', 'relationship', 'between', 'across', 'correlation',
|
97 |
+
'related', 'connection', 'link', 'join', 'combine', 'multiple sheets',
|
98 |
+
'all sheets', 'different sheets'
|
99 |
+
]
|
100 |
+
|
101 |
+
query_lower = query.lower()
|
102 |
+
return any(indicator in query_lower for indicator in relationship_indicators)
|
103 |
+
|
104 |
+
def _extract_relationship_context(self, doc: str) -> Optional[str]:
|
105 |
+
"""Extract relationship context from document"""
|
106 |
+
lines = doc.split('\n')
|
107 |
+
relationship_context = []
|
108 |
+
in_relationships = False
|
109 |
+
|
110 |
+
for line in lines:
|
111 |
+
if 'Sheet Relationships:' in line:
|
112 |
+
in_relationships = True
|
113 |
+
relationship_context.append(line)
|
114 |
+
elif in_relationships and line.strip() and not line.startswith('Sheet: '):
|
115 |
+
relationship_context.append(line)
|
116 |
+
elif in_relationships and line.startswith('Sheet: '):
|
117 |
+
break
|
118 |
+
|
119 |
+
return '\n'.join(relationship_context) if relationship_context else None
|
120 |
+
|
121 |
+
async def enhance_excel_response(
|
122 |
+
self,
|
123 |
+
query: str,
|
124 |
+
response: str,
|
125 |
+
context_docs: List[str]
|
126 |
+
) -> str:
|
127 |
+
"""
|
128 |
+
Enhance response for Excel-related queries
|
129 |
+
|
130 |
+
Args:
|
131 |
+
query (str): Original query
|
132 |
+
response (str): Generated response
|
133 |
+
context_docs (List[str]): Context documents
|
134 |
+
|
135 |
+
Returns:
|
136 |
+
str: Enhanced response
|
137 |
+
"""
|
138 |
+
if not any('Sheet:' in doc for doc in context_docs):
|
139 |
+
return response
|
140 |
+
|
141 |
+
try:
|
142 |
+
# Enhance response with specific Excel insights
|
143 |
+
enhanced_parts = [response]
|
144 |
+
|
145 |
+
# Add sheet-specific insights if relevant
|
146 |
+
if self._needs_sheet_specific_insights(query):
|
147 |
+
insights = self._generate_sheet_insights(query, context_docs)
|
148 |
+
if insights:
|
149 |
+
enhanced_parts.append("\nAdditional Sheet Insights:")
|
150 |
+
enhanced_parts.extend(insights)
|
151 |
+
|
152 |
+
# Add relationship insights if relevant
|
153 |
+
if self._needs_relationship_context(query):
|
154 |
+
relationship_insights = self._generate_relationship_insights(context_docs)
|
155 |
+
if relationship_insights:
|
156 |
+
enhanced_parts.append("\nSheet Relationship Insights:")
|
157 |
+
enhanced_parts.extend(relationship_insights)
|
158 |
+
|
159 |
+
return "\n".join(enhanced_parts)
|
160 |
+
except Exception as e:
|
161 |
+
logger.error(f"Error enhancing Excel response: {str(e)}")
|
162 |
+
return response # Fall back to original response if enhancement fails
|
163 |
+
|
164 |
+
def _needs_sheet_specific_insights(self, query: str) -> bool:
|
165 |
+
"""Determine if query needs sheet-specific insights"""
|
166 |
+
insight_indicators = [
|
167 |
+
'analyze', 'summarize', 'tell me about', 'what is in',
|
168 |
+
'show me', 'describe', 'explain', 'give me details'
|
169 |
+
]
|
170 |
+
|
171 |
+
query_lower = query.lower()
|
172 |
+
return any(indicator in query_lower for indicator in insight_indicators)
|
173 |
+
|
174 |
+
def _generate_sheet_insights(self, query: str, context_docs: List[str]) -> List[str]:
|
175 |
+
"""Generate insights for relevant sheets"""
|
176 |
+
insights = []
|
177 |
+
relevant_sheets = set()
|
178 |
+
|
179 |
+
# Collect relevant sheets from context
|
180 |
+
for doc in context_docs:
|
181 |
+
if 'Sheet:' in doc:
|
182 |
+
sheets = self._identify_relevant_sheets(doc, query)
|
183 |
+
relevant_sheets.update(sheets)
|
184 |
+
|
185 |
+
# Generate insights for each relevant sheet
|
186 |
+
for sheet in relevant_sheets:
|
187 |
+
sheet_insights = self._generate_single_sheet_insights(sheet, context_docs)
|
188 |
+
if sheet_insights:
|
189 |
+
insights.extend(sheet_insights)
|
190 |
+
|
191 |
+
return insights
|
192 |
+
|
193 |
+
def _generate_single_sheet_insights(
|
194 |
+
self,
|
195 |
+
sheet_name: str,
|
196 |
+
context_docs: List[str]
|
197 |
+
) -> List[str]:
|
198 |
+
"""Generate insights for a single sheet"""
|
199 |
+
insights = []
|
200 |
+
sheet_context = None
|
201 |
+
|
202 |
+
# Find context for this sheet
|
203 |
+
for doc in context_docs:
|
204 |
+
if f'Sheet: {sheet_name}' in doc:
|
205 |
+
sheet_context = self._extract_sheet_context(doc, sheet_name)
|
206 |
+
break
|
207 |
+
|
208 |
+
if not sheet_context:
|
209 |
+
return insights
|
210 |
+
|
211 |
+
# Extract and summarize key information
|
212 |
+
if 'Numeric Columns Summary:' in sheet_context:
|
213 |
+
numeric_insights = self._extract_numeric_insights(sheet_context)
|
214 |
+
if numeric_insights:
|
215 |
+
insights.extend(numeric_insights)
|
216 |
+
|
217 |
+
if 'Categorical Columns Summary:' in sheet_context:
|
218 |
+
categorical_insights = self._extract_categorical_insights(sheet_context)
|
219 |
+
if categorical_insights:
|
220 |
+
insights.extend(categorical_insights)
|
221 |
+
|
222 |
+
return insights
|
223 |
+
|
224 |
+
def _generate_relationship_insights(self, context_docs: List[str]) -> List[str]:
|
225 |
+
"""Generate insights about relationships between sheets"""
|
226 |
+
insights = []
|
227 |
+
|
228 |
+
for doc in context_docs:
|
229 |
+
relationship_context = self._extract_relationship_context(doc)
|
230 |
+
if relationship_context:
|
231 |
+
# Process and format relationship information
|
232 |
+
relationships = relationship_context.split('\n')[1:] # Skip header
|
233 |
+
for rel in relationships:
|
234 |
+
if rel.strip():
|
235 |
+
insights.append(f"- {rel.strip()}")
|
236 |
+
|
237 |
+
return insights
|
src/agents/rag_agent.py
CHANGED
@@ -2,6 +2,7 @@
|
|
2 |
from typing import List, Optional, Tuple, Dict
|
3 |
import uuid
|
4 |
|
|
|
5 |
from ..llms.base_llm import BaseLLM
|
6 |
from src.embeddings.base_embedding import BaseEmbedding
|
7 |
from src.vectorstores.base_vectorstore import BaseVectorStore
|
@@ -10,7 +11,7 @@ from src.db.mongodb_store import MongoDBStore
|
|
10 |
from src.models.rag import RAGResponse
|
11 |
from src.utils.logger import logger
|
12 |
|
13 |
-
class RAGAgent:
|
14 |
def __init__(
|
15 |
self,
|
16 |
llm: BaseLLM,
|
@@ -31,6 +32,7 @@ class RAGAgent:
|
|
31 |
max_history_tokens (int): Maximum tokens in conversation history
|
32 |
max_history_messages (int): Maximum messages to keep in history
|
33 |
"""
|
|
|
34 |
self.llm = llm
|
35 |
self.embedding = embedding
|
36 |
self.vector_store = vector_store
|
@@ -77,6 +79,15 @@ class RAGAgent:
|
|
77 |
sources = None
|
78 |
scores = None
|
79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
# Generate prompt with context and history
|
81 |
augmented_prompt = self.conversation_manager.generate_prompt_with_history(
|
82 |
current_query=query,
|
@@ -84,13 +95,25 @@ class RAGAgent:
|
|
84 |
context_docs=context_docs
|
85 |
)
|
86 |
|
87 |
-
# Generate response using LLM
|
88 |
response = self.llm.generate(
|
89 |
augmented_prompt,
|
90 |
temperature=temperature,
|
91 |
max_tokens=max_tokens
|
92 |
)
|
93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
return RAGResponse(
|
95 |
response=response,
|
96 |
context_docs=context_docs,
|
|
|
2 |
from typing import List, Optional, Tuple, Dict
|
3 |
import uuid
|
4 |
|
5 |
+
from .excel_aware_rag import ExcelAwareRAGAgent
|
6 |
from ..llms.base_llm import BaseLLM
|
7 |
from src.embeddings.base_embedding import BaseEmbedding
|
8 |
from src.vectorstores.base_vectorstore import BaseVectorStore
|
|
|
11 |
from src.models.rag import RAGResponse
|
12 |
from src.utils.logger import logger
|
13 |
|
14 |
+
class RAGAgent(ExcelAwareRAGAgent):
|
15 |
def __init__(
|
16 |
self,
|
17 |
llm: BaseLLM,
|
|
|
32 |
max_history_tokens (int): Maximum tokens in conversation history
|
33 |
max_history_messages (int): Maximum messages to keep in history
|
34 |
"""
|
35 |
+
super().__init__() # Initialize ExcelAwareRAGAgent
|
36 |
self.llm = llm
|
37 |
self.embedding = embedding
|
38 |
self.vector_store = vector_store
|
|
|
79 |
sources = None
|
80 |
scores = None
|
81 |
|
82 |
+
# Check if this is an Excel-related query and enhance context if needed
|
83 |
+
has_excel_content = any('Sheet:' in doc for doc in (context_docs or []))
|
84 |
+
if has_excel_content:
|
85 |
+
try:
|
86 |
+
context_docs = self._process_excel_context(context_docs, query)
|
87 |
+
except Exception as e:
|
88 |
+
logger.warning(f"Error processing Excel context: {str(e)}")
|
89 |
+
# Continue with original context if Excel processing fails
|
90 |
+
|
91 |
# Generate prompt with context and history
|
92 |
augmented_prompt = self.conversation_manager.generate_prompt_with_history(
|
93 |
current_query=query,
|
|
|
95 |
context_docs=context_docs
|
96 |
)
|
97 |
|
98 |
+
# Generate initial response using LLM
|
99 |
response = self.llm.generate(
|
100 |
augmented_prompt,
|
101 |
temperature=temperature,
|
102 |
max_tokens=max_tokens
|
103 |
)
|
104 |
|
105 |
+
# Enhance response for Excel queries if applicable
|
106 |
+
if has_excel_content:
|
107 |
+
try:
|
108 |
+
response = await self.enhance_excel_response(
|
109 |
+
query=query,
|
110 |
+
response=response,
|
111 |
+
context_docs=context_docs
|
112 |
+
)
|
113 |
+
except Exception as e:
|
114 |
+
logger.warning(f"Error enhancing Excel response: {str(e)}")
|
115 |
+
# Continue with original response if enhancement fails
|
116 |
+
|
117 |
return RAGResponse(
|
118 |
response=response,
|
119 |
context_docs=context_docs,
|
src/db/__pycache__/mongodb_store.cpython-312.pyc
CHANGED
Binary files a/src/db/__pycache__/mongodb_store.cpython-312.pyc and b/src/db/__pycache__/mongodb_store.cpython-312.pyc differ
|
|
src/db/mongodb_store.py
CHANGED
@@ -62,14 +62,53 @@ class MongoDBStore:
|
|
62 |
"""Delete document from MongoDB"""
|
63 |
result = await self.documents.delete_one({"document_id": document_id})
|
64 |
return result.deleted_count > 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
# Conversation and chat history methods
|
67 |
async def create_conversation(
|
68 |
self,
|
69 |
conversation_id: str,
|
70 |
-
metadata: Optional[Dict] = None
|
|
|
|
|
|
|
71 |
) -> str:
|
72 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
conversation = {
|
74 |
"conversation_id": conversation_id,
|
75 |
"created_at": datetime.now(),
|
@@ -77,7 +116,15 @@ class MongoDBStore:
|
|
77 |
"message_count": 0,
|
78 |
"metadata": metadata or {}
|
79 |
}
|
80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
await self.conversations.insert_one(conversation)
|
82 |
return conversation_id
|
83 |
|
|
|
62 |
"""Delete document from MongoDB"""
|
63 |
result = await self.documents.delete_one({"document_id": document_id})
|
64 |
return result.deleted_count > 0
|
65 |
+
|
66 |
+
async def find_existing_user(
|
67 |
+
self,
|
68 |
+
email: str,
|
69 |
+
phone_number: str
|
70 |
+
) -> Optional[str]:
|
71 |
+
"""
|
72 |
+
Find existing user by email or phone number
|
73 |
+
|
74 |
+
Args:
|
75 |
+
email (str): User's email
|
76 |
+
phone_number (str): User's phone number
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
Optional[str]: Conversation ID if found, None otherwise
|
80 |
+
"""
|
81 |
+
result = await self.conversations.find_one({
|
82 |
+
"$or": [
|
83 |
+
{"email": email},
|
84 |
+
{"phone_number": phone_number}
|
85 |
+
]
|
86 |
+
})
|
87 |
+
|
88 |
+
return result["conversation_id"] if result else None
|
89 |
|
90 |
# Conversation and chat history methods
|
91 |
async def create_conversation(
|
92 |
self,
|
93 |
conversation_id: str,
|
94 |
+
metadata: Optional[Dict] = None,
|
95 |
+
full_name: Optional[str] = None,
|
96 |
+
email: Optional[str] = None,
|
97 |
+
phone_number: Optional[str] = None
|
98 |
) -> str:
|
99 |
+
"""
|
100 |
+
Create a new conversation
|
101 |
+
|
102 |
+
Args:
|
103 |
+
conversation_id (str): Unique conversation ID
|
104 |
+
metadata (Optional[Dict]): Additional metadata
|
105 |
+
full_name (Optional[str]): User's full name
|
106 |
+
email (Optional[str]): User's email
|
107 |
+
phone_number (Optional[str]): User's phone number
|
108 |
+
|
109 |
+
Returns:
|
110 |
+
str: Conversation ID
|
111 |
+
"""
|
112 |
conversation = {
|
113 |
"conversation_id": conversation_id,
|
114 |
"created_at": datetime.now(),
|
|
|
116 |
"message_count": 0,
|
117 |
"metadata": metadata or {}
|
118 |
}
|
119 |
+
|
120 |
+
# Add user information if provided
|
121 |
+
if full_name:
|
122 |
+
conversation["full_name"] = full_name
|
123 |
+
if email:
|
124 |
+
conversation["email"] = email
|
125 |
+
if phone_number:
|
126 |
+
conversation["phone_number"] = phone_number
|
127 |
+
|
128 |
await self.conversations.insert_one(conversation)
|
129 |
return conversation_id
|
130 |
|
src/main.py
CHANGED
@@ -12,6 +12,7 @@ import os
|
|
12 |
# Import custom modules1
|
13 |
from src.agents.rag_agent import RAGAgent
|
14 |
from src.models.document import AllDocumentsResponse, StoredDocument
|
|
|
15 |
from src.utils.document_processor import DocumentProcessor
|
16 |
from src.utils.conversation_summarizer import ConversationSummarizer
|
17 |
from src.utils.logger import logger
|
@@ -21,12 +22,15 @@ from src.implementations.document_service import DocumentService
|
|
21 |
from src.models import (
|
22 |
ChatRequest,
|
23 |
ChatResponse,
|
24 |
-
DocumentResponse,
|
25 |
BatchUploadResponse,
|
26 |
SummarizeRequest,
|
27 |
SummaryResponse,
|
28 |
FeedbackRequest
|
29 |
)
|
|
|
|
|
|
|
|
|
30 |
from config.config import settings
|
31 |
|
32 |
app = FastAPI(title="Chatbot API")
|
@@ -54,6 +58,18 @@ UPLOADS_DIR.mkdir(exist_ok=True)
|
|
54 |
# Mount the uploads directory for static file serving
|
55 |
app.mount("/docs", StaticFiles(directory=str(UPLOADS_DIR)), name="documents")
|
56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
@app.get("/documents")
|
58 |
async def get_all_documents():
|
59 |
"""Get all documents from MongoDB"""
|
@@ -190,18 +206,72 @@ async def delete_document(document_id: str):
|
|
190 |
except Exception as e:
|
191 |
logger.error(f"Error in delete_document endpoint: {str(e)}")
|
192 |
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
193 |
|
194 |
@app.post("/chat", response_model=ChatResponse)
|
195 |
async def chat_endpoint(
|
196 |
request: ChatRequest,
|
197 |
background_tasks: BackgroundTasks
|
198 |
):
|
199 |
-
"""Chat endpoint with RAG support"""
|
200 |
try:
|
|
|
|
|
201 |
vector_store, embedding_model = await get_vector_store()
|
|
|
|
|
202 |
llm = get_llm_instance(request.llm_provider)
|
203 |
|
204 |
-
# Initialize RAG agent
|
205 |
rag_agent = RAGAgent(
|
206 |
llm=llm,
|
207 |
embedding=embedding_model,
|
@@ -211,14 +281,69 @@ async def chat_endpoint(
|
|
211 |
|
212 |
# Use provided conversation ID or create new one
|
213 |
conversation_id = request.conversation_id or str(uuid.uuid4())
|
214 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
# Generate response
|
216 |
-
response
|
217 |
-
query=query,
|
218 |
-
conversation_id=conversation_id,
|
219 |
-
temperature=request.temperature
|
220 |
-
)
|
221 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
222 |
# Store message in chat history
|
223 |
await mongodb.store_message(
|
224 |
conversation_id=conversation_id,
|
@@ -228,19 +353,32 @@ async def chat_endpoint(
|
|
228 |
sources=response.sources,
|
229 |
llm_provider=request.llm_provider
|
230 |
)
|
231 |
-
|
232 |
-
return
|
|
|
233 |
response=response.response,
|
234 |
context=response.context_docs,
|
235 |
sources=response.sources,
|
236 |
conversation_id=conversation_id,
|
237 |
timestamp=datetime.now(),
|
238 |
-
relevant_doc_scores=response.scores if hasattr(response, 'scores') else None
|
|
|
239 |
)
|
|
|
|
|
|
|
240 |
|
|
|
|
|
241 |
except Exception as e:
|
242 |
-
logger.error(f"Error in chat endpoint: {str(e)}")
|
243 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
244 |
|
245 |
@app.get("/chat/history/{conversation_id}")
|
246 |
async def get_conversation_history(conversation_id: str):
|
@@ -347,6 +485,27 @@ async def debug_config():
|
|
347 |
|
348 |
return debug_info
|
349 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
350 |
@app.get("/health")
|
351 |
async def health_check():
|
352 |
"""Health check endpoint"""
|
|
|
12 |
# Import custom modules1
|
13 |
from src.agents.rag_agent import RAGAgent
|
14 |
from src.models.document import AllDocumentsResponse, StoredDocument
|
15 |
+
from src.models.UserContact import UserContactRequest
|
16 |
from src.utils.document_processor import DocumentProcessor
|
17 |
from src.utils.conversation_summarizer import ConversationSummarizer
|
18 |
from src.utils.logger import logger
|
|
|
22 |
from src.models import (
|
23 |
ChatRequest,
|
24 |
ChatResponse,
|
|
|
25 |
BatchUploadResponse,
|
26 |
SummarizeRequest,
|
27 |
SummaryResponse,
|
28 |
FeedbackRequest
|
29 |
)
|
30 |
+
from fastapi import HTTPException, Depends
|
31 |
+
from fastapi.security import APIKeyHeader
|
32 |
+
from src.utils.database_cleanup import perform_cleanup
|
33 |
+
|
34 |
from config.config import settings
|
35 |
|
36 |
app = FastAPI(title="Chatbot API")
|
|
|
58 |
# Mount the uploads directory for static file serving
|
59 |
app.mount("/docs", StaticFiles(directory=str(UPLOADS_DIR)), name="documents")
|
60 |
|
61 |
+
# Security setup
|
62 |
+
API_KEY_HEADER = APIKeyHeader(name="ADMIN_API_KEY")
|
63 |
+
|
64 |
+
async def verify_api_key(api_key: str = Depends(API_KEY_HEADER)):
|
65 |
+
"""Verify admin API key"""
|
66 |
+
if not settings.ADMIN_API_KEY or api_key != settings.ADMIN_API_KEY:
|
67 |
+
raise HTTPException(
|
68 |
+
status_code=403,
|
69 |
+
detail="Invalid or missing API key"
|
70 |
+
)
|
71 |
+
return api_key
|
72 |
+
|
73 |
@app.get("/documents")
|
74 |
async def get_all_documents():
|
75 |
"""Get all documents from MongoDB"""
|
|
|
206 |
except Exception as e:
|
207 |
logger.error(f"Error in delete_document endpoint: {str(e)}")
|
208 |
raise HTTPException(status_code=500, detail=str(e))
|
209 |
+
|
210 |
+
# src/main.py
|
211 |
+
|
212 |
+
@app.post("/user/contact", response_model=ChatResponse)
|
213 |
+
async def create_user_contact(
|
214 |
+
request: UserContactRequest,
|
215 |
+
background_tasks: BackgroundTasks
|
216 |
+
):
|
217 |
+
"""Create or retrieve user conversation based on contact information"""
|
218 |
+
try:
|
219 |
+
# Check for existing user
|
220 |
+
existing_conversation_id = await mongodb.find_existing_user(
|
221 |
+
email=request.email,
|
222 |
+
phone_number=request.phone_number
|
223 |
+
)
|
224 |
+
|
225 |
+
if existing_conversation_id:
|
226 |
+
chat_request = ChatRequest(
|
227 |
+
query=f'An old user with name: "{request.full_name}", email: "{request.email}" and phone number: "{request.phone_number}" wants support again. Create a welcome back message for him and ask how i can help you today?',
|
228 |
+
llm_provider="openai",
|
229 |
+
max_context_docs=3,
|
230 |
+
temperature=1.0,
|
231 |
+
stream=False,
|
232 |
+
conversation_id=existing_conversation_id
|
233 |
+
)
|
234 |
+
else:
|
235 |
+
# Create new conversation with user information
|
236 |
+
new_conversation_id = str(uuid.uuid4())
|
237 |
+
await mongodb.create_conversation(
|
238 |
+
conversation_id=new_conversation_id,
|
239 |
+
full_name=request.full_name,
|
240 |
+
email=request.email,
|
241 |
+
phone_number=request.phone_number
|
242 |
+
)
|
243 |
+
|
244 |
+
chat_request = ChatRequest(
|
245 |
+
query=f'A new user with name: "{request.full_name}", email: "{request.email}" and phone number: "{request.phone_number}" wants support. Create a welcome message for him and ask how i can help you today?',
|
246 |
+
llm_provider="openai",
|
247 |
+
max_context_docs=3,
|
248 |
+
temperature=1.0,
|
249 |
+
stream=False,
|
250 |
+
conversation_id=new_conversation_id
|
251 |
+
)
|
252 |
+
|
253 |
+
# Call chat_endpoint with the prepared request
|
254 |
+
return await chat_endpoint(chat_request, background_tasks)
|
255 |
+
|
256 |
+
except Exception as e:
|
257 |
+
logger.error(f"Error in create_user_contact: {str(e)}")
|
258 |
+
raise HTTPException(status_code=500, detail=str(e))
|
259 |
|
260 |
@app.post("/chat", response_model=ChatResponse)
|
261 |
async def chat_endpoint(
|
262 |
request: ChatRequest,
|
263 |
background_tasks: BackgroundTasks
|
264 |
):
|
265 |
+
"""Chat endpoint with RAG support and enhanced Excel handling"""
|
266 |
try:
|
267 |
+
# Initialize core components
|
268 |
+
logger.info(f"Initializing vector store and embedding: {str(datetime.now())}")
|
269 |
vector_store, embedding_model = await get_vector_store()
|
270 |
+
|
271 |
+
logger.info(f"Initializing LLM: {str(datetime.now())}")
|
272 |
llm = get_llm_instance(request.llm_provider)
|
273 |
|
274 |
+
# Initialize RAG agent
|
275 |
rag_agent = RAGAgent(
|
276 |
llm=llm,
|
277 |
embedding=embedding_model,
|
|
|
281 |
|
282 |
# Use provided conversation ID or create new one
|
283 |
conversation_id = request.conversation_id or str(uuid.uuid4())
|
284 |
+
|
285 |
+
# Process the query
|
286 |
+
query = request.query
|
287 |
+
|
288 |
+
# Add specific instructions for certain types of queries
|
289 |
+
#if "introduce" in query.lower() or "name" in query.lower() or "email" in query.lower():
|
290 |
+
query += ". The response should be short and to the point. Make sure to not add any irrelevant information. Keep the introduction concise and friendly."
|
291 |
+
|
292 |
# Generate response
|
293 |
+
logger.info(f"Generating response: {str(datetime.now())}")
|
|
|
|
|
|
|
|
|
294 |
|
295 |
+
max_retries = 3
|
296 |
+
retry_count = 0
|
297 |
+
response = None
|
298 |
+
last_error = None
|
299 |
+
|
300 |
+
while retry_count < max_retries and response is None:
|
301 |
+
try:
|
302 |
+
response = await rag_agent.generate_response(
|
303 |
+
query=query,
|
304 |
+
conversation_id=conversation_id,
|
305 |
+
temperature=request.temperature,
|
306 |
+
max_tokens=request.max_tokens if hasattr(request, 'max_tokens') else None
|
307 |
+
)
|
308 |
+
break
|
309 |
+
except Exception as e:
|
310 |
+
last_error = e
|
311 |
+
retry_count += 1
|
312 |
+
logger.warning(f"Attempt {retry_count} failed: {str(e)}")
|
313 |
+
await asyncio.sleep(1) # Brief pause before retry
|
314 |
+
|
315 |
+
if response is None:
|
316 |
+
raise last_error or Exception("Failed to generate response after retries")
|
317 |
+
|
318 |
+
logger.info(f"Response generated: {str(datetime.now())}")
|
319 |
+
|
320 |
+
# Prepare response metadata
|
321 |
+
metadata = {
|
322 |
+
'llm_provider': request.llm_provider,
|
323 |
+
'temperature': request.temperature,
|
324 |
+
'conversation_id': conversation_id
|
325 |
+
}
|
326 |
+
|
327 |
+
# Add Excel-specific metadata if present
|
328 |
+
has_excel_content = any(
|
329 |
+
doc and 'Sheet:' in doc
|
330 |
+
for doc in (response.context_docs or [])
|
331 |
+
)
|
332 |
+
if has_excel_content:
|
333 |
+
try:
|
334 |
+
metadata['excel_content'] = True
|
335 |
+
|
336 |
+
# Extract Excel-specific insights if available
|
337 |
+
if hasattr(rag_agent, 'get_excel_insights'):
|
338 |
+
excel_insights = rag_agent.get_excel_insights(
|
339 |
+
query=query,
|
340 |
+
context_docs=response.context_docs
|
341 |
+
)
|
342 |
+
if excel_insights:
|
343 |
+
metadata['excel_insights'] = excel_insights
|
344 |
+
except Exception as e:
|
345 |
+
logger.warning(f"Error processing Excel metadata: {str(e)}")
|
346 |
+
|
347 |
# Store message in chat history
|
348 |
await mongodb.store_message(
|
349 |
conversation_id=conversation_id,
|
|
|
353 |
sources=response.sources,
|
354 |
llm_provider=request.llm_provider
|
355 |
)
|
356 |
+
|
357 |
+
# Prepare and return response
|
358 |
+
chat_response = ChatResponse(
|
359 |
response=response.response,
|
360 |
context=response.context_docs,
|
361 |
sources=response.sources,
|
362 |
conversation_id=conversation_id,
|
363 |
timestamp=datetime.now(),
|
364 |
+
relevant_doc_scores=response.scores if hasattr(response, 'scores') else None,
|
365 |
+
metadata=metadata
|
366 |
)
|
367 |
+
|
368 |
+
# Log completion
|
369 |
+
logger.info(f"Chat response completed: {str(datetime.now())}")
|
370 |
|
371 |
+
return chat_response
|
372 |
+
|
373 |
except Exception as e:
|
374 |
+
logger.error(f"Error in chat endpoint: {str(e)}", exc_info=True)
|
375 |
+
# Convert known errors to HTTPException with appropriate status codes
|
376 |
+
if isinstance(e, ValueError):
|
377 |
+
raise HTTPException(status_code=400, detail=str(e))
|
378 |
+
elif isinstance(e, (KeyError, AttributeError)):
|
379 |
+
raise HTTPException(status_code=500, detail="Internal processing error")
|
380 |
+
else:
|
381 |
+
raise HTTPException(status_code=500, detail=str(e))
|
382 |
|
383 |
@app.get("/chat/history/{conversation_id}")
|
384 |
async def get_conversation_history(conversation_id: str):
|
|
|
485 |
|
486 |
return debug_info
|
487 |
|
488 |
+
@app.post("/admin/cleanup")
|
489 |
+
async def cleanup_databases(
|
490 |
+
include_files: bool = True,
|
491 |
+
api_key: str = Depends(verify_api_key)
|
492 |
+
):
|
493 |
+
"""
|
494 |
+
Clean up all data from ChromaDB and MongoDB
|
495 |
+
|
496 |
+
Args:
|
497 |
+
include_files (bool): Whether to also delete uploaded files
|
498 |
+
"""
|
499 |
+
try:
|
500 |
+
result = await perform_cleanup(mongodb, include_files)
|
501 |
+
return result
|
502 |
+
except Exception as e:
|
503 |
+
logger.error(f"Error in cleanup operation: {str(e)}")
|
504 |
+
raise HTTPException(
|
505 |
+
status_code=500,
|
506 |
+
detail=f"Error during cleanup: {str(e)}"
|
507 |
+
)
|
508 |
+
|
509 |
@app.get("/health")
|
510 |
async def health_check():
|
511 |
"""Health check endpoint"""
|
src/models/UserContact.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseModel, EmailStr, validator
|
2 |
+
import re
|
3 |
+
|
4 |
+
class UserContactRequest(BaseModel):
|
5 |
+
"""Request model for user contact information"""
|
6 |
+
full_name: str
|
7 |
+
email: EmailStr
|
8 |
+
phone_number: str
|
9 |
+
|
10 |
+
@validator('phone_number')
|
11 |
+
def validate_phone(cls, v):
|
12 |
+
# Remove any non-digit characters
|
13 |
+
phone = re.sub(r'\D', '', v)
|
14 |
+
if not 8 <= len(phone) <= 15: # Standard phone number length globally
|
15 |
+
raise ValueError('Invalid phone number length')
|
16 |
+
return phone
|
17 |
+
|
18 |
+
@validator('full_name')
|
19 |
+
def validate_name(cls, v):
|
20 |
+
if not v.strip():
|
21 |
+
raise ValueError('Name cannot be empty')
|
22 |
+
return v.strip()
|
23 |
+
|
24 |
+
class UserContactResponse(BaseModel):
|
25 |
+
"""Response model for user contact endpoint"""
|
26 |
+
conversation_id: str
|
27 |
+
is_existing: bool
|
28 |
+
message: str
|
src/models/__pycache__/UserContact.cpython-312.pyc
ADDED
Binary file (1.79 kB). View file
|
|
src/utils/__pycache__/database_cleanup.cpython-312.pyc
ADDED
Binary file (7.04 kB). View file
|
|
src/utils/__pycache__/document_processor.cpython-312.pyc
CHANGED
Binary files a/src/utils/__pycache__/document_processor.cpython-312.pyc and b/src/utils/__pycache__/document_processor.cpython-312.pyc differ
|
|
src/utils/__pycache__/enhanced_excel_processor.cpython-312.pyc
ADDED
Binary file (10.5 kB). View file
|
|
src/utils/__pycache__/llm_utils.cpython-312.pyc
CHANGED
Binary files a/src/utils/__pycache__/llm_utils.cpython-312.pyc and b/src/utils/__pycache__/llm_utils.cpython-312.pyc differ
|
|
src/utils/database_cleanup.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# src/utils/database_cleanup.py
|
2 |
+
from typing import List, Dict
|
3 |
+
import chromadb
|
4 |
+
import shutil
|
5 |
+
from pathlib import Path
|
6 |
+
from src.utils.logger import logger
|
7 |
+
from config.config import settings
|
8 |
+
|
9 |
+
async def cleanup_chroma():
|
10 |
+
"""Clean up ChromaDB vector store"""
|
11 |
+
try:
|
12 |
+
# Initialize client with allow_reset=True
|
13 |
+
client = chromadb.PersistentClient(
|
14 |
+
path=settings.CHROMA_PATH,
|
15 |
+
settings=chromadb.Settings(
|
16 |
+
allow_reset=True,
|
17 |
+
is_persistent=True
|
18 |
+
)
|
19 |
+
)
|
20 |
+
|
21 |
+
# Get collection names
|
22 |
+
collection_names = client.list_collections()
|
23 |
+
|
24 |
+
# Delete each collection by name
|
25 |
+
for name in collection_names:
|
26 |
+
client.delete_collection(name)
|
27 |
+
|
28 |
+
# Reset client
|
29 |
+
client.reset()
|
30 |
+
|
31 |
+
# Remove persistence directory
|
32 |
+
path = Path(settings.CHROMA_PATH)
|
33 |
+
if path.exists():
|
34 |
+
shutil.rmtree(path)
|
35 |
+
|
36 |
+
return ["All vector store data cleared"]
|
37 |
+
except Exception as e:
|
38 |
+
raise Exception(f"ChromaDB cleanup failed: {str(e)}")
|
39 |
+
|
40 |
+
async def cleanup_mongodb(mongodb) -> List[str]:
|
41 |
+
"""
|
42 |
+
Clean up MongoDB collections
|
43 |
+
|
44 |
+
Args:
|
45 |
+
mongodb: MongoDB store instance
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
List[str]: Details of cleanup operations
|
49 |
+
"""
|
50 |
+
details = []
|
51 |
+
|
52 |
+
try:
|
53 |
+
# Drop all collections
|
54 |
+
await mongodb.chat_history.delete_many({})
|
55 |
+
details.append("Cleared chat history")
|
56 |
+
|
57 |
+
await mongodb.conversations.delete_many({})
|
58 |
+
details.append("Cleared conversations")
|
59 |
+
|
60 |
+
await mongodb.documents.delete_many({})
|
61 |
+
details.append("Cleared document metadata")
|
62 |
+
|
63 |
+
await mongodb.knowledge_base.delete_many({})
|
64 |
+
details.append("Cleared knowledge base")
|
65 |
+
|
66 |
+
if hasattr(mongodb.db, 'vector_metadata'):
|
67 |
+
await mongodb.db.vector_metadata.delete_many({})
|
68 |
+
details.append("Cleared vector metadata")
|
69 |
+
|
70 |
+
return details
|
71 |
+
except Exception as e:
|
72 |
+
raise Exception(f"MongoDB cleanup failed: {str(e)}")
|
73 |
+
|
74 |
+
async def cleanup_files() -> List[str]:
|
75 |
+
"""
|
76 |
+
Clean up uploaded files
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
List[str]: Details of cleanup operations
|
80 |
+
"""
|
81 |
+
details = []
|
82 |
+
uploads_dir = Path("uploads")
|
83 |
+
|
84 |
+
if uploads_dir.exists():
|
85 |
+
# Get list of files before deletion
|
86 |
+
files = list(uploads_dir.glob('*'))
|
87 |
+
|
88 |
+
# Delete all files
|
89 |
+
for file in files:
|
90 |
+
if file.is_file():
|
91 |
+
file.unlink()
|
92 |
+
details.append(f"Deleted file: {file.name}")
|
93 |
+
|
94 |
+
# Try to remove the directory itself
|
95 |
+
if not any(uploads_dir.iterdir()):
|
96 |
+
uploads_dir.rmdir()
|
97 |
+
details.append("Removed empty uploads directory")
|
98 |
+
else:
|
99 |
+
details.append("No uploads directory found")
|
100 |
+
|
101 |
+
return details
|
102 |
+
|
103 |
+
async def perform_cleanup(
|
104 |
+
mongodb,
|
105 |
+
include_files: bool = True
|
106 |
+
) -> Dict:
|
107 |
+
"""
|
108 |
+
Perform comprehensive cleanup of all databases
|
109 |
+
|
110 |
+
Args:
|
111 |
+
mongodb: MongoDB store instance
|
112 |
+
include_files (bool): Whether to also delete uploaded files
|
113 |
+
|
114 |
+
Returns:
|
115 |
+
Dict: Cleanup operation summary
|
116 |
+
"""
|
117 |
+
cleanup_summary = {
|
118 |
+
"chroma_db": {"status": "not_started", "details": []},
|
119 |
+
"mongodb": {"status": "not_started", "details": []},
|
120 |
+
"files": {"status": "not_started", "details": []}
|
121 |
+
}
|
122 |
+
|
123 |
+
try:
|
124 |
+
# Clean ChromaDB
|
125 |
+
try:
|
126 |
+
details = await cleanup_chroma()
|
127 |
+
cleanup_summary["chroma_db"] = {
|
128 |
+
"status": "success",
|
129 |
+
"details": details
|
130 |
+
}
|
131 |
+
except Exception as e:
|
132 |
+
logger.error(f"Error cleaning ChromaDB: {str(e)}")
|
133 |
+
cleanup_summary["chroma_db"] = {
|
134 |
+
"status": "error",
|
135 |
+
"details": [str(e)]
|
136 |
+
}
|
137 |
+
|
138 |
+
# Clean MongoDB
|
139 |
+
try:
|
140 |
+
details = await cleanup_mongodb(mongodb)
|
141 |
+
cleanup_summary["mongodb"] = {
|
142 |
+
"status": "success",
|
143 |
+
"details": details
|
144 |
+
}
|
145 |
+
except Exception as e:
|
146 |
+
logger.error(f"Error cleaning MongoDB: {str(e)}")
|
147 |
+
cleanup_summary["mongodb"] = {
|
148 |
+
"status": "error",
|
149 |
+
"details": [str(e)]
|
150 |
+
}
|
151 |
+
|
152 |
+
# Clean files if requested
|
153 |
+
if include_files:
|
154 |
+
try:
|
155 |
+
details = await cleanup_files()
|
156 |
+
cleanup_summary["files"] = {
|
157 |
+
"status": "success",
|
158 |
+
"details": details
|
159 |
+
}
|
160 |
+
except Exception as e:
|
161 |
+
logger.error(f"Error cleaning files: {str(e)}")
|
162 |
+
cleanup_summary["files"] = {
|
163 |
+
"status": "error",
|
164 |
+
"details": [str(e)]
|
165 |
+
}
|
166 |
+
|
167 |
+
# Determine overall status
|
168 |
+
overall_status = "success"
|
169 |
+
if any(item["status"] == "error" for item in cleanup_summary.values()):
|
170 |
+
overall_status = "partial_success"
|
171 |
+
if all(item["status"] == "error" for item in cleanup_summary.values()):
|
172 |
+
overall_status = "error"
|
173 |
+
|
174 |
+
return {
|
175 |
+
"status": overall_status,
|
176 |
+
"message": "Cleanup operation completed",
|
177 |
+
"details": cleanup_summary
|
178 |
+
}
|
179 |
+
|
180 |
+
except Exception as e:
|
181 |
+
logger.error(f"Error in cleanup operation: {str(e)}")
|
182 |
+
raise
|
src/utils/document_processor.py
CHANGED
@@ -8,13 +8,15 @@ from pathlib import Path
|
|
8 |
import hashlib
|
9 |
import magic # python-magic library for file type detection
|
10 |
from bs4 import BeautifulSoup
|
11 |
-
import requests
|
12 |
import csv
|
13 |
from datetime import datetime
|
14 |
import threading
|
15 |
from queue import Queue
|
16 |
import tiktoken
|
17 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
|
|
|
|
|
18 |
|
19 |
class DocumentProcessor:
|
20 |
def __init__(
|
@@ -29,11 +31,26 @@ class DocumentProcessor:
|
|
29 |
self.max_file_size = max_file_size
|
30 |
self.supported_formats = supported_formats or [
|
31 |
'.txt', '.pdf', '.docx', '.csv', '.json',
|
32 |
-
'.html', '.md', '.xml', '.rtf'
|
33 |
]
|
34 |
self.processing_queue = Queue()
|
35 |
self.processed_docs = {}
|
36 |
self._initialize_text_splitter()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
def _initialize_text_splitter(self):
|
39 |
"""Initialize the text splitter with custom settings"""
|
@@ -44,65 +61,10 @@ class DocumentProcessor:
|
|
44 |
separators=["\n\n", "\n", " ", ""]
|
45 |
)
|
46 |
|
47 |
-
async def process_document(
|
48 |
-
self,
|
49 |
-
file_path: Union[str, Path],
|
50 |
-
metadata: Optional[Dict] = None
|
51 |
-
) -> Dict:
|
52 |
-
"""
|
53 |
-
Process a document with metadata and content extraction
|
54 |
-
"""
|
55 |
-
file_path = Path(file_path)
|
56 |
-
|
57 |
-
# Basic validation
|
58 |
-
if not self._validate_file(file_path):
|
59 |
-
raise ValueError(f"Invalid file: {file_path}")
|
60 |
-
|
61 |
-
# Extract content based on file type
|
62 |
-
content = self._extract_content(file_path)
|
63 |
-
|
64 |
-
# Generate document metadata
|
65 |
-
doc_metadata = self._generate_metadata(file_path, content, metadata)
|
66 |
-
|
67 |
-
# Split content into chunks
|
68 |
-
chunks = self.text_splitter.split_text(content)
|
69 |
-
|
70 |
-
# Calculate embeddings chunk hashes
|
71 |
-
chunk_hashes = [self._calculate_hash(chunk) for chunk in chunks]
|
72 |
-
|
73 |
-
return {
|
74 |
-
'content': content,
|
75 |
-
'chunks': chunks,
|
76 |
-
'chunk_hashes': chunk_hashes,
|
77 |
-
'metadata': doc_metadata,
|
78 |
-
'statistics': self._generate_statistics(content, chunks)
|
79 |
-
}
|
80 |
-
|
81 |
-
def _validate_file(self, file_path: Path) -> bool:
|
82 |
-
"""
|
83 |
-
Validate file type, size, and content
|
84 |
-
"""
|
85 |
-
if not file_path.exists():
|
86 |
-
raise FileNotFoundError(f"File not found: {file_path}")
|
87 |
-
|
88 |
-
if file_path.suffix.lower() not in self.supported_formats:
|
89 |
-
raise ValueError(f"Unsupported file format: {file_path.suffix}")
|
90 |
-
|
91 |
-
if file_path.stat().st_size > self.max_file_size:
|
92 |
-
raise ValueError(f"File too large: {file_path}")
|
93 |
-
|
94 |
-
# Check if file is not empty
|
95 |
-
if file_path.stat().st_size == 0:
|
96 |
-
raise ValueError(f"Empty file: {file_path}")
|
97 |
-
|
98 |
-
return True
|
99 |
-
|
100 |
def _extract_content(self, file_path: Path) -> str:
|
101 |
-
"""
|
102 |
-
Extract content from different file formats
|
103 |
-
"""
|
104 |
suffix = file_path.suffix.lower()
|
105 |
-
|
106 |
try:
|
107 |
if suffix == '.pdf':
|
108 |
return self._extract_pdf(file_path)
|
@@ -114,13 +76,28 @@ class DocumentProcessor:
|
|
114 |
return self._extract_json(file_path)
|
115 |
elif suffix == '.html':
|
116 |
return self._extract_html(file_path)
|
117 |
-
elif suffix == '.txt':
|
118 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
else:
|
120 |
raise ValueError(f"Unsupported format: {suffix}")
|
121 |
except Exception as e:
|
122 |
raise Exception(f"Error extracting content from {file_path}: {str(e)}")
|
123 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
def _extract_pdf(self, file_path: Path) -> str:
|
125 |
"""Extract text from PDF with advanced features"""
|
126 |
text = ""
|
@@ -135,7 +112,6 @@ class DocumentProcessor:
|
|
135 |
if '/XObject' in page['/Resources']:
|
136 |
for obj in page['/Resources']['/XObject'].get_object():
|
137 |
if page['/Resources']['/XObject'][obj]['/Subtype'] == '/Image':
|
138 |
-
# Process images if needed
|
139 |
pass
|
140 |
|
141 |
return text.strip()
|
@@ -148,7 +124,6 @@ class DocumentProcessor:
|
|
148 |
for para in doc.paragraphs:
|
149 |
full_text.append(para.text)
|
150 |
|
151 |
-
# Extract tables if present
|
152 |
for table in doc.tables:
|
153 |
for row in table.rows:
|
154 |
row_text = [cell.text for cell in row.cells]
|
@@ -172,7 +147,6 @@ class DocumentProcessor:
|
|
172 |
with open(file_path) as f:
|
173 |
soup = BeautifulSoup(f, 'html.parser')
|
174 |
|
175 |
-
# Remove script and style elements
|
176 |
for script in soup(["script", "style"]):
|
177 |
script.decompose()
|
178 |
|
@@ -180,6 +154,83 @@ class DocumentProcessor:
|
|
180 |
lines = [line.strip() for line in text.splitlines() if line.strip()]
|
181 |
return "\n\n".join(lines)
|
182 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
def _generate_metadata(
|
184 |
self,
|
185 |
file_path: Path,
|
@@ -202,11 +253,64 @@ class DocumentProcessor:
|
|
202 |
'processing_timestamp': datetime.now().isoformat()
|
203 |
}
|
204 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
205 |
if additional_metadata:
|
206 |
metadata.update(additional_metadata)
|
207 |
|
208 |
return metadata
|
209 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
210 |
def _generate_statistics(self, content: str, chunks: List[str]) -> Dict:
|
211 |
"""Generate document statistics"""
|
212 |
return {
|
@@ -217,18 +321,12 @@ class DocumentProcessor:
|
|
217 |
'sentences': len([s for s in content.split('.') if s.strip()]),
|
218 |
}
|
219 |
|
220 |
-
def _calculate_hash(self, text: str) -> str:
|
221 |
-
"""Calculate SHA-256 hash of text"""
|
222 |
-
return hashlib.sha256(text.encode()).hexdigest()
|
223 |
-
|
224 |
async def batch_process(
|
225 |
self,
|
226 |
file_paths: List[Union[str, Path]],
|
227 |
parallel: bool = True
|
228 |
) -> Dict[str, Dict]:
|
229 |
-
"""
|
230 |
-
Process multiple documents in parallel
|
231 |
-
"""
|
232 |
results = {}
|
233 |
|
234 |
if parallel:
|
|
|
8 |
import hashlib
|
9 |
import magic # python-magic library for file type detection
|
10 |
from bs4 import BeautifulSoup
|
|
|
11 |
import csv
|
12 |
from datetime import datetime
|
13 |
import threading
|
14 |
from queue import Queue
|
15 |
import tiktoken
|
16 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
17 |
+
import logging
|
18 |
+
from bs4.element import ProcessingInstruction
|
19 |
+
from .enhanced_excel_processor import EnhancedExcelProcessor
|
20 |
|
21 |
class DocumentProcessor:
|
22 |
def __init__(
|
|
|
31 |
self.max_file_size = max_file_size
|
32 |
self.supported_formats = supported_formats or [
|
33 |
'.txt', '.pdf', '.docx', '.csv', '.json',
|
34 |
+
'.html', '.md', '.xml', '.rtf', '.xlsx', '.xls'
|
35 |
]
|
36 |
self.processing_queue = Queue()
|
37 |
self.processed_docs = {}
|
38 |
self._initialize_text_splitter()
|
39 |
+
|
40 |
+
# Initialize Excel processor
|
41 |
+
self.excel_processor = EnhancedExcelProcessor()
|
42 |
+
|
43 |
+
# Check for required packages
|
44 |
+
try:
|
45 |
+
import striprtf.striprtf
|
46 |
+
except ImportError:
|
47 |
+
logging.warning("Warning: striprtf package not found. RTF support will be limited.")
|
48 |
+
|
49 |
+
try:
|
50 |
+
from bs4 import BeautifulSoup
|
51 |
+
import lxml
|
52 |
+
except ImportError:
|
53 |
+
logging.warning("Warning: beautifulsoup4 or lxml package not found. XML support will be limited.")
|
54 |
|
55 |
def _initialize_text_splitter(self):
|
56 |
"""Initialize the text splitter with custom settings"""
|
|
|
61 |
separators=["\n\n", "\n", " ", ""]
|
62 |
)
|
63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
def _extract_content(self, file_path: Path) -> str:
|
65 |
+
"""Extract content from different file formats"""
|
|
|
|
|
66 |
suffix = file_path.suffix.lower()
|
67 |
+
|
68 |
try:
|
69 |
if suffix == '.pdf':
|
70 |
return self._extract_pdf(file_path)
|
|
|
76 |
return self._extract_json(file_path)
|
77 |
elif suffix == '.html':
|
78 |
return self._extract_html(file_path)
|
79 |
+
elif suffix == '.txt' or suffix == '.md':
|
80 |
+
return self._extract_text(file_path)
|
81 |
+
elif suffix == '.xml':
|
82 |
+
return self._extract_xml(file_path)
|
83 |
+
elif suffix == '.rtf':
|
84 |
+
return self._extract_rtf(file_path)
|
85 |
+
elif suffix in ['.xlsx', '.xls']:
|
86 |
+
return self._extract_excel(file_path)
|
87 |
else:
|
88 |
raise ValueError(f"Unsupported format: {suffix}")
|
89 |
except Exception as e:
|
90 |
raise Exception(f"Error extracting content from {file_path}: {str(e)}")
|
91 |
|
92 |
+
def _extract_text(self, file_path: Path) -> str:
|
93 |
+
"""Extract content from text-based files"""
|
94 |
+
try:
|
95 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
96 |
+
return f.read()
|
97 |
+
except UnicodeDecodeError:
|
98 |
+
with open(file_path, 'r', encoding='latin-1') as f:
|
99 |
+
return f.read()
|
100 |
+
|
101 |
def _extract_pdf(self, file_path: Path) -> str:
|
102 |
"""Extract text from PDF with advanced features"""
|
103 |
text = ""
|
|
|
112 |
if '/XObject' in page['/Resources']:
|
113 |
for obj in page['/Resources']['/XObject'].get_object():
|
114 |
if page['/Resources']['/XObject'][obj]['/Subtype'] == '/Image':
|
|
|
115 |
pass
|
116 |
|
117 |
return text.strip()
|
|
|
124 |
for para in doc.paragraphs:
|
125 |
full_text.append(para.text)
|
126 |
|
|
|
127 |
for table in doc.tables:
|
128 |
for row in table.rows:
|
129 |
row_text = [cell.text for cell in row.cells]
|
|
|
147 |
with open(file_path) as f:
|
148 |
soup = BeautifulSoup(f, 'html.parser')
|
149 |
|
|
|
150 |
for script in soup(["script", "style"]):
|
151 |
script.decompose()
|
152 |
|
|
|
154 |
lines = [line.strip() for line in text.splitlines() if line.strip()]
|
155 |
return "\n\n".join(lines)
|
156 |
|
157 |
+
def _extract_xml(self, file_path: Path) -> str:
|
158 |
+
"""Extract text from XML with structure preservation"""
|
159 |
+
try:
|
160 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
161 |
+
soup = BeautifulSoup(f, 'xml')
|
162 |
+
|
163 |
+
for pi in soup.find_all(text=lambda text: isinstance(text, ProcessingInstruction)):
|
164 |
+
pi.extract()
|
165 |
+
|
166 |
+
text = soup.get_text(separator='\n')
|
167 |
+
lines = [line.strip() for line in text.splitlines() if line.strip()]
|
168 |
+
return "\n\n".join(lines)
|
169 |
+
except Exception as e:
|
170 |
+
raise Exception(f"Error processing XML file: {str(e)}")
|
171 |
+
|
172 |
+
def _extract_rtf(self, file_path: Path) -> str:
|
173 |
+
"""Extract text from RTF files"""
|
174 |
+
try:
|
175 |
+
import striprtf.striprtf as striprtf
|
176 |
+
|
177 |
+
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
|
178 |
+
rtf_text = f.read()
|
179 |
+
|
180 |
+
plain_text = striprtf.rtf_to_text(rtf_text)
|
181 |
+
lines = [line.strip() for line in plain_text.splitlines() if line.strip()]
|
182 |
+
return "\n\n".join(lines)
|
183 |
+
except ImportError:
|
184 |
+
raise ImportError("striprtf package is required for RTF support.")
|
185 |
+
except Exception as e:
|
186 |
+
raise Exception(f"Error processing RTF file: {str(e)}")
|
187 |
+
|
188 |
+
def _extract_excel(self, file_path: Path) -> str:
|
189 |
+
"""Extract content from Excel files with enhanced processing"""
|
190 |
+
try:
|
191 |
+
# Use enhanced Excel processor
|
192 |
+
processed_content = self.excel_processor.process_excel(file_path)
|
193 |
+
|
194 |
+
# If processing fails, fall back to basic processing
|
195 |
+
if not processed_content:
|
196 |
+
logging.warning(f"Enhanced Excel processing failed for {file_path}, falling back to basic processing")
|
197 |
+
return self._basic_excel_extract(file_path)
|
198 |
+
|
199 |
+
return processed_content
|
200 |
+
|
201 |
+
except Exception as e:
|
202 |
+
logging.error(f"Error in enhanced Excel processing: {str(e)}")
|
203 |
+
# Fall back to basic Excel processing
|
204 |
+
return self._basic_excel_extract(file_path)
|
205 |
+
|
206 |
+
def _basic_excel_extract(self, file_path: Path) -> str:
|
207 |
+
"""Basic Excel extraction as fallback"""
|
208 |
+
try:
|
209 |
+
excel_file = pd.ExcelFile(file_path)
|
210 |
+
sheets_data = []
|
211 |
+
|
212 |
+
for sheet_name in excel_file.sheet_names:
|
213 |
+
df = pd.read_excel(excel_file, sheet_name=sheet_name)
|
214 |
+
sheet_content = f"\nSheet: {sheet_name}\n"
|
215 |
+
sheet_content += "=" * (len(sheet_name) + 7) + "\n"
|
216 |
+
|
217 |
+
if df.empty:
|
218 |
+
sheet_content += "Empty Sheet\n"
|
219 |
+
else:
|
220 |
+
sheet_content += df.fillna('').to_string(
|
221 |
+
index=False,
|
222 |
+
max_rows=None,
|
223 |
+
max_cols=None,
|
224 |
+
line_width=120
|
225 |
+
) + "\n"
|
226 |
+
|
227 |
+
sheets_data.append(sheet_content)
|
228 |
+
|
229 |
+
return "\n\n".join(sheets_data)
|
230 |
+
|
231 |
+
except Exception as e:
|
232 |
+
raise Exception(f"Error in basic Excel processing: {str(e)}")
|
233 |
+
|
234 |
def _generate_metadata(
|
235 |
self,
|
236 |
file_path: Path,
|
|
|
253 |
'processing_timestamp': datetime.now().isoformat()
|
254 |
}
|
255 |
|
256 |
+
# Add Excel-specific metadata if applicable
|
257 |
+
if file_path.suffix.lower() in ['.xlsx', '.xls']:
|
258 |
+
try:
|
259 |
+
if hasattr(self.excel_processor, 'get_metadata'):
|
260 |
+
excel_metadata = self.excel_processor.get_metadata()
|
261 |
+
metadata.update({'excel_metadata': excel_metadata})
|
262 |
+
except Exception as e:
|
263 |
+
logging.warning(f"Could not extract Excel metadata: {str(e)}")
|
264 |
+
|
265 |
if additional_metadata:
|
266 |
metadata.update(additional_metadata)
|
267 |
|
268 |
return metadata
|
269 |
|
270 |
+
def _calculate_hash(self, text: str) -> str:
|
271 |
+
"""Calculate SHA-256 hash of text"""
|
272 |
+
return hashlib.sha256(text.encode()).hexdigest()
|
273 |
+
|
274 |
+
async def process_document(
|
275 |
+
self,
|
276 |
+
file_path: Union[str, Path],
|
277 |
+
metadata: Optional[Dict] = None
|
278 |
+
) -> Dict:
|
279 |
+
"""Process a document with metadata and content extraction"""
|
280 |
+
file_path = Path(file_path)
|
281 |
+
|
282 |
+
if not self._validate_file(file_path):
|
283 |
+
raise ValueError(f"Invalid file: {file_path}")
|
284 |
+
|
285 |
+
content = self._extract_content(file_path)
|
286 |
+
doc_metadata = self._generate_metadata(file_path, content, metadata)
|
287 |
+
chunks = self.text_splitter.split_text(content)
|
288 |
+
chunk_hashes = [self._calculate_hash(chunk) for chunk in chunks]
|
289 |
+
|
290 |
+
return {
|
291 |
+
'content': content,
|
292 |
+
'chunks': chunks,
|
293 |
+
'chunk_hashes': chunk_hashes,
|
294 |
+
'metadata': doc_metadata,
|
295 |
+
'statistics': self._generate_statistics(content, chunks)
|
296 |
+
}
|
297 |
+
|
298 |
+
def _validate_file(self, file_path: Path) -> bool:
|
299 |
+
"""Validate file type, size, and content"""
|
300 |
+
if not file_path.exists():
|
301 |
+
raise FileNotFoundError(f"File not found: {file_path}")
|
302 |
+
|
303 |
+
if file_path.suffix.lower() not in self.supported_formats:
|
304 |
+
raise ValueError(f"Unsupported file format: {file_path.suffix}")
|
305 |
+
|
306 |
+
if file_path.stat().st_size > self.max_file_size:
|
307 |
+
raise ValueError(f"File too large: {file_path}")
|
308 |
+
|
309 |
+
if file_path.stat().st_size == 0:
|
310 |
+
raise ValueError(f"Empty file: {file_path}")
|
311 |
+
|
312 |
+
return True
|
313 |
+
|
314 |
def _generate_statistics(self, content: str, chunks: List[str]) -> Dict:
|
315 |
"""Generate document statistics"""
|
316 |
return {
|
|
|
321 |
'sentences': len([s for s in content.split('.') if s.strip()]),
|
322 |
}
|
323 |
|
|
|
|
|
|
|
|
|
324 |
async def batch_process(
|
325 |
self,
|
326 |
file_paths: List[Union[str, Path]],
|
327 |
parallel: bool = True
|
328 |
) -> Dict[str, Dict]:
|
329 |
+
"""Process multiple documents in parallel"""
|
|
|
|
|
330 |
results = {}
|
331 |
|
332 |
if parallel:
|
src/utils/enhanced_excel_processor.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Any, Optional
|
2 |
+
import pandas as pd
|
3 |
+
import numpy as np
|
4 |
+
from pathlib import Path
|
5 |
+
import json
|
6 |
+
|
7 |
+
class EnhancedExcelProcessor:
|
8 |
+
def __init__(self):
|
9 |
+
"""Initialize the enhanced Excel processor"""
|
10 |
+
self.sheet_summaries = {}
|
11 |
+
self.relationships = {}
|
12 |
+
self.sheet_metadata = {}
|
13 |
+
|
14 |
+
def process_excel(self, file_path: Path) -> str:
|
15 |
+
"""
|
16 |
+
Process Excel file with enhanced multi-sheet handling
|
17 |
+
|
18 |
+
Args:
|
19 |
+
file_path (Path): Path to Excel file
|
20 |
+
|
21 |
+
Returns:
|
22 |
+
str: Structured text representation of Excel content
|
23 |
+
"""
|
24 |
+
# Read all sheets
|
25 |
+
excel_file = pd.ExcelFile(file_path)
|
26 |
+
sheets_data = {}
|
27 |
+
|
28 |
+
for sheet_name in excel_file.sheet_names:
|
29 |
+
df = pd.read_excel(excel_file, sheet_name=sheet_name)
|
30 |
+
sheets_data[sheet_name] = df
|
31 |
+
|
32 |
+
# Generate sheet summary
|
33 |
+
self.sheet_summaries[sheet_name] = self._generate_sheet_summary(df)
|
34 |
+
|
35 |
+
# Extract sheet metadata
|
36 |
+
self.sheet_metadata[sheet_name] = {
|
37 |
+
'columns': list(df.columns),
|
38 |
+
'rows': len(df),
|
39 |
+
'numeric_columns': df.select_dtypes(include=[np.number]).columns.tolist(),
|
40 |
+
'date_columns': df.select_dtypes(include=['datetime64']).columns.tolist(),
|
41 |
+
'categorical_columns': df.select_dtypes(include=['object']).columns.tolist()
|
42 |
+
}
|
43 |
+
|
44 |
+
# Detect relationships between sheets
|
45 |
+
self.relationships = self._detect_relationships(sheets_data)
|
46 |
+
|
47 |
+
# Generate structured text representation
|
48 |
+
return self._generate_structured_text(sheets_data)
|
49 |
+
|
50 |
+
def _generate_sheet_summary(self, df: pd.DataFrame) -> Dict:
|
51 |
+
"""Generate statistical summary for a sheet"""
|
52 |
+
summary = {
|
53 |
+
'total_rows': len(df),
|
54 |
+
'total_columns': len(df.columns),
|
55 |
+
'column_types': {},
|
56 |
+
'numeric_summaries': {},
|
57 |
+
'categorical_summaries': {},
|
58 |
+
'null_counts': df.isnull().sum().to_dict()
|
59 |
+
}
|
60 |
+
|
61 |
+
# Process numeric columns
|
62 |
+
numeric_cols = df.select_dtypes(include=[np.number]).columns
|
63 |
+
for col in numeric_cols:
|
64 |
+
summary['numeric_summaries'][col] = {
|
65 |
+
'mean': float(df[col].mean()),
|
66 |
+
'median': float(df[col].median()),
|
67 |
+
'std': float(df[col].std()),
|
68 |
+
'min': float(df[col].min()),
|
69 |
+
'max': float(df[col].max())
|
70 |
+
}
|
71 |
+
summary['column_types'][col] = 'numeric'
|
72 |
+
|
73 |
+
# Process categorical columns
|
74 |
+
categorical_cols = df.select_dtypes(include=['object']).columns
|
75 |
+
for col in categorical_cols:
|
76 |
+
value_counts = df[col].value_counts()
|
77 |
+
summary['categorical_summaries'][col] = {
|
78 |
+
'unique_values': int(len(value_counts)),
|
79 |
+
'top_values': value_counts.head(5).to_dict()
|
80 |
+
}
|
81 |
+
summary['column_types'][col] = 'categorical'
|
82 |
+
|
83 |
+
return summary
|
84 |
+
|
85 |
+
def _detect_relationships(self, sheets_data: Dict[str, pd.DataFrame]) -> Dict:
|
86 |
+
"""Detect potential relationships between sheets"""
|
87 |
+
relationships = {}
|
88 |
+
sheet_names = list(sheets_data.keys())
|
89 |
+
|
90 |
+
for i, sheet1 in enumerate(sheet_names):
|
91 |
+
for sheet2 in sheet_names[i+1:]:
|
92 |
+
common_cols = set(sheets_data[sheet1].columns) & set(sheets_data[sheet2].columns)
|
93 |
+
if common_cols:
|
94 |
+
relationships[f"{sheet1}__{sheet2}"] = {
|
95 |
+
'common_columns': list(common_cols),
|
96 |
+
'type': 'potential_join'
|
97 |
+
}
|
98 |
+
|
99 |
+
# Check for foreign key relationships
|
100 |
+
for col1 in sheets_data[sheet1].columns:
|
101 |
+
for col2 in sheets_data[sheet2].columns:
|
102 |
+
if (col1.lower().endswith('_id') or col2.lower().endswith('_id')):
|
103 |
+
unique_vals1 = set(sheets_data[sheet1][col1].dropna())
|
104 |
+
unique_vals2 = set(sheets_data[sheet2][col2].dropna())
|
105 |
+
if unique_vals1 & unique_vals2:
|
106 |
+
relationships[f"{sheet1}__{sheet2}__{col1}__{col2}"] = {
|
107 |
+
'type': 'foreign_key',
|
108 |
+
'columns': [col1, col2]
|
109 |
+
}
|
110 |
+
|
111 |
+
return relationships
|
112 |
+
|
113 |
+
def _generate_structured_text(self, sheets_data: Dict[str, pd.DataFrame]) -> str:
|
114 |
+
"""Generate structured text representation of Excel content"""
|
115 |
+
output_parts = []
|
116 |
+
|
117 |
+
# Overall summary
|
118 |
+
output_parts.append(f"Excel File Overview:")
|
119 |
+
output_parts.append(f"Total Sheets: {len(sheets_data)}")
|
120 |
+
output_parts.append("")
|
121 |
+
|
122 |
+
# Sheet details
|
123 |
+
for sheet_name, df in sheets_data.items():
|
124 |
+
output_parts.append(f"Sheet: {sheet_name}")
|
125 |
+
output_parts.append("=" * (len(sheet_name) + 7))
|
126 |
+
|
127 |
+
metadata = self.sheet_metadata[sheet_name]
|
128 |
+
summary = self.sheet_summaries[sheet_name]
|
129 |
+
|
130 |
+
# Basic info
|
131 |
+
output_parts.append(f"Rows: {metadata['rows']}")
|
132 |
+
output_parts.append(f"Columns: {', '.join(metadata['columns'])}")
|
133 |
+
output_parts.append("")
|
134 |
+
|
135 |
+
# Column summaries
|
136 |
+
if metadata['numeric_columns']:
|
137 |
+
output_parts.append("Numeric Columns Summary:")
|
138 |
+
for col in metadata['numeric_columns']:
|
139 |
+
stats = summary['numeric_summaries'][col]
|
140 |
+
output_parts.append(f" {col}:")
|
141 |
+
output_parts.append(f" Range: {stats['min']} to {stats['max']}")
|
142 |
+
output_parts.append(f" Average: {stats['mean']:.2f}")
|
143 |
+
output_parts.append("")
|
144 |
+
|
145 |
+
if metadata['categorical_columns']:
|
146 |
+
output_parts.append("Categorical Columns Summary:")
|
147 |
+
for col in metadata['categorical_columns']:
|
148 |
+
cats = summary['categorical_summaries'][col]
|
149 |
+
output_parts.append(f" {col}:")
|
150 |
+
output_parts.append(f" Unique Values: {cats['unique_values']}")
|
151 |
+
if cats['top_values']:
|
152 |
+
output_parts.append(" Top Values: " +
|
153 |
+
", ".join(f"{k} ({v})" for k, v in
|
154 |
+
list(cats['top_values'].items())[:3]))
|
155 |
+
output_parts.append("")
|
156 |
+
|
157 |
+
# Sample data
|
158 |
+
output_parts.append("Sample Data:")
|
159 |
+
output_parts.append(df.head(3).to_string())
|
160 |
+
output_parts.append("\n")
|
161 |
+
|
162 |
+
# Relationships
|
163 |
+
if self.relationships:
|
164 |
+
output_parts.append("Sheet Relationships:")
|
165 |
+
for rel_key, rel_info in self.relationships.items():
|
166 |
+
if rel_info['type'] == 'potential_join':
|
167 |
+
sheets = rel_key.split('__')
|
168 |
+
output_parts.append(f"- {sheets[0]} and {sheets[1]} share columns: " +
|
169 |
+
f"{', '.join(rel_info['common_columns'])}")
|
170 |
+
elif rel_info['type'] == 'foreign_key':
|
171 |
+
parts = rel_key.split('__')
|
172 |
+
output_parts.append(f"- Potential foreign key relationship between " +
|
173 |
+
f"{parts[0]}.{parts[2]} and {parts[1]}.{parts[3]}")
|
174 |
+
|
175 |
+
return "\n".join(output_parts)
|
176 |
+
|
177 |
+
def get_sheet_summary(self, sheet_name: str) -> Optional[Dict]:
|
178 |
+
"""Get summary for a specific sheet"""
|
179 |
+
return self.sheet_summaries.get(sheet_name)
|
180 |
+
|
181 |
+
def get_relationships(self) -> Dict:
|
182 |
+
"""Get detected relationships between sheets"""
|
183 |
+
return self.relationships
|
184 |
+
|
185 |
+
def get_metadata(self) -> Dict:
|
186 |
+
"""Get complete metadata for all sheets"""
|
187 |
+
return self.sheet_metadata
|
src/utils/excel_integration
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Any
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
class ExcelIntegration:
|
5 |
+
def __init__(self, enhanced_processor):
|
6 |
+
"""
|
7 |
+
Initialize Excel integration
|
8 |
+
|
9 |
+
Args:
|
10 |
+
enhanced_processor: Instance of EnhancedExcelProcessor
|
11 |
+
"""
|
12 |
+
self.processor = enhanced_processor
|
13 |
+
|
14 |
+
def process_for_rag(self, file_path: Path) -> Dict[str, Any]:
|
15 |
+
"""
|
16 |
+
Process Excel file for RAG system
|
17 |
+
|
18 |
+
Args:
|
19 |
+
file_path (Path): Path to Excel file
|
20 |
+
|
21 |
+
Returns:
|
22 |
+
Dict[str, Any]: Processed content and metadata
|
23 |
+
"""
|
24 |
+
# Process Excel file
|
25 |
+
content = self.processor.process_excel(file_path)
|
26 |
+
|
27 |
+
# Get all metadata
|
28 |
+
metadata = {
|
29 |
+
'sheet_summaries': self.processor.sheet_summaries,
|
30 |
+
'relationships': self.processor.relationships,
|
31 |
+
'sheet_metadata': self.processor.sheet_metadata
|
32 |
+
}
|
33 |
+
|
34 |
+
# Create chunks based on logical divisions
|
35 |
+
chunks = self._create_semantic_chunks(content)
|
36 |
+
|
37 |
+
return {
|
38 |
+
'content': content,
|
39 |
+
'chunks': chunks,
|
40 |
+
'metadata': metadata
|
41 |
+
}
|
42 |
+
|
43 |
+
def _create_semantic_chunks(self, content: str) -> list:
|
44 |
+
"""
|
45 |
+
Create meaningful chunks from Excel content
|
46 |
+
|
47 |
+
Args:
|
48 |
+
content (str): Processed Excel content
|
49 |
+
|
50 |
+
Returns:
|
51 |
+
list: List of content chunks
|
52 |
+
"""
|
53 |
+
chunks = []
|
54 |
+
current_chunk = []
|
55 |
+
current_sheet = None
|
56 |
+
|
57 |
+
for line in content.split('\n'):
|
58 |
+
# Start new chunk for each sheet
|
59 |
+
if line.startswith('Sheet: '):
|
60 |
+
if current_chunk:
|
61 |
+
chunks.append('\n'.join(current_chunk))
|
62 |
+
current_chunk = []
|
63 |
+
current_sheet = line
|
64 |
+
current_chunk.append(line)
|
65 |
+
|
66 |
+
# Start new chunk for major sections within sheet
|
67 |
+
elif any(line.startswith(section) for section in
|
68 |
+
['Numeric Columns Summary:', 'Categorical Columns Summary:',
|
69 |
+
'Sample Data:', 'Sheet Relationships:']):
|
70 |
+
if current_chunk:
|
71 |
+
chunks.append('\n'.join(current_chunk))
|
72 |
+
current_chunk = []
|
73 |
+
if current_sheet:
|
74 |
+
current_chunk.append(current_sheet)
|
75 |
+
current_chunk.append(line)
|
76 |
+
|
77 |
+
else:
|
78 |
+
current_chunk.append(line)
|
79 |
+
|
80 |
+
# Add final chunk
|
81 |
+
if current_chunk:
|
82 |
+
chunks.append('\n'.join(current_chunk))
|
83 |
+
|
84 |
+
return chunks
|
85 |
+
|
86 |
+
def get_sheet_context(self, sheet_name: str) -> str:
|
87 |
+
"""
|
88 |
+
Get specific context for a sheet
|
89 |
+
|
90 |
+
Args:
|
91 |
+
sheet_name (str): Name of the sheet
|
92 |
+
|
93 |
+
Returns:
|
94 |
+
str: Contextual information about the sheet
|
95 |
+
"""
|
96 |
+
if sheet_name not in self.processor.sheet_metadata:
|
97 |
+
return ""
|
98 |
+
|
99 |
+
metadata = self.processor.sheet_metadata[sheet_name]
|
100 |
+
summary = self.processor.sheet_summaries[sheet_name]
|
101 |
+
|
102 |
+
context_parts = [
|
103 |
+
f"Sheet: {sheet_name}",
|
104 |
+
f"Total Rows: {metadata['rows']}",
|
105 |
+
f"Columns: {', '.join(metadata['columns'])}",
|
106 |
+
]
|
107 |
+
|
108 |
+
# Add numeric column summaries
|
109 |
+
if metadata['numeric_columns']:
|
110 |
+
context_parts.append("\nNumeric Columns:")
|
111 |
+
for col in metadata['numeric_columns']:
|
112 |
+
stats = summary['numeric_summaries'][col]
|
113 |
+
context_parts.append(f"- {col}: Range {stats['min']} to {stats['max']}, "
|
114 |
+
f"Average {stats['mean']:.2f}")
|
115 |
+
|
116 |
+
# Add categorical column summaries
|
117 |
+
if metadata['categorical_columns']:
|
118 |
+
context_parts.append("\nCategorical Columns:")
|
119 |
+
for col in metadata['categorical_columns']:
|
120 |
+
cats = summary['categorical_summaries'][col]
|
121 |
+
context_parts.append(f"- {col}: {cats['unique_values']} unique values")
|
122 |
+
|
123 |
+
return "\n".join(context_parts)
|
124 |
+
|
125 |
+
def get_relationship_context(self) -> str:
|
126 |
+
"""
|
127 |
+
Get context about relationships between sheets
|
128 |
+
|
129 |
+
Returns:
|
130 |
+
str: Information about sheet relationships
|
131 |
+
"""
|
132 |
+
if not self.processor.relationships:
|
133 |
+
return "No relationships detected between sheets."
|
134 |
+
|
135 |
+
context_parts = ["Sheet Relationships:"]
|
136 |
+
|
137 |
+
for rel_key, rel_info in self.processor.relationships.items():
|
138 |
+
if rel_info['type'] == 'potential_join':
|
139 |
+
sheets = rel_
|
src/utils/llm_utils.py
CHANGED
@@ -9,6 +9,7 @@ from src.llms.falcon_llm import FalconLanguageModel
|
|
9 |
from src.llms.llama_llm import LlamaLanguageModel
|
10 |
from src.embeddings.huggingface_embedding import HuggingFaceEmbedding
|
11 |
from src.vectorstores.chroma_vectorstore import ChromaVectorStore
|
|
|
12 |
from src.utils.logger import logger
|
13 |
from config.config import settings
|
14 |
|
@@ -39,21 +40,22 @@ def get_llm_instance(provider: str):
|
|
39 |
|
40 |
async def get_vector_store() -> Tuple[ChromaVectorStore, HuggingFaceEmbedding]:
|
41 |
"""
|
42 |
-
|
|
|
43 |
|
44 |
Returns:
|
45 |
-
Tuple[ChromaVectorStore, HuggingFaceEmbedding]:
|
46 |
-
|
47 |
-
Raises:
|
48 |
-
HTTPException: If vector store initialization fails
|
49 |
"""
|
50 |
try:
|
|
|
|
|
|
|
|
|
|
|
51 |
embedding = HuggingFaceEmbedding(model_name=settings.EMBEDDING_MODEL)
|
52 |
vector_store = ChromaVectorStore(
|
53 |
embedding_function=embedding.embed_documents,
|
54 |
persist_directory=settings.CHROMA_PATH
|
55 |
)
|
56 |
-
return vector_store, embedding
|
57 |
-
except Exception as e:
|
58 |
-
logger.error(f"Error initializing vector store: {str(e)}")
|
59 |
-
raise HTTPException(status_code=500, detail="Failed to initialize vector store")
|
|
|
9 |
from src.llms.llama_llm import LlamaLanguageModel
|
10 |
from src.embeddings.huggingface_embedding import HuggingFaceEmbedding
|
11 |
from src.vectorstores.chroma_vectorstore import ChromaVectorStore
|
12 |
+
from src.vectorstores.optimized_vectorstore import get_optimized_vector_store
|
13 |
from src.utils.logger import logger
|
14 |
from config.config import settings
|
15 |
|
|
|
40 |
|
41 |
async def get_vector_store() -> Tuple[ChromaVectorStore, HuggingFaceEmbedding]:
|
42 |
"""
|
43 |
+
Get vector store and embedding model instances
|
44 |
+
Uses optimized implementation while maintaining backward compatibility
|
45 |
|
46 |
Returns:
|
47 |
+
Tuple[ChromaVectorStore, HuggingFaceEmbedding]:
|
48 |
+
Vector store and embedding model instances
|
|
|
|
|
49 |
"""
|
50 |
try:
|
51 |
+
return await get_optimized_vector_store()
|
52 |
+
except Exception as e:
|
53 |
+
logger.error(f"Error getting optimized vector store: {str(e)}")
|
54 |
+
# Fallback to original implementation if optimization fails
|
55 |
+
logger.warning("Falling back to standard vector store implementation")
|
56 |
embedding = HuggingFaceEmbedding(model_name=settings.EMBEDDING_MODEL)
|
57 |
vector_store = ChromaVectorStore(
|
58 |
embedding_function=embedding.embed_documents,
|
59 |
persist_directory=settings.CHROMA_PATH
|
60 |
)
|
61 |
+
return vector_store, embedding
|
|
|
|
|
|
src/vectorstores/__pycache__/optimized_vectorstore.cpython-312.pyc
ADDED
Binary file (6.85 kB). View file
|
|
src/vectorstores/optimized_vectorstore.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# src/vectorstores/optimized_vectorstore.py
|
2 |
+
import asyncio
|
3 |
+
from typing import Tuple, Optional, List, Dict, Any, Callable
|
4 |
+
import concurrent.futures
|
5 |
+
from functools import lru_cache
|
6 |
+
|
7 |
+
from .base_vectorstore import BaseVectorStore
|
8 |
+
from .chroma_vectorstore import ChromaVectorStore
|
9 |
+
from src.embeddings.huggingface_embedding import HuggingFaceEmbedding
|
10 |
+
from src.utils.logger import logger
|
11 |
+
from config.config import settings
|
12 |
+
|
13 |
+
class OptimizedVectorStore(ChromaVectorStore):
|
14 |
+
"""
|
15 |
+
Optimized vector store that maintains ChromaVectorStore compatibility
|
16 |
+
while adding caching and async initialization
|
17 |
+
"""
|
18 |
+
_instance: Optional['OptimizedVectorStore'] = None
|
19 |
+
_lock = asyncio.Lock()
|
20 |
+
_initialized = False
|
21 |
+
_embedding_model: Optional[HuggingFaceEmbedding] = None
|
22 |
+
_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
23 |
+
|
24 |
+
def __new__(cls, *args, **kwargs):
|
25 |
+
if not cls._instance:
|
26 |
+
cls._instance = super().__new__(cls)
|
27 |
+
return cls._instance
|
28 |
+
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
embedding_function: Optional[Callable] = None,
|
32 |
+
persist_directory: str = settings.CHROMA_PATH,
|
33 |
+
collection_name: str = "documents",
|
34 |
+
client_settings: Optional[Dict[str, Any]] = None
|
35 |
+
):
|
36 |
+
"""
|
37 |
+
Initialize the optimized vector store
|
38 |
+
Note: The actual initialization is deferred until needed
|
39 |
+
"""
|
40 |
+
if not self._initialized:
|
41 |
+
self._persist_directory = persist_directory
|
42 |
+
self._collection_name = collection_name
|
43 |
+
self._client_settings = client_settings
|
44 |
+
self._embedding_function = embedding_function
|
45 |
+
# Don't call super().__init__() here - we'll do it in _initialize()
|
46 |
+
|
47 |
+
@classmethod
|
48 |
+
async def create(
|
49 |
+
cls,
|
50 |
+
persist_directory: str = settings.CHROMA_PATH,
|
51 |
+
collection_name: str = "documents",
|
52 |
+
client_settings: Optional[Dict[str, Any]] = None
|
53 |
+
) -> Tuple['OptimizedVectorStore', HuggingFaceEmbedding]:
|
54 |
+
"""
|
55 |
+
Asynchronously create or get instance
|
56 |
+
|
57 |
+
Returns:
|
58 |
+
Tuple[OptimizedVectorStore, HuggingFaceEmbedding]:
|
59 |
+
The vector store instance and embedding model
|
60 |
+
"""
|
61 |
+
async with cls._lock:
|
62 |
+
if not cls._instance or not cls._initialized:
|
63 |
+
instance = cls(
|
64 |
+
persist_directory=persist_directory,
|
65 |
+
collection_name=collection_name,
|
66 |
+
client_settings=client_settings
|
67 |
+
)
|
68 |
+
await instance._initialize()
|
69 |
+
cls._instance = instance
|
70 |
+
return cls._instance, cls._instance._embedding_model
|
71 |
+
|
72 |
+
async def _initialize(self) -> None:
|
73 |
+
"""Initialize the vector store and embedding model"""
|
74 |
+
if self._initialized:
|
75 |
+
return
|
76 |
+
|
77 |
+
try:
|
78 |
+
# Load embedding model in background thread
|
79 |
+
self._embedding_model = await self._load_embedding_model()
|
80 |
+
|
81 |
+
# Initialize ChromaVectorStore with the loaded model
|
82 |
+
super().__init__(
|
83 |
+
embedding_function=self._embedding_model.embed_documents,
|
84 |
+
persist_directory=self._persist_directory,
|
85 |
+
collection_name=self._collection_name,
|
86 |
+
client_settings=self._client_settings
|
87 |
+
)
|
88 |
+
|
89 |
+
self._initialized = True
|
90 |
+
|
91 |
+
except Exception as e:
|
92 |
+
logger.error(f"Error initializing vector store: {str(e)}")
|
93 |
+
raise
|
94 |
+
|
95 |
+
async def _load_embedding_model(self) -> HuggingFaceEmbedding:
|
96 |
+
"""Load embedding model in background thread"""
|
97 |
+
try:
|
98 |
+
loop = asyncio.get_event_loop()
|
99 |
+
return await loop.run_in_executor(
|
100 |
+
self._executor,
|
101 |
+
self._create_embedding_model
|
102 |
+
)
|
103 |
+
except Exception as e:
|
104 |
+
logger.error(f"Error loading embedding model: {str(e)}")
|
105 |
+
raise
|
106 |
+
|
107 |
+
@staticmethod
|
108 |
+
@lru_cache(maxsize=1)
|
109 |
+
def _create_embedding_model() -> HuggingFaceEmbedding:
|
110 |
+
"""Create and cache embedding model"""
|
111 |
+
return HuggingFaceEmbedding(model_name=settings.EMBEDDING_MODEL)
|
112 |
+
|
113 |
+
def __getattribute__(self, name):
|
114 |
+
"""
|
115 |
+
Ensure initialization before accessing any ChromaVectorStore methods
|
116 |
+
"""
|
117 |
+
# Get the attribute from the class
|
118 |
+
attr = super().__getattribute__(name)
|
119 |
+
|
120 |
+
# If it's a method from ChromaVectorStore, ensure initialization
|
121 |
+
if callable(attr) and name in ChromaVectorStore.__dict__:
|
122 |
+
if not self._initialized:
|
123 |
+
raise RuntimeError(
|
124 |
+
"Vector store not initialized. Please use 'await OptimizedVectorStore.create()'"
|
125 |
+
)
|
126 |
+
return attr
|
127 |
+
|
128 |
+
# Factory function for getting optimized vector store
|
129 |
+
async def get_optimized_vector_store() -> Tuple[ChromaVectorStore, HuggingFaceEmbedding]:
|
130 |
+
"""
|
131 |
+
Get or create an optimized vector store instance
|
132 |
+
|
133 |
+
Returns:
|
134 |
+
Tuple[ChromaVectorStore, HuggingFaceEmbedding]:
|
135 |
+
The vector store and embedding model instances
|
136 |
+
"""
|
137 |
+
return await OptimizedVectorStore.create()
|