Spaces:
Build error
Build error
import os | |
import numpy as np | |
import pandas as pd | |
from transformers import pipeline | |
from sentence_transformers import SentenceTransformer | |
from sklearn.metrics.pairwise import cosine_similarity | |
from utils.logger import setup_logger | |
from utils.model_loader import ModelLoader | |
logger = setup_logger(__name__) | |
class RAGSystem: | |
def __init__(self, csv_path="apparel.csv"): | |
try: | |
# Initialize the sentence transformer model | |
self.embedder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') | |
# Initialize the QA pipeline | |
self.qa_pipeline = pipeline( | |
"question-answering", | |
model="distilbert-base-cased-distilled-squad", | |
tokenizer="distilbert-base-cased-distilled-squad" | |
) | |
self.setup_system(csv_path) | |
except Exception as e: | |
logger.error(f"Failed to initialize RAGSystem: {str(e)}") | |
raise | |
def setup_system(self, csv_path): | |
if not os.path.exists(csv_path): | |
raise FileNotFoundError(f"CSV file not found at {csv_path}") | |
try: | |
# Load and preprocess documents | |
self.documents = pd.read_csv(csv_path) | |
self.texts = self.documents['Title'].astype(str).tolist() | |
# Create embeddings for all documents | |
self.embeddings = self.embedder.encode(self.texts) | |
logger.info(f"Successfully loaded {len(self.texts)} documents") | |
except Exception as e: | |
logger.error(f"Failed to setup RAG system: {str(e)}") | |
raise | |
def get_relevant_documents(self, query, top_k=5): | |
try: | |
# Get query embedding | |
query_embedding = self.embedder.encode([query]) | |
# Calculate similarities | |
similarities = cosine_similarity(query_embedding, self.embeddings)[0] | |
# Get top k most similar documents | |
top_indices = np.argsort(similarities)[-top_k:][::-1] | |
return [self.texts[i] for i in top_indices] | |
except Exception as e: | |
logger.error(f"Error retrieving relevant documents: {str(e)}") | |
return [] | |
def process_query(self, query): | |
try: | |
# Get relevant documents | |
relevant_docs = self.get_relevant_documents(query) | |
if not relevant_docs: | |
return "No relevant documents found." | |
# Combine retrieved documents into context | |
context = " ".join(relevant_docs) | |
# Prepare QA input | |
qa_input = { | |
"question": query, | |
"context": context[:512] # Limit context length for the model | |
} | |
# Get answer using QA pipeline | |
answer = self.qa_pipeline(qa_input) | |
return answer['answer'] | |
except Exception as e: | |
logger.error(f"Error processing query: {str(e)}") | |
return f"Failed to process query: {str(e)}" |