Rohil Bansal
search improved
821284f
from sentence_transformers import SentenceTransformer
import numpy as np
import pandas as pd
import faiss
import logging
from typing import List, Dict
from pathlib import Path
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class RAGSystem:
def __init__(self):
"""Initialize the RAG system"""
try:
self.model = SentenceTransformer('all-MiniLM-L6-v2')
self.embeddings = None
self.index = None
self.df = None
logger.info("RAG system initialized successfully")
except Exception as e:
logger.error(f"Error initializing RAG system: {str(e)}")
raise
def load_and_process_data(self, df: pd.DataFrame, cache_dir: Path = None):
"""Load and process the course data with caching support"""
try:
# Validate input
if df is None or len(df) == 0:
raise ValueError("Empty or None DataFrame provided")
required_columns = ['title', 'description', 'curriculum', 'url']
missing_columns = [col for col in required_columns if col not in df.columns]
if missing_columns:
raise ValueError(f"Missing required columns: {missing_columns}")
self.df = df
vector_dimension = 384 # dimension for all-MiniLM-L6-v2
# Try loading from cache first
if cache_dir is not None:
cache_dir.mkdir(exist_ok=True)
embeddings_path = cache_dir / 'course_embeddings.npy'
index_path = cache_dir / 'faiss_index.bin'
if embeddings_path.exists() and index_path.exists():
logger.info("Loading cached embeddings and index...")
try:
self.embeddings = np.load(str(embeddings_path))
self.index = faiss.read_index(str(index_path))
logger.info("Successfully loaded cached data")
return
except Exception as e:
logger.warning(f"Failed to load cache: {e}. Computing new embeddings...")
# Compute new embeddings
logger.info("Computing course embeddings...")
texts = [
f"{row['title']}. {row['description']}"
for _, row in df.iterrows()
]
if not texts:
raise ValueError("No texts to encode")
self.embeddings = self.model.encode(
texts,
show_progress_bar=True,
convert_to_numpy=True
)
if self.embeddings.size == 0:
raise ValueError("Failed to generate embeddings")
# Create and populate FAISS index
self.index = faiss.IndexFlatL2(vector_dimension)
self.index.add(self.embeddings.astype('float32'))
# Save to cache if directory provided
if cache_dir is not None:
logger.info("Saving embeddings and index to cache...")
np.save(str(embeddings_path), self.embeddings)
faiss.write_index(self.index, str(index_path))
logger.info(f"Successfully processed {len(df)} courses")
except Exception as e:
logger.error(f"Error processing data: {str(e)}")
raise
def search_courses(self, query: str, top_k: int = 5) -> Dict:
"""Search for courses using semantic search with improved ranking"""
try:
# Ensure the FAISS index is initialized
if self.index is None:
raise ValueError("FAISS index not initialized. Please load data first.")
# Get query embedding
query_embedding = self.model.encode([query], convert_to_numpy=True)
# Get initial similarity scores
D, I = self.index.search(query_embedding.reshape(1, -1), top_k * 2)
distances = D[0]
indices = I[0]
# Get results with additional metadata
results = []
for dist, idx in zip(distances, indices):
course = self.df.iloc[idx].to_dict()
# Calculate relevance score components
title_similarity = self.calculate_text_similarity(query, course['title'])
desc_similarity = self.calculate_text_similarity(query, course['description'])
# Combine scores with weights
final_score = (
0.4 * (1 - dist) +
0.4 * title_similarity +
0.2 * desc_similarity
)
results.append({
**course,
'relevance_score': final_score
})
# Sort by final relevance score and take top_k
results.sort(key=lambda x: x['relevance_score'], reverse=True)
results = results[:top_k]
return {
'query': query,
'results': results
}
except Exception as e:
logger.error(f"Error in search_courses: {str(e)}")
raise
def calculate_text_similarity(self, text1: str, text2: str) -> float:
"""
Calculate text similarity between two strings using word overlap
Args:
text1 (str): First text string
text2 (str): Second text string
Returns:
float: Similarity score between 0 and 1
"""
try:
# Convert to lowercase and split into words
text1 = str(text1).lower()
text2 = str(text2).lower()
words1 = set(text1.split())
words2 = set(text2.split())
if not words1 or not words2:
return 0.0
# Calculate Jaccard similarity
intersection = len(words1.intersection(words2))
union = len(words1.union(words2))
if union == 0:
return 0.0
similarity = intersection / union
return similarity
except Exception as e:
logger.error(f"Error calculating text similarity: {str(e)}")
return 0.0