Spaces:
Sleeping
Sleeping
from sentence_transformers import SentenceTransformer | |
import numpy as np | |
import pandas as pd | |
import faiss | |
import logging | |
from typing import List, Dict | |
from pathlib import Path | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class RAGSystem: | |
def __init__(self): | |
"""Initialize the RAG system""" | |
try: | |
self.model = SentenceTransformer('all-MiniLM-L6-v2') | |
self.embeddings = None | |
self.index = None | |
self.df = None | |
logger.info("RAG system initialized successfully") | |
except Exception as e: | |
logger.error(f"Error initializing RAG system: {str(e)}") | |
raise | |
def load_and_process_data(self, df: pd.DataFrame, cache_dir: Path = None): | |
"""Load and process the course data with caching support""" | |
try: | |
# Validate input | |
if df is None or len(df) == 0: | |
raise ValueError("Empty or None DataFrame provided") | |
required_columns = ['title', 'description', 'curriculum', 'url'] | |
missing_columns = [col for col in required_columns if col not in df.columns] | |
if missing_columns: | |
raise ValueError(f"Missing required columns: {missing_columns}") | |
self.df = df | |
vector_dimension = 384 # dimension for all-MiniLM-L6-v2 | |
# Try loading from cache first | |
if cache_dir is not None: | |
cache_dir.mkdir(exist_ok=True) | |
embeddings_path = cache_dir / 'course_embeddings.npy' | |
index_path = cache_dir / 'faiss_index.bin' | |
if embeddings_path.exists() and index_path.exists(): | |
logger.info("Loading cached embeddings and index...") | |
try: | |
self.embeddings = np.load(str(embeddings_path)) | |
self.index = faiss.read_index(str(index_path)) | |
logger.info("Successfully loaded cached data") | |
return | |
except Exception as e: | |
logger.warning(f"Failed to load cache: {e}. Computing new embeddings...") | |
# Compute new embeddings | |
logger.info("Computing course embeddings...") | |
texts = [ | |
f"{row['title']}. {row['description']}" | |
for _, row in df.iterrows() | |
] | |
if not texts: | |
raise ValueError("No texts to encode") | |
self.embeddings = self.model.encode( | |
texts, | |
show_progress_bar=True, | |
convert_to_numpy=True | |
) | |
if self.embeddings.size == 0: | |
raise ValueError("Failed to generate embeddings") | |
# Create and populate FAISS index | |
self.index = faiss.IndexFlatL2(vector_dimension) | |
self.index.add(self.embeddings.astype('float32')) | |
# Save to cache if directory provided | |
if cache_dir is not None: | |
logger.info("Saving embeddings and index to cache...") | |
np.save(str(embeddings_path), self.embeddings) | |
faiss.write_index(self.index, str(index_path)) | |
logger.info(f"Successfully processed {len(df)} courses") | |
except Exception as e: | |
logger.error(f"Error processing data: {str(e)}") | |
raise | |
def search_courses(self, query: str, top_k: int = 5) -> Dict: | |
"""Search for courses using semantic search with improved ranking""" | |
try: | |
# Ensure the FAISS index is initialized | |
if self.index is None: | |
raise ValueError("FAISS index not initialized. Please load data first.") | |
# Get query embedding | |
query_embedding = self.model.encode([query], convert_to_numpy=True) | |
# Get initial similarity scores | |
D, I = self.index.search(query_embedding.reshape(1, -1), top_k * 2) | |
distances = D[0] | |
indices = I[0] | |
# Get results with additional metadata | |
results = [] | |
for dist, idx in zip(distances, indices): | |
course = self.df.iloc[idx].to_dict() | |
# Calculate relevance score components | |
title_similarity = self.calculate_text_similarity(query, course['title']) | |
desc_similarity = self.calculate_text_similarity(query, course['description']) | |
# Combine scores with weights | |
final_score = ( | |
0.4 * (1 - dist) + | |
0.4 * title_similarity + | |
0.2 * desc_similarity | |
) | |
results.append({ | |
**course, | |
'relevance_score': final_score | |
}) | |
# Sort by final relevance score and take top_k | |
results.sort(key=lambda x: x['relevance_score'], reverse=True) | |
results = results[:top_k] | |
return { | |
'query': query, | |
'results': results | |
} | |
except Exception as e: | |
logger.error(f"Error in search_courses: {str(e)}") | |
raise | |
def calculate_text_similarity(self, text1: str, text2: str) -> float: | |
""" | |
Calculate text similarity between two strings using word overlap | |
Args: | |
text1 (str): First text string | |
text2 (str): Second text string | |
Returns: | |
float: Similarity score between 0 and 1 | |
""" | |
try: | |
# Convert to lowercase and split into words | |
text1 = str(text1).lower() | |
text2 = str(text2).lower() | |
words1 = set(text1.split()) | |
words2 = set(text2.split()) | |
if not words1 or not words2: | |
return 0.0 | |
# Calculate Jaccard similarity | |
intersection = len(words1.intersection(words2)) | |
union = len(words1.union(words2)) | |
if union == 0: | |
return 0.0 | |
similarity = intersection / union | |
return similarity | |
except Exception as e: | |
logger.error(f"Error calculating text similarity: {str(e)}") | |
return 0.0 |