|
import os |
|
import re |
|
import logging |
|
import nltk |
|
from io import BytesIO |
|
import numpy as np |
|
from sklearn.feature_extraction.text import TfidfVectorizer |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
import PyPDF2 |
|
import docx2txt |
|
from functools import lru_cache |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
try: |
|
from sentence_transformers import SentenceTransformer |
|
HAVE_TRANSFORMERS = True |
|
except ImportError: |
|
HAVE_TRANSFORMERS = False |
|
|
|
|
|
try: |
|
nltk.data.find('tokenizers/punkt') |
|
except LookupError: |
|
try: |
|
nltk.download('punkt', quiet=True) |
|
except: |
|
pass |
|
|
|
try: |
|
nltk.data.find('corpora/stopwords') |
|
except LookupError: |
|
try: |
|
nltk.download('stopwords', quiet=True) |
|
from nltk.corpus import stopwords |
|
STOPWORDS = set(stopwords.words('english')) |
|
except: |
|
STOPWORDS = set(['the', 'and', 'a', 'in', 'to', 'of', 'is', 'it', 'that', 'for', 'with', 'as', 'on', 'by']) |
|
|
|
class EnhancedRAGSearch: |
|
def __init__(self): |
|
self.file_texts = [] |
|
self.chunks = [] |
|
self.chunk_metadata = [] |
|
self.file_metadata = [] |
|
self.languages = [] |
|
self.model = None |
|
|
|
|
|
if HAVE_TRANSFORMERS: |
|
try: |
|
|
|
self.model = SentenceTransformer('all-MiniLM-L6-v2') |
|
self.use_transformer = True |
|
logger.info("Using sentence-transformers for RAG") |
|
except Exception as e: |
|
logger.warning(f"Error loading sentence-transformer: {e}") |
|
self.use_transformer = False |
|
else: |
|
self.use_transformer = False |
|
|
|
|
|
if not self.use_transformer: |
|
self.vectorizer = TfidfVectorizer( |
|
stop_words='english', |
|
ngram_range=(1, 2), |
|
max_features=15000, |
|
min_df=1 |
|
) |
|
|
|
self.vectors = None |
|
self.chunk_vectors = None |
|
|
|
def add_file(self, file_data, file_info): |
|
"""Add a file to the search index with improved processing""" |
|
file_ext = os.path.splitext(file_info['filename'])[1].lower() |
|
text = self.extract_text(file_data, file_ext) |
|
|
|
if text: |
|
|
|
self.file_texts.append(text) |
|
self.file_metadata.append(file_info) |
|
|
|
|
|
try: |
|
|
|
words = re.findall(r'\b\w+\b', text.lower()) |
|
english_stopwords_ratio = len([w for w in words[:100] if w in STOPWORDS]) / max(1, len(words[:100])) |
|
lang = 'en' if english_stopwords_ratio > 0.2 else 'unknown' |
|
self.languages.append(lang) |
|
except: |
|
self.languages.append('en') |
|
|
|
|
|
chunks = self.create_chunks(text) |
|
for chunk in chunks: |
|
self.chunks.append(chunk) |
|
self.chunk_metadata.append({ |
|
'file_info': file_info, |
|
'chunk_size': len(chunk), |
|
'file_index': len(self.file_texts) - 1 |
|
}) |
|
|
|
return True |
|
return False |
|
|
|
def create_chunks(self, text, chunk_size=1000, overlap=200): |
|
"""Split text into overlapping chunks for better search precision""" |
|
try: |
|
sentences = nltk.sent_tokenize(text) |
|
chunks = [] |
|
current_chunk = "" |
|
|
|
for sentence in sentences: |
|
if len(current_chunk) + len(sentence) <= chunk_size: |
|
current_chunk += sentence + " " |
|
else: |
|
|
|
if current_chunk: |
|
chunks.append(current_chunk.strip()) |
|
|
|
|
|
if len(current_chunk) > overlap: |
|
|
|
overlap_text = current_chunk[-overlap:] |
|
last_space = overlap_text.rfind(' ') |
|
if last_space != -1: |
|
current_chunk = current_chunk[-(overlap-last_space):] + sentence + " " |
|
else: |
|
current_chunk = sentence + " " |
|
else: |
|
current_chunk = sentence + " " |
|
|
|
|
|
if current_chunk: |
|
chunks.append(current_chunk.strip()) |
|
|
|
return chunks |
|
except: |
|
|
|
chunks = [] |
|
for i in range(0, len(text), chunk_size - overlap): |
|
chunk = text[i:i + chunk_size] |
|
if chunk: |
|
chunks.append(chunk) |
|
return chunks |
|
|
|
def extract_text(self, file_data, file_ext): |
|
"""Extract text from different file types with enhanced support""" |
|
try: |
|
if file_ext.lower() == '.pdf': |
|
reader = PyPDF2.PdfReader(BytesIO(file_data)) |
|
text = "" |
|
for page in reader.pages: |
|
extracted = page.extract_text() |
|
if extracted: |
|
text += extracted + "\n" |
|
return text |
|
elif file_ext.lower() in ['.docx', '.doc']: |
|
return docx2txt.process(BytesIO(file_data)) |
|
elif file_ext.lower() in ['.txt', '.csv', '.json', '.html', '.htm']: |
|
|
|
try: |
|
return file_data.decode('utf-8', errors='ignore') |
|
except: |
|
encodings = ['latin-1', 'iso-8859-1', 'windows-1252'] |
|
for enc in encodings: |
|
try: |
|
return file_data.decode(enc, errors='ignore') |
|
except: |
|
pass |
|
|
|
return file_data.decode('utf-8', errors='ignore') |
|
elif file_ext.lower() in ['.pptx', '.ppt', '.xlsx', '.xls']: |
|
return f"[Content of {file_ext} file - install additional libraries for full text extraction]" |
|
else: |
|
return "" |
|
except Exception as e: |
|
logger.error(f"Error extracting text: {e}") |
|
return "" |
|
|
|
def build_index(self): |
|
"""Build both document and chunk search indices""" |
|
if not self.file_texts: |
|
return False |
|
|
|
try: |
|
if self.use_transformer: |
|
|
|
logger.info("Building document and chunk embeddings with transformer model...") |
|
self.vectors = self.model.encode(self.file_texts, show_progress_bar=False) |
|
|
|
|
|
if self.chunks: |
|
|
|
batch_size = 32 |
|
chunk_vectors = [] |
|
for i in range(0, len(self.chunks), batch_size): |
|
batch = self.chunks[i:i+batch_size] |
|
batch_vectors = self.model.encode(batch, show_progress_bar=False) |
|
chunk_vectors.append(batch_vectors) |
|
self.chunk_vectors = np.vstack(chunk_vectors) |
|
else: |
|
|
|
self.vectors = self.vectorizer.fit_transform(self.file_texts) |
|
|
|
|
|
if self.chunks: |
|
self.chunk_vectors = self.vectorizer.transform(self.chunks) |
|
|
|
return True |
|
except Exception as e: |
|
logger.error(f"Error building search index: {e}") |
|
return False |
|
|
|
def expand_query(self, query): |
|
"""Add related terms to query for better recall - mini LLM function""" |
|
|
|
expansions = { |
|
"exam": ["test", "assessment", "quiz", "paper", "exam paper", "past paper", "past exam"], |
|
"test": ["exam", "quiz", "assessment", "paper"], |
|
"document": ["file", "paper", "report", "doc", "documentation"], |
|
"manual": ["guide", "instruction", "documentation", "handbook"], |
|
"tutorial": ["guide", "instructions", "how-to", "lesson"], |
|
"article": ["paper", "publication", "journal", "research"], |
|
"research": ["study", "investigation", "paper", "analysis"], |
|
"book": ["textbook", "publication", "volume", "edition"], |
|
"thesis": ["dissertation", "paper", "research", "study"], |
|
"report": ["document", "paper", "analysis", "summary"], |
|
"assignment": ["homework", "task", "project", "work"], |
|
"lecture": ["class", "presentation", "talk", "lesson"], |
|
"notes": ["annotations", "summary", "outline", "study material"], |
|
"syllabus": ["curriculum", "course outline", "program", "plan"], |
|
"paper": ["document", "article", "publication", "exam", "test"], |
|
"question": ["problem", "query", "exercise", "inquiry"], |
|
"solution": ["answer", "resolution", "explanation", "result"], |
|
"reference": ["source", "citation", "bibliography", "resource"], |
|
"analysis": ["examination", "study", "evaluation", "assessment"], |
|
"guide": ["manual", "instruction", "handbook", "tutorial"], |
|
"worksheet": ["exercise", "activity", "handout", "practice"], |
|
"review": ["evaluation", "assessment", "critique", "feedback"], |
|
"material": ["resource", "content", "document", "information"], |
|
"data": ["information", "statistics", "figures", "numbers"] |
|
} |
|
|
|
|
|
query_words = re.findall(r'\b\w+\b', query.lower()) |
|
expanded_terms = set() |
|
|
|
|
|
for word in query_words: |
|
if word in expansions: |
|
expanded_terms.update(expansions[word]) |
|
|
|
|
|
if any(term in query.lower() for term in ["file", "document", "download", "paper"]): |
|
if not any(ext in query.lower() for ext in ["pdf", "docx", "ppt", "excel"]): |
|
expanded_terms.update(["pdf", "docx", "pptx", "xlsx"]) |
|
|
|
|
|
if any(term in query.lower() for term in ["course", "university", "college", "school", "class"]): |
|
expanded_terms.update(["syllabus", "lecture", "notes", "textbook"]) |
|
|
|
|
|
if expanded_terms: |
|
expanded_query = f"{query} {' '.join(expanded_terms)}" |
|
logger.info(f"Expanded query: '{query}' -> '{expanded_query}'") |
|
return expanded_query |
|
return query |
|
|
|
@lru_cache(maxsize=8) |
|
def search(self, query, top_k=5, search_chunks=True): |
|
"""Enhanced search with both document and chunk-level search""" |
|
if self.vectors is None: |
|
return [] |
|
|
|
|
|
expanded_query = self.expand_query(query) |
|
|
|
try: |
|
results = [] |
|
|
|
if self.use_transformer: |
|
|
|
query_vector = self.model.encode([expanded_query])[0] |
|
|
|
|
|
if self.vectors is not None: |
|
|
|
doc_similarities = cosine_similarity( |
|
query_vector.reshape(1, -1), |
|
self.vectors |
|
).flatten() |
|
|
|
top_doc_indices = doc_similarities.argsort()[-top_k:][::-1] |
|
|
|
for i, idx in enumerate(top_doc_indices): |
|
if doc_similarities[idx] > 0.2: |
|
results.append({ |
|
'file_info': self.file_metadata[idx], |
|
'score': float(doc_similarities[idx]), |
|
'rank': i+1, |
|
'match_type': 'document', |
|
'language': self.languages[idx] if idx < len(self.languages) else 'unknown' |
|
}) |
|
|
|
|
|
if search_chunks and self.chunk_vectors is not None: |
|
|
|
chunk_similarities = cosine_similarity( |
|
query_vector.reshape(1, -1), |
|
self.chunk_vectors |
|
).flatten() |
|
|
|
top_chunk_indices = chunk_similarities.argsort()[-top_k*2:][::-1] |
|
|
|
|
|
seen_files = set(r['file_info']['url'] for r in results) |
|
|
|
for i, idx in enumerate(top_chunk_indices): |
|
if chunk_similarities[idx] > 0.25: |
|
file_index = self.chunk_metadata[idx]['file_index'] |
|
file_info = self.file_metadata[file_index] |
|
|
|
|
|
if file_info['url'] not in seen_files: |
|
seen_files.add(file_info['url']) |
|
results.append({ |
|
'file_info': file_info, |
|
'score': float(chunk_similarities[idx]), |
|
'rank': len(results) + 1, |
|
'match_type': 'chunk', |
|
'language': self.languages[file_index] if file_index < len(self.languages) else 'unknown', |
|
'chunk_preview': self.chunks[idx][:200] + "..." if len(self.chunks[idx]) > 200 else self.chunks[idx] |
|
}) |
|
|
|
|
|
if len(results) >= top_k*1.5: |
|
break |
|
else: |
|
|
|
query_vector = self.vectorizer.transform([expanded_query]) |
|
|
|
|
|
if self.vectors is not None: |
|
doc_similarities = cosine_similarity(query_vector, self.vectors).flatten() |
|
top_doc_indices = doc_similarities.argsort()[-top_k:][::-1] |
|
|
|
for i, idx in enumerate(top_doc_indices): |
|
if doc_similarities[idx] > 0.1: |
|
results.append({ |
|
'file_info': self.file_metadata[idx], |
|
'score': float(doc_similarities[idx]), |
|
'rank': i+1, |
|
'match_type': 'document', |
|
'language': self.languages[idx] if idx < len(self.languages) else 'unknown' |
|
}) |
|
|
|
|
|
if search_chunks and self.chunk_vectors is not None: |
|
chunk_similarities = cosine_similarity(query_vector, self.chunk_vectors).flatten() |
|
top_chunk_indices = chunk_similarities.argsort()[-top_k*2:][::-1] |
|
|
|
|
|
seen_files = set(r['file_info']['url'] for r in results) |
|
|
|
for i, idx in enumerate(top_chunk_indices): |
|
if chunk_similarities[idx] > 0.15: |
|
file_index = self.chunk_metadata[idx]['file_index'] |
|
file_info = self.file_metadata[file_index] |
|
|
|
if file_info['url'] not in seen_files: |
|
seen_files.add(file_info['url']) |
|
results.append({ |
|
'file_info': file_info, |
|
'score': float(chunk_similarities[idx]), |
|
'rank': len(results) + 1, |
|
'match_type': 'chunk', |
|
'language': self.languages[file_index] if file_index < len(self.languages) else 'unknown', |
|
'chunk_preview': self.chunks[idx][:200] + "..." if len(self.chunks[idx]) > 200 else self.chunks[idx] |
|
}) |
|
|
|
if len(results) >= top_k*1.5: |
|
break |
|
|
|
|
|
results.sort(key=lambda x: x['score'], reverse=True) |
|
|
|
|
|
for i, result in enumerate(results[:top_k]): |
|
result['rank'] = i+1 |
|
|
|
return results[:top_k] |
|
except Exception as e: |
|
logger.error(f"Error during search: {e}") |
|
return [] |
|
|
|
def clear_cache(self): |
|
"""Clear search cache and free memory""" |
|
if hasattr(self.search, 'cache_clear'): |
|
self.search.cache_clear() |
|
|
|
|
|
self.vectors = None |
|
self.chunk_vectors = None |
|
|
|
|
|
import gc |
|
gc.collect() |