craw_web / rag_search.py
euler314's picture
Rename app/rag_search.py to rag_search.py
83899b0 verified
raw
history blame
19.1 kB
import os
import re
import logging
import nltk
from io import BytesIO
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import PyPDF2
import docx2txt
from functools import lru_cache
logger = logging.getLogger(__name__)
# Try to import sentence-transformers
try:
from sentence_transformers import SentenceTransformer
HAVE_TRANSFORMERS = True
except ImportError:
HAVE_TRANSFORMERS = False
# Try to download NLTK data if not already present
try:
nltk.data.find('tokenizers/punkt')
except LookupError:
try:
nltk.download('punkt', quiet=True)
except:
pass
try:
nltk.data.find('corpora/stopwords')
except LookupError:
try:
nltk.download('stopwords', quiet=True)
from nltk.corpus import stopwords
STOPWORDS = set(stopwords.words('english'))
except:
STOPWORDS = set(['the', 'and', 'a', 'in', 'to', 'of', 'is', 'it', 'that', 'for', 'with', 'as', 'on', 'by'])
class EnhancedRAGSearch:
def __init__(self):
self.file_texts = []
self.chunks = [] # Document chunks for more targeted search
self.chunk_metadata = [] # Metadata for each chunk
self.file_metadata = []
self.languages = []
self.model = None
# Try to load the sentence transformer model if available
if HAVE_TRANSFORMERS:
try:
# Use a small, efficient model
self.model = SentenceTransformer('all-MiniLM-L6-v2')
self.use_transformer = True
logger.info("Using sentence-transformers for RAG")
except Exception as e:
logger.warning(f"Error loading sentence-transformer: {e}")
self.use_transformer = False
else:
self.use_transformer = False
# Fallback to TF-IDF if transformers not available
if not self.use_transformer:
self.vectorizer = TfidfVectorizer(
stop_words='english',
ngram_range=(1, 2), # Use bigrams for better context
max_features=15000, # Use more features for better representation
min_df=1 # Include rare terms
)
self.vectors = None
self.chunk_vectors = None
def add_file(self, file_data, file_info):
"""Add a file to the search index with improved processing"""
file_ext = os.path.splitext(file_info['filename'])[1].lower()
text = self.extract_text(file_data, file_ext)
if text:
# Store the whole document text
self.file_texts.append(text)
self.file_metadata.append(file_info)
# Try to detect language
try:
# Simple language detection based on stopwords
words = re.findall(r'\b\w+\b', text.lower())
english_stopwords_ratio = len([w for w in words[:100] if w in STOPWORDS]) / max(1, len(words[:100]))
lang = 'en' if english_stopwords_ratio > 0.2 else 'unknown'
self.languages.append(lang)
except:
self.languages.append('en') # Default to English
# Create chunks for more granular search
chunks = self.create_chunks(text)
for chunk in chunks:
self.chunks.append(chunk)
self.chunk_metadata.append({
'file_info': file_info,
'chunk_size': len(chunk),
'file_index': len(self.file_texts) - 1
})
return True
return False
def create_chunks(self, text, chunk_size=1000, overlap=200):
"""Split text into overlapping chunks for better search precision"""
try:
sentences = nltk.sent_tokenize(text)
chunks = []
current_chunk = ""
for sentence in sentences:
if len(current_chunk) + len(sentence) <= chunk_size:
current_chunk += sentence + " "
else:
# Add current chunk if it has content
if current_chunk:
chunks.append(current_chunk.strip())
# Start new chunk with overlap from previous chunk
if len(current_chunk) > overlap:
# Find the last space within the overlap region
overlap_text = current_chunk[-overlap:]
last_space = overlap_text.rfind(' ')
if last_space != -1:
current_chunk = current_chunk[-(overlap-last_space):] + sentence + " "
else:
current_chunk = sentence + " "
else:
current_chunk = sentence + " "
# Add the last chunk if it has content
if current_chunk:
chunks.append(current_chunk.strip())
return chunks
except:
# Fallback to simpler chunking approach
chunks = []
for i in range(0, len(text), chunk_size - overlap):
chunk = text[i:i + chunk_size]
if chunk:
chunks.append(chunk)
return chunks
def extract_text(self, file_data, file_ext):
"""Extract text from different file types with enhanced support"""
try:
if file_ext.lower() == '.pdf':
reader = PyPDF2.PdfReader(BytesIO(file_data))
text = ""
for page in reader.pages:
extracted = page.extract_text()
if extracted:
text += extracted + "\n"
return text
elif file_ext.lower() in ['.docx', '.doc']:
return docx2txt.process(BytesIO(file_data))
elif file_ext.lower() in ['.txt', '.csv', '.json', '.html', '.htm']:
# Handle both UTF-8 and other common encodings
try:
return file_data.decode('utf-8', errors='ignore')
except:
encodings = ['latin-1', 'iso-8859-1', 'windows-1252']
for enc in encodings:
try:
return file_data.decode(enc, errors='ignore')
except:
pass
# Last resort fallback
return file_data.decode('utf-8', errors='ignore')
elif file_ext.lower() in ['.pptx', '.ppt', '.xlsx', '.xls']:
return f"[Content of {file_ext} file - install additional libraries for full text extraction]"
else:
return ""
except Exception as e:
logger.error(f"Error extracting text: {e}")
return ""
def build_index(self):
"""Build both document and chunk search indices"""
if not self.file_texts:
return False
try:
if self.use_transformer:
# Use sentence transformer models for embeddings
logger.info("Building document and chunk embeddings with transformer model...")
self.vectors = self.model.encode(self.file_texts, show_progress_bar=False)
# Build chunk-level index if we have chunks
if self.chunks:
# Process in batches to avoid memory issues
batch_size = 32
chunk_vectors = []
for i in range(0, len(self.chunks), batch_size):
batch = self.chunks[i:i+batch_size]
batch_vectors = self.model.encode(batch, show_progress_bar=False)
chunk_vectors.append(batch_vectors)
self.chunk_vectors = np.vstack(chunk_vectors)
else:
# Build document-level index
self.vectors = self.vectorizer.fit_transform(self.file_texts)
# Build chunk-level index if we have chunks
if self.chunks:
self.chunk_vectors = self.vectorizer.transform(self.chunks)
return True
except Exception as e:
logger.error(f"Error building search index: {e}")
return False
def expand_query(self, query):
"""Add related terms to query for better recall - mini LLM function"""
# Dictionary of related terms for common keywords
expansions = {
"exam": ["test", "assessment", "quiz", "paper", "exam paper", "past paper", "past exam"],
"test": ["exam", "quiz", "assessment", "paper"],
"document": ["file", "paper", "report", "doc", "documentation"],
"manual": ["guide", "instruction", "documentation", "handbook"],
"tutorial": ["guide", "instructions", "how-to", "lesson"],
"article": ["paper", "publication", "journal", "research"],
"research": ["study", "investigation", "paper", "analysis"],
"book": ["textbook", "publication", "volume", "edition"],
"thesis": ["dissertation", "paper", "research", "study"],
"report": ["document", "paper", "analysis", "summary"],
"assignment": ["homework", "task", "project", "work"],
"lecture": ["class", "presentation", "talk", "lesson"],
"notes": ["annotations", "summary", "outline", "study material"],
"syllabus": ["curriculum", "course outline", "program", "plan"],
"paper": ["document", "article", "publication", "exam", "test"],
"question": ["problem", "query", "exercise", "inquiry"],
"solution": ["answer", "resolution", "explanation", "result"],
"reference": ["source", "citation", "bibliography", "resource"],
"analysis": ["examination", "study", "evaluation", "assessment"],
"guide": ["manual", "instruction", "handbook", "tutorial"],
"worksheet": ["exercise", "activity", "handout", "practice"],
"review": ["evaluation", "assessment", "critique", "feedback"],
"material": ["resource", "content", "document", "information"],
"data": ["information", "statistics", "figures", "numbers"]
}
# Enhanced query expansion simulating a mini-LLM
query_words = re.findall(r'\b\w+\b', query.lower())
expanded_terms = set()
# Directly add expansions from our dictionary
for word in query_words:
if word in expansions:
expanded_terms.update(expansions[word])
# Add common academic file formats if not already included
if any(term in query.lower() for term in ["file", "document", "download", "paper"]):
if not any(ext in query.lower() for ext in ["pdf", "docx", "ppt", "excel"]):
expanded_terms.update(["pdf", "docx", "pptx", "xlsx"])
# Add special academic terms when the query seems related to education
if any(term in query.lower() for term in ["course", "university", "college", "school", "class"]):
expanded_terms.update(["syllabus", "lecture", "notes", "textbook"])
# Return original query plus expanded terms
if expanded_terms:
expanded_query = f"{query} {' '.join(expanded_terms)}"
logger.info(f"Expanded query: '{query}' -> '{expanded_query}'")
return expanded_query
return query
@lru_cache(maxsize=8)
def search(self, query, top_k=5, search_chunks=True):
"""Enhanced search with both document and chunk-level search"""
if self.vectors is None:
return []
# Simulate a small LLM by expanding the query with related terms
expanded_query = self.expand_query(query)
try:
results = []
if self.use_transformer:
# Transform the query to embedding
query_vector = self.model.encode([expanded_query])[0]
# First search at document level for higher-level matches
if self.vectors is not None:
# Compute similarities between query and documents
doc_similarities = cosine_similarity(
query_vector.reshape(1, -1),
self.vectors
).flatten()
top_doc_indices = doc_similarities.argsort()[-top_k:][::-1]
for i, idx in enumerate(top_doc_indices):
if doc_similarities[idx] > 0.2: # Threshold to exclude irrelevant results
results.append({
'file_info': self.file_metadata[idx],
'score': float(doc_similarities[idx]),
'rank': i+1,
'match_type': 'document',
'language': self.languages[idx] if idx < len(self.languages) else 'unknown'
})
# Then search at chunk level for more specific matches if enabled
if search_chunks and self.chunk_vectors is not None:
# Compute similarities between query and chunks
chunk_similarities = cosine_similarity(
query_vector.reshape(1, -1),
self.chunk_vectors
).flatten()
top_chunk_indices = chunk_similarities.argsort()[-top_k*2:][::-1] # Get more chunk results
# Use a set to avoid duplicate file results
seen_files = set(r['file_info']['url'] for r in results)
for i, idx in enumerate(top_chunk_indices):
if chunk_similarities[idx] > 0.25: # Higher threshold for chunks
file_index = self.chunk_metadata[idx]['file_index']
file_info = self.file_metadata[file_index]
# Only add if we haven't already included this file
if file_info['url'] not in seen_files:
seen_files.add(file_info['url'])
results.append({
'file_info': file_info,
'score': float(chunk_similarities[idx]),
'rank': len(results) + 1,
'match_type': 'chunk',
'language': self.languages[file_index] if file_index < len(self.languages) else 'unknown',
'chunk_preview': self.chunks[idx][:200] + "..." if len(self.chunks[idx]) > 200 else self.chunks[idx]
})
# Stop after we've found enough results
if len(results) >= top_k*1.5:
break
else:
# Fallback to TF-IDF if transformers not available
query_vector = self.vectorizer.transform([expanded_query])
# First search at document level
if self.vectors is not None:
doc_similarities = cosine_similarity(query_vector, self.vectors).flatten()
top_doc_indices = doc_similarities.argsort()[-top_k:][::-1]
for i, idx in enumerate(top_doc_indices):
if doc_similarities[idx] > 0.1: # Threshold to exclude irrelevant results
results.append({
'file_info': self.file_metadata[idx],
'score': float(doc_similarities[idx]),
'rank': i+1,
'match_type': 'document',
'language': self.languages[idx] if idx < len(self.languages) else 'unknown'
})
# Then search at chunk level if enabled
if search_chunks and self.chunk_vectors is not None:
chunk_similarities = cosine_similarity(query_vector, self.chunk_vectors).flatten()
top_chunk_indices = chunk_similarities.argsort()[-top_k*2:][::-1]
# Avoid duplicates
seen_files = set(r['file_info']['url'] for r in results)
for i, idx in enumerate(top_chunk_indices):
if chunk_similarities[idx] > 0.15:
file_index = self.chunk_metadata[idx]['file_index']
file_info = self.file_metadata[file_index]
if file_info['url'] not in seen_files:
seen_files.add(file_info['url'])
results.append({
'file_info': file_info,
'score': float(chunk_similarities[idx]),
'rank': len(results) + 1,
'match_type': 'chunk',
'language': self.languages[file_index] if file_index < len(self.languages) else 'unknown',
'chunk_preview': self.chunks[idx][:200] + "..." if len(self.chunks[idx]) > 200 else self.chunks[idx]
})
if len(results) >= top_k*1.5:
break
# Sort combined results by score
results.sort(key=lambda x: x['score'], reverse=True)
# Re-rank and truncate
for i, result in enumerate(results[:top_k]):
result['rank'] = i+1
return results[:top_k]
except Exception as e:
logger.error(f"Error during search: {e}")
return []
def clear_cache(self):
"""Clear search cache and free memory"""
if hasattr(self.search, 'cache_clear'):
self.search.cache_clear()
# Clear vectors to free memory
self.vectors = None
self.chunk_vectors = None
# Force garbage collection
import gc
gc.collect()