euler314 commited on
Commit
f0d7dcd
·
verified ·
1 Parent(s): deb3f2f

Delete rag_search.py

Browse files
Files changed (1) hide show
  1. rag_search.py +0 -410
rag_search.py DELETED
@@ -1,410 +0,0 @@
1
- import os
2
- import re
3
- import logging
4
- import nltk
5
- from io import BytesIO
6
- import numpy as np
7
- from sklearn.feature_extraction.text import TfidfVectorizer
8
- from sklearn.metrics.pairwise import cosine_similarity
9
- import PyPDF2
10
- import docx2txt
11
- from functools import lru_cache
12
-
13
- logger = logging.getLogger(__name__)
14
-
15
- # Try to import sentence-transformers
16
- try:
17
- from sentence_transformers import SentenceTransformer
18
- HAVE_TRANSFORMERS = True
19
- except ImportError:
20
- HAVE_TRANSFORMERS = False
21
-
22
- # Try to download NLTK data if not already present
23
- try:
24
- nltk.data.find('tokenizers/punkt')
25
- except LookupError:
26
- try:
27
- nltk.download('punkt', quiet=True)
28
- except:
29
- pass
30
-
31
- try:
32
- nltk.data.find('corpora/stopwords')
33
- except LookupError:
34
- try:
35
- nltk.download('stopwords', quiet=True)
36
- from nltk.corpus import stopwords
37
- STOPWORDS = set(stopwords.words('english'))
38
- except:
39
- STOPWORDS = set(['the', 'and', 'a', 'in', 'to', 'of', 'is', 'it', 'that', 'for', 'with', 'as', 'on', 'by'])
40
-
41
- class EnhancedRAGSearch:
42
- def __init__(self):
43
- self.file_texts = []
44
- self.chunks = [] # Document chunks for more targeted search
45
- self.chunk_metadata = [] # Metadata for each chunk
46
- self.file_metadata = []
47
- self.languages = []
48
- self.model = None
49
-
50
- # Try to load the sentence transformer model if available
51
- if HAVE_TRANSFORMERS:
52
- try:
53
- # Use a small, efficient model
54
- self.model = SentenceTransformer('all-MiniLM-L6-v2')
55
- self.use_transformer = True
56
- logger.info("Using sentence-transformers for RAG")
57
- except Exception as e:
58
- logger.warning(f"Error loading sentence-transformer: {e}")
59
- self.use_transformer = False
60
- else:
61
- self.use_transformer = False
62
-
63
- # Fallback to TF-IDF if transformers not available
64
- if not self.use_transformer:
65
- self.vectorizer = TfidfVectorizer(
66
- stop_words='english',
67
- ngram_range=(1, 2), # Use bigrams for better context
68
- max_features=15000, # Use more features for better representation
69
- min_df=1 # Include rare terms
70
- )
71
-
72
- self.vectors = None
73
- self.chunk_vectors = None
74
-
75
- def add_file(self, file_data, file_info):
76
- """Add a file to the search index with improved processing"""
77
- file_ext = os.path.splitext(file_info['filename'])[1].lower()
78
- text = self.extract_text(file_data, file_ext)
79
-
80
- if text:
81
- # Store the whole document text
82
- self.file_texts.append(text)
83
- self.file_metadata.append(file_info)
84
-
85
- # Try to detect language
86
- try:
87
- # Simple language detection based on stopwords
88
- words = re.findall(r'\b\w+\b', text.lower())
89
- english_stopwords_ratio = len([w for w in words[:100] if w in STOPWORDS]) / max(1, len(words[:100]))
90
- lang = 'en' if english_stopwords_ratio > 0.2 else 'unknown'
91
- self.languages.append(lang)
92
- except:
93
- self.languages.append('en') # Default to English
94
-
95
- # Create chunks for more granular search
96
- chunks = self.create_chunks(text)
97
- for chunk in chunks:
98
- self.chunks.append(chunk)
99
- self.chunk_metadata.append({
100
- 'file_info': file_info,
101
- 'chunk_size': len(chunk),
102
- 'file_index': len(self.file_texts) - 1
103
- })
104
-
105
- return True
106
- return False
107
-
108
- def create_chunks(self, text, chunk_size=1000, overlap=200):
109
- """Split text into overlapping chunks for better search precision"""
110
- try:
111
- sentences = nltk.sent_tokenize(text)
112
- chunks = []
113
- current_chunk = ""
114
-
115
- for sentence in sentences:
116
- if len(current_chunk) + len(sentence) <= chunk_size:
117
- current_chunk += sentence + " "
118
- else:
119
- # Add current chunk if it has content
120
- if current_chunk:
121
- chunks.append(current_chunk.strip())
122
-
123
- # Start new chunk with overlap from previous chunk
124
- if len(current_chunk) > overlap:
125
- # Find the last space within the overlap region
126
- overlap_text = current_chunk[-overlap:]
127
- last_space = overlap_text.rfind(' ')
128
- if last_space != -1:
129
- current_chunk = current_chunk[-(overlap-last_space):] + sentence + " "
130
- else:
131
- current_chunk = sentence + " "
132
- else:
133
- current_chunk = sentence + " "
134
-
135
- # Add the last chunk if it has content
136
- if current_chunk:
137
- chunks.append(current_chunk.strip())
138
-
139
- return chunks
140
- except:
141
- # Fallback to simpler chunking approach
142
- chunks = []
143
- for i in range(0, len(text), chunk_size - overlap):
144
- chunk = text[i:i + chunk_size]
145
- if chunk:
146
- chunks.append(chunk)
147
- return chunks
148
-
149
- def extract_text(self, file_data, file_ext):
150
- """Extract text from different file types with enhanced support"""
151
- try:
152
- if file_ext.lower() == '.pdf':
153
- reader = PyPDF2.PdfReader(BytesIO(file_data))
154
- text = ""
155
- for page in reader.pages:
156
- extracted = page.extract_text()
157
- if extracted:
158
- text += extracted + "\n"
159
- return text
160
- elif file_ext.lower() in ['.docx', '.doc']:
161
- return docx2txt.process(BytesIO(file_data))
162
- elif file_ext.lower() in ['.txt', '.csv', '.json', '.html', '.htm']:
163
- # Handle both UTF-8 and other common encodings
164
- try:
165
- return file_data.decode('utf-8', errors='ignore')
166
- except:
167
- encodings = ['latin-1', 'iso-8859-1', 'windows-1252']
168
- for enc in encodings:
169
- try:
170
- return file_data.decode(enc, errors='ignore')
171
- except:
172
- pass
173
- # Last resort fallback
174
- return file_data.decode('utf-8', errors='ignore')
175
- elif file_ext.lower() in ['.pptx', '.ppt', '.xlsx', '.xls']:
176
- return f"[Content of {file_ext} file - install additional libraries for full text extraction]"
177
- else:
178
- return ""
179
- except Exception as e:
180
- logger.error(f"Error extracting text: {e}")
181
- return ""
182
-
183
- def build_index(self):
184
- """Build both document and chunk search indices"""
185
- if not self.file_texts:
186
- return False
187
-
188
- try:
189
- if self.use_transformer:
190
- # Use sentence transformer models for embeddings
191
- logger.info("Building document and chunk embeddings with transformer model...")
192
- self.vectors = self.model.encode(self.file_texts, show_progress_bar=False)
193
-
194
- # Build chunk-level index if we have chunks
195
- if self.chunks:
196
- # Process in batches to avoid memory issues
197
- batch_size = 32
198
- chunk_vectors = []
199
- for i in range(0, len(self.chunks), batch_size):
200
- batch = self.chunks[i:i+batch_size]
201
- batch_vectors = self.model.encode(batch, show_progress_bar=False)
202
- chunk_vectors.append(batch_vectors)
203
- self.chunk_vectors = np.vstack(chunk_vectors)
204
- else:
205
- # Build document-level index
206
- self.vectors = self.vectorizer.fit_transform(self.file_texts)
207
-
208
- # Build chunk-level index if we have chunks
209
- if self.chunks:
210
- self.chunk_vectors = self.vectorizer.transform(self.chunks)
211
-
212
- return True
213
- except Exception as e:
214
- logger.error(f"Error building search index: {e}")
215
- return False
216
-
217
- def expand_query(self, query):
218
- """Add related terms to query for better recall - mini LLM function"""
219
- # Dictionary of related terms for common keywords
220
- expansions = {
221
- "exam": ["test", "assessment", "quiz", "paper", "exam paper", "past paper", "past exam"],
222
- "test": ["exam", "quiz", "assessment", "paper"],
223
- "document": ["file", "paper", "report", "doc", "documentation"],
224
- "manual": ["guide", "instruction", "documentation", "handbook"],
225
- "tutorial": ["guide", "instructions", "how-to", "lesson"],
226
- "article": ["paper", "publication", "journal", "research"],
227
- "research": ["study", "investigation", "paper", "analysis"],
228
- "book": ["textbook", "publication", "volume", "edition"],
229
- "thesis": ["dissertation", "paper", "research", "study"],
230
- "report": ["document", "paper", "analysis", "summary"],
231
- "assignment": ["homework", "task", "project", "work"],
232
- "lecture": ["class", "presentation", "talk", "lesson"],
233
- "notes": ["annotations", "summary", "outline", "study material"],
234
- "syllabus": ["curriculum", "course outline", "program", "plan"],
235
- "paper": ["document", "article", "publication", "exam", "test"],
236
- "question": ["problem", "query", "exercise", "inquiry"],
237
- "solution": ["answer", "resolution", "explanation", "result"],
238
- "reference": ["source", "citation", "bibliography", "resource"],
239
- "analysis": ["examination", "study", "evaluation", "assessment"],
240
- "guide": ["manual", "instruction", "handbook", "tutorial"],
241
- "worksheet": ["exercise", "activity", "handout", "practice"],
242
- "review": ["evaluation", "assessment", "critique", "feedback"],
243
- "material": ["resource", "content", "document", "information"],
244
- "data": ["information", "statistics", "figures", "numbers"]
245
- }
246
-
247
- # Enhanced query expansion simulating a mini-LLM
248
- query_words = re.findall(r'\b\w+\b', query.lower())
249
- expanded_terms = set()
250
-
251
- # Directly add expansions from our dictionary
252
- for word in query_words:
253
- if word in expansions:
254
- expanded_terms.update(expansions[word])
255
-
256
- # Add common academic file formats if not already included
257
- if any(term in query.lower() for term in ["file", "document", "download", "paper"]):
258
- if not any(ext in query.lower() for ext in ["pdf", "docx", "ppt", "excel"]):
259
- expanded_terms.update(["pdf", "docx", "pptx", "xlsx"])
260
-
261
- # Add special academic terms when the query seems related to education
262
- if any(term in query.lower() for term in ["course", "university", "college", "school", "class"]):
263
- expanded_terms.update(["syllabus", "lecture", "notes", "textbook"])
264
-
265
- # Return original query plus expanded terms
266
- if expanded_terms:
267
- expanded_query = f"{query} {' '.join(expanded_terms)}"
268
- logger.info(f"Expanded query: '{query}' -> '{expanded_query}'")
269
- return expanded_query
270
- return query
271
-
272
- @lru_cache(maxsize=8)
273
- def search(self, query, top_k=5, search_chunks=True):
274
- """Enhanced search with both document and chunk-level search"""
275
- if self.vectors is None:
276
- return []
277
-
278
- # Simulate a small LLM by expanding the query with related terms
279
- expanded_query = self.expand_query(query)
280
-
281
- try:
282
- results = []
283
-
284
- if self.use_transformer:
285
- # Transform the query to embedding
286
- query_vector = self.model.encode([expanded_query])[0]
287
-
288
- # First search at document level for higher-level matches
289
- if self.vectors is not None:
290
- # Compute similarities between query and documents
291
- doc_similarities = cosine_similarity(
292
- query_vector.reshape(1, -1),
293
- self.vectors
294
- ).flatten()
295
-
296
- top_doc_indices = doc_similarities.argsort()[-top_k:][::-1]
297
-
298
- for i, idx in enumerate(top_doc_indices):
299
- if doc_similarities[idx] > 0.2: # Threshold to exclude irrelevant results
300
- results.append({
301
- 'file_info': self.file_metadata[idx],
302
- 'score': float(doc_similarities[idx]),
303
- 'rank': i+1,
304
- 'match_type': 'document',
305
- 'language': self.languages[idx] if idx < len(self.languages) else 'unknown'
306
- })
307
-
308
- # Then search at chunk level for more specific matches if enabled
309
- if search_chunks and self.chunk_vectors is not None:
310
- # Compute similarities between query and chunks
311
- chunk_similarities = cosine_similarity(
312
- query_vector.reshape(1, -1),
313
- self.chunk_vectors
314
- ).flatten()
315
-
316
- top_chunk_indices = chunk_similarities.argsort()[-top_k*2:][::-1] # Get more chunk results
317
-
318
- # Use a set to avoid duplicate file results
319
- seen_files = set(r['file_info']['url'] for r in results)
320
-
321
- for i, idx in enumerate(top_chunk_indices):
322
- if chunk_similarities[idx] > 0.25: # Higher threshold for chunks
323
- file_index = self.chunk_metadata[idx]['file_index']
324
- file_info = self.file_metadata[file_index]
325
-
326
- # Only add if we haven't already included this file
327
- if file_info['url'] not in seen_files:
328
- seen_files.add(file_info['url'])
329
- results.append({
330
- 'file_info': file_info,
331
- 'score': float(chunk_similarities[idx]),
332
- 'rank': len(results) + 1,
333
- 'match_type': 'chunk',
334
- 'language': self.languages[file_index] if file_index < len(self.languages) else 'unknown',
335
- 'chunk_preview': self.chunks[idx][:200] + "..." if len(self.chunks[idx]) > 200 else self.chunks[idx]
336
- })
337
-
338
- # Stop after we've found enough results
339
- if len(results) >= top_k*1.5:
340
- break
341
- else:
342
- # Fallback to TF-IDF if transformers not available
343
- query_vector = self.vectorizer.transform([expanded_query])
344
-
345
- # First search at document level
346
- if self.vectors is not None:
347
- doc_similarities = cosine_similarity(query_vector, self.vectors).flatten()
348
- top_doc_indices = doc_similarities.argsort()[-top_k:][::-1]
349
-
350
- for i, idx in enumerate(top_doc_indices):
351
- if doc_similarities[idx] > 0.1: # Threshold to exclude irrelevant results
352
- results.append({
353
- 'file_info': self.file_metadata[idx],
354
- 'score': float(doc_similarities[idx]),
355
- 'rank': i+1,
356
- 'match_type': 'document',
357
- 'language': self.languages[idx] if idx < len(self.languages) else 'unknown'
358
- })
359
-
360
- # Then search at chunk level if enabled
361
- if search_chunks and self.chunk_vectors is not None:
362
- chunk_similarities = cosine_similarity(query_vector, self.chunk_vectors).flatten()
363
- top_chunk_indices = chunk_similarities.argsort()[-top_k*2:][::-1]
364
-
365
- # Avoid duplicates
366
- seen_files = set(r['file_info']['url'] for r in results)
367
-
368
- for i, idx in enumerate(top_chunk_indices):
369
- if chunk_similarities[idx] > 0.15:
370
- file_index = self.chunk_metadata[idx]['file_index']
371
- file_info = self.file_metadata[file_index]
372
-
373
- if file_info['url'] not in seen_files:
374
- seen_files.add(file_info['url'])
375
- results.append({
376
- 'file_info': file_info,
377
- 'score': float(chunk_similarities[idx]),
378
- 'rank': len(results) + 1,
379
- 'match_type': 'chunk',
380
- 'language': self.languages[file_index] if file_index < len(self.languages) else 'unknown',
381
- 'chunk_preview': self.chunks[idx][:200] + "..." if len(self.chunks[idx]) > 200 else self.chunks[idx]
382
- })
383
-
384
- if len(results) >= top_k*1.5:
385
- break
386
-
387
- # Sort combined results by score
388
- results.sort(key=lambda x: x['score'], reverse=True)
389
-
390
- # Re-rank and truncate
391
- for i, result in enumerate(results[:top_k]):
392
- result['rank'] = i+1
393
-
394
- return results[:top_k]
395
- except Exception as e:
396
- logger.error(f"Error during search: {e}")
397
- return []
398
-
399
- def clear_cache(self):
400
- """Clear search cache and free memory"""
401
- if hasattr(self.search, 'cache_clear'):
402
- self.search.cache_clear()
403
-
404
- # Clear vectors to free memory
405
- self.vectors = None
406
- self.chunk_vectors = None
407
-
408
- # Force garbage collection
409
- import gc
410
- gc.collect()