euler314 commited on
Commit
e07ee7b
·
verified ·
1 Parent(s): 5462109

Create app/rag_search.py

Browse files
Files changed (1) hide show
  1. app/rag_search.py +410 -0
app/rag_search.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import logging
4
+ import nltk
5
+ from io import BytesIO
6
+ import numpy as np
7
+ from sklearn.feature_extraction.text import TfidfVectorizer
8
+ from sklearn.metrics.pairwise import cosine_similarity
9
+ import PyPDF2
10
+ import docx2txt
11
+ from functools import lru_cache
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ # Try to import sentence-transformers
16
+ try:
17
+ from sentence_transformers import SentenceTransformer
18
+ HAVE_TRANSFORMERS = True
19
+ except ImportError:
20
+ HAVE_TRANSFORMERS = False
21
+
22
+ # Try to download NLTK data if not already present
23
+ try:
24
+ nltk.data.find('tokenizers/punkt')
25
+ except LookupError:
26
+ try:
27
+ nltk.download('punkt', quiet=True)
28
+ except:
29
+ pass
30
+
31
+ try:
32
+ nltk.data.find('corpora/stopwords')
33
+ except LookupError:
34
+ try:
35
+ nltk.download('stopwords', quiet=True)
36
+ from nltk.corpus import stopwords
37
+ STOPWORDS = set(stopwords.words('english'))
38
+ except:
39
+ STOPWORDS = set(['the', 'and', 'a', 'in', 'to', 'of', 'is', 'it', 'that', 'for', 'with', 'as', 'on', 'by'])
40
+
41
+ class EnhancedRAGSearch:
42
+ def __init__(self):
43
+ self.file_texts = []
44
+ self.chunks = [] # Document chunks for more targeted search
45
+ self.chunk_metadata = [] # Metadata for each chunk
46
+ self.file_metadata = []
47
+ self.languages = []
48
+ self.model = None
49
+
50
+ # Try to load the sentence transformer model if available
51
+ if HAVE_TRANSFORMERS:
52
+ try:
53
+ # Use a small, efficient model
54
+ self.model = SentenceTransformer('all-MiniLM-L6-v2')
55
+ self.use_transformer = True
56
+ logger.info("Using sentence-transformers for RAG")
57
+ except Exception as e:
58
+ logger.warning(f"Error loading sentence-transformer: {e}")
59
+ self.use_transformer = False
60
+ else:
61
+ self.use_transformer = False
62
+
63
+ # Fallback to TF-IDF if transformers not available
64
+ if not self.use_transformer:
65
+ self.vectorizer = TfidfVectorizer(
66
+ stop_words='english',
67
+ ngram_range=(1, 2), # Use bigrams for better context
68
+ max_features=15000, # Use more features for better representation
69
+ min_df=1 # Include rare terms
70
+ )
71
+
72
+ self.vectors = None
73
+ self.chunk_vectors = None
74
+
75
+ def add_file(self, file_data, file_info):
76
+ """Add a file to the search index with improved processing"""
77
+ file_ext = os.path.splitext(file_info['filename'])[1].lower()
78
+ text = self.extract_text(file_data, file_ext)
79
+
80
+ if text:
81
+ # Store the whole document text
82
+ self.file_texts.append(text)
83
+ self.file_metadata.append(file_info)
84
+
85
+ # Try to detect language
86
+ try:
87
+ # Simple language detection based on stopwords
88
+ words = re.findall(r'\b\w+\b', text.lower())
89
+ english_stopwords_ratio = len([w for w in words[:100] if w in STOPWORDS]) / max(1, len(words[:100]))
90
+ lang = 'en' if english_stopwords_ratio > 0.2 else 'unknown'
91
+ self.languages.append(lang)
92
+ except:
93
+ self.languages.append('en') # Default to English
94
+
95
+ # Create chunks for more granular search
96
+ chunks = self.create_chunks(text)
97
+ for chunk in chunks:
98
+ self.chunks.append(chunk)
99
+ self.chunk_metadata.append({
100
+ 'file_info': file_info,
101
+ 'chunk_size': len(chunk),
102
+ 'file_index': len(self.file_texts) - 1
103
+ })
104
+
105
+ return True
106
+ return False
107
+
108
+ def create_chunks(self, text, chunk_size=1000, overlap=200):
109
+ """Split text into overlapping chunks for better search precision"""
110
+ try:
111
+ sentences = nltk.sent_tokenize(text)
112
+ chunks = []
113
+ current_chunk = ""
114
+
115
+ for sentence in sentences:
116
+ if len(current_chunk) + len(sentence) <= chunk_size:
117
+ current_chunk += sentence + " "
118
+ else:
119
+ # Add current chunk if it has content
120
+ if current_chunk:
121
+ chunks.append(current_chunk.strip())
122
+
123
+ # Start new chunk with overlap from previous chunk
124
+ if len(current_chunk) > overlap:
125
+ # Find the last space within the overlap region
126
+ overlap_text = current_chunk[-overlap:]
127
+ last_space = overlap_text.rfind(' ')
128
+ if last_space != -1:
129
+ current_chunk = current_chunk[-(overlap-last_space):] + sentence + " "
130
+ else:
131
+ current_chunk = sentence + " "
132
+ else:
133
+ current_chunk = sentence + " "
134
+
135
+ # Add the last chunk if it has content
136
+ if current_chunk:
137
+ chunks.append(current_chunk.strip())
138
+
139
+ return chunks
140
+ except:
141
+ # Fallback to simpler chunking approach
142
+ chunks = []
143
+ for i in range(0, len(text), chunk_size - overlap):
144
+ chunk = text[i:i + chunk_size]
145
+ if chunk:
146
+ chunks.append(chunk)
147
+ return chunks
148
+
149
+ def extract_text(self, file_data, file_ext):
150
+ """Extract text from different file types with enhanced support"""
151
+ try:
152
+ if file_ext.lower() == '.pdf':
153
+ reader = PyPDF2.PdfReader(BytesIO(file_data))
154
+ text = ""
155
+ for page in reader.pages:
156
+ extracted = page.extract_text()
157
+ if extracted:
158
+ text += extracted + "\n"
159
+ return text
160
+ elif file_ext.lower() in ['.docx', '.doc']:
161
+ return docx2txt.process(BytesIO(file_data))
162
+ elif file_ext.lower() in ['.txt', '.csv', '.json', '.html', '.htm']:
163
+ # Handle both UTF-8 and other common encodings
164
+ try:
165
+ return file_data.decode('utf-8', errors='ignore')
166
+ except:
167
+ encodings = ['latin-1', 'iso-8859-1', 'windows-1252']
168
+ for enc in encodings:
169
+ try:
170
+ return file_data.decode(enc, errors='ignore')
171
+ except:
172
+ pass
173
+ # Last resort fallback
174
+ return file_data.decode('utf-8', errors='ignore')
175
+ elif file_ext.lower() in ['.pptx', '.ppt', '.xlsx', '.xls']:
176
+ return f"[Content of {file_ext} file - install additional libraries for full text extraction]"
177
+ else:
178
+ return ""
179
+ except Exception as e:
180
+ logger.error(f"Error extracting text: {e}")
181
+ return ""
182
+
183
+ def build_index(self):
184
+ """Build both document and chunk search indices"""
185
+ if not self.file_texts:
186
+ return False
187
+
188
+ try:
189
+ if self.use_transformer:
190
+ # Use sentence transformer models for embeddings
191
+ logger.info("Building document and chunk embeddings with transformer model...")
192
+ self.vectors = self.model.encode(self.file_texts, show_progress_bar=False)
193
+
194
+ # Build chunk-level index if we have chunks
195
+ if self.chunks:
196
+ # Process in batches to avoid memory issues
197
+ batch_size = 32
198
+ chunk_vectors = []
199
+ for i in range(0, len(self.chunks), batch_size):
200
+ batch = self.chunks[i:i+batch_size]
201
+ batch_vectors = self.model.encode(batch, show_progress_bar=False)
202
+ chunk_vectors.append(batch_vectors)
203
+ self.chunk_vectors = np.vstack(chunk_vectors)
204
+ else:
205
+ # Build document-level index
206
+ self.vectors = self.vectorizer.fit_transform(self.file_texts)
207
+
208
+ # Build chunk-level index if we have chunks
209
+ if self.chunks:
210
+ self.chunk_vectors = self.vectorizer.transform(self.chunks)
211
+
212
+ return True
213
+ except Exception as e:
214
+ logger.error(f"Error building search index: {e}")
215
+ return False
216
+
217
+ def expand_query(self, query):
218
+ """Add related terms to query for better recall - mini LLM function"""
219
+ # Dictionary of related terms for common keywords
220
+ expansions = {
221
+ "exam": ["test", "assessment", "quiz", "paper", "exam paper", "past paper", "past exam"],
222
+ "test": ["exam", "quiz", "assessment", "paper"],
223
+ "document": ["file", "paper", "report", "doc", "documentation"],
224
+ "manual": ["guide", "instruction", "documentation", "handbook"],
225
+ "tutorial": ["guide", "instructions", "how-to", "lesson"],
226
+ "article": ["paper", "publication", "journal", "research"],
227
+ "research": ["study", "investigation", "paper", "analysis"],
228
+ "book": ["textbook", "publication", "volume", "edition"],
229
+ "thesis": ["dissertation", "paper", "research", "study"],
230
+ "report": ["document", "paper", "analysis", "summary"],
231
+ "assignment": ["homework", "task", "project", "work"],
232
+ "lecture": ["class", "presentation", "talk", "lesson"],
233
+ "notes": ["annotations", "summary", "outline", "study material"],
234
+ "syllabus": ["curriculum", "course outline", "program", "plan"],
235
+ "paper": ["document", "article", "publication", "exam", "test"],
236
+ "question": ["problem", "query", "exercise", "inquiry"],
237
+ "solution": ["answer", "resolution", "explanation", "result"],
238
+ "reference": ["source", "citation", "bibliography", "resource"],
239
+ "analysis": ["examination", "study", "evaluation", "assessment"],
240
+ "guide": ["manual", "instruction", "handbook", "tutorial"],
241
+ "worksheet": ["exercise", "activity", "handout", "practice"],
242
+ "review": ["evaluation", "assessment", "critique", "feedback"],
243
+ "material": ["resource", "content", "document", "information"],
244
+ "data": ["information", "statistics", "figures", "numbers"]
245
+ }
246
+
247
+ # Enhanced query expansion simulating a mini-LLM
248
+ query_words = re.findall(r'\b\w+\b', query.lower())
249
+ expanded_terms = set()
250
+
251
+ # Directly add expansions from our dictionary
252
+ for word in query_words:
253
+ if word in expansions:
254
+ expanded_terms.update(expansions[word])
255
+
256
+ # Add common academic file formats if not already included
257
+ if any(term in query.lower() for term in ["file", "document", "download", "paper"]):
258
+ if not any(ext in query.lower() for ext in ["pdf", "docx", "ppt", "excel"]):
259
+ expanded_terms.update(["pdf", "docx", "pptx", "xlsx"])
260
+
261
+ # Add special academic terms when the query seems related to education
262
+ if any(term in query.lower() for term in ["course", "university", "college", "school", "class"]):
263
+ expanded_terms.update(["syllabus", "lecture", "notes", "textbook"])
264
+
265
+ # Return original query plus expanded terms
266
+ if expanded_terms:
267
+ expanded_query = f"{query} {' '.join(expanded_terms)}"
268
+ logger.info(f"Expanded query: '{query}' -> '{expanded_query}'")
269
+ return expanded_query
270
+ return query
271
+
272
+ @lru_cache(maxsize=8)
273
+ def search(self, query, top_k=5, search_chunks=True):
274
+ """Enhanced search with both document and chunk-level search"""
275
+ if self.vectors is None:
276
+ return []
277
+
278
+ # Simulate a small LLM by expanding the query with related terms
279
+ expanded_query = self.expand_query(query)
280
+
281
+ try:
282
+ results = []
283
+
284
+ if self.use_transformer:
285
+ # Transform the query to embedding
286
+ query_vector = self.model.encode([expanded_query])[0]
287
+
288
+ # First search at document level for higher-level matches
289
+ if self.vectors is not None:
290
+ # Compute similarities between query and documents
291
+ doc_similarities = cosine_similarity(
292
+ query_vector.reshape(1, -1),
293
+ self.vectors
294
+ ).flatten()
295
+
296
+ top_doc_indices = doc_similarities.argsort()[-top_k:][::-1]
297
+
298
+ for i, idx in enumerate(top_doc_indices):
299
+ if doc_similarities[idx] > 0.2: # Threshold to exclude irrelevant results
300
+ results.append({
301
+ 'file_info': self.file_metadata[idx],
302
+ 'score': float(doc_similarities[idx]),
303
+ 'rank': i+1,
304
+ 'match_type': 'document',
305
+ 'language': self.languages[idx] if idx < len(self.languages) else 'unknown'
306
+ })
307
+
308
+ # Then search at chunk level for more specific matches if enabled
309
+ if search_chunks and self.chunk_vectors is not None:
310
+ # Compute similarities between query and chunks
311
+ chunk_similarities = cosine_similarity(
312
+ query_vector.reshape(1, -1),
313
+ self.chunk_vectors
314
+ ).flatten()
315
+
316
+ top_chunk_indices = chunk_similarities.argsort()[-top_k*2:][::-1] # Get more chunk results
317
+
318
+ # Use a set to avoid duplicate file results
319
+ seen_files = set(r['file_info']['url'] for r in results)
320
+
321
+ for i, idx in enumerate(top_chunk_indices):
322
+ if chunk_similarities[idx] > 0.25: # Higher threshold for chunks
323
+ file_index = self.chunk_metadata[idx]['file_index']
324
+ file_info = self.file_metadata[file_index]
325
+
326
+ # Only add if we haven't already included this file
327
+ if file_info['url'] not in seen_files:
328
+ seen_files.add(file_info['url'])
329
+ results.append({
330
+ 'file_info': file_info,
331
+ 'score': float(chunk_similarities[idx]),
332
+ 'rank': len(results) + 1,
333
+ 'match_type': 'chunk',
334
+ 'language': self.languages[file_index] if file_index < len(self.languages) else 'unknown',
335
+ 'chunk_preview': self.chunks[idx][:200] + "..." if len(self.chunks[idx]) > 200 else self.chunks[idx]
336
+ })
337
+
338
+ # Stop after we've found enough results
339
+ if len(results) >= top_k*1.5:
340
+ break
341
+ else:
342
+ # Fallback to TF-IDF if transformers not available
343
+ query_vector = self.vectorizer.transform([expanded_query])
344
+
345
+ # First search at document level
346
+ if self.vectors is not None:
347
+ doc_similarities = cosine_similarity(query_vector, self.vectors).flatten()
348
+ top_doc_indices = doc_similarities.argsort()[-top_k:][::-1]
349
+
350
+ for i, idx in enumerate(top_doc_indices):
351
+ if doc_similarities[idx] > 0.1: # Threshold to exclude irrelevant results
352
+ results.append({
353
+ 'file_info': self.file_metadata[idx],
354
+ 'score': float(doc_similarities[idx]),
355
+ 'rank': i+1,
356
+ 'match_type': 'document',
357
+ 'language': self.languages[idx] if idx < len(self.languages) else 'unknown'
358
+ })
359
+
360
+ # Then search at chunk level if enabled
361
+ if search_chunks and self.chunk_vectors is not None:
362
+ chunk_similarities = cosine_similarity(query_vector, self.chunk_vectors).flatten()
363
+ top_chunk_indices = chunk_similarities.argsort()[-top_k*2:][::-1]
364
+
365
+ # Avoid duplicates
366
+ seen_files = set(r['file_info']['url'] for r in results)
367
+
368
+ for i, idx in enumerate(top_chunk_indices):
369
+ if chunk_similarities[idx] > 0.15:
370
+ file_index = self.chunk_metadata[idx]['file_index']
371
+ file_info = self.file_metadata[file_index]
372
+
373
+ if file_info['url'] not in seen_files:
374
+ seen_files.add(file_info['url'])
375
+ results.append({
376
+ 'file_info': file_info,
377
+ 'score': float(chunk_similarities[idx]),
378
+ 'rank': len(results) + 1,
379
+ 'match_type': 'chunk',
380
+ 'language': self.languages[file_index] if file_index < len(self.languages) else 'unknown',
381
+ 'chunk_preview': self.chunks[idx][:200] + "..." if len(self.chunks[idx]) > 200 else self.chunks[idx]
382
+ })
383
+
384
+ if len(results) >= top_k*1.5:
385
+ break
386
+
387
+ # Sort combined results by score
388
+ results.sort(key=lambda x: x['score'], reverse=True)
389
+
390
+ # Re-rank and truncate
391
+ for i, result in enumerate(results[:top_k]):
392
+ result['rank'] = i+1
393
+
394
+ return results[:top_k]
395
+ except Exception as e:
396
+ logger.error(f"Error during search: {e}")
397
+ return []
398
+
399
+ def clear_cache(self):
400
+ """Clear search cache and free memory"""
401
+ if hasattr(self.search, 'cache_clear'):
402
+ self.search.cache_clear()
403
+
404
+ # Clear vectors to free memory
405
+ self.vectors = None
406
+ self.chunk_vectors = None
407
+
408
+ # Force garbage collection
409
+ import gc
410
+ gc.collect()