Pepe_1 / models /rag_system.py
nileshhanotia's picture
Update models/rag_system.py
d7a26ff verified
import os
import numpy as np
import pandas as pd
from transformers import pipeline
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from utils.logger import setup_logger
from utils.model_loader import ModelLoader
logger = setup_logger(__name__)
class RAGSystem:
def __init__(self, csv_path="apparel.csv"):
try:
# Initialize the sentence transformer model
self.embedder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
# Initialize the QA pipeline
self.qa_pipeline = pipeline(
"question-answering",
model="distilbert-base-cased-distilled-squad",
tokenizer="distilbert-base-cased-distilled-squad"
)
self.setup_system(csv_path)
except Exception as e:
logger.error(f"Failed to initialize RAGSystem: {str(e)}")
raise
def setup_system(self, csv_path):
if not os.path.exists(csv_path):
raise FileNotFoundError(f"CSV file not found at {csv_path}")
try:
# Load and preprocess documents
self.documents = pd.read_csv(csv_path)
self.texts = self.documents['Title'].astype(str).tolist()
# Create embeddings for all documents
self.embeddings = self.embedder.encode(self.texts)
logger.info(f"Successfully loaded {len(self.texts)} documents")
except Exception as e:
logger.error(f"Failed to setup RAG system: {str(e)}")
raise
def get_relevant_documents(self, query, top_k=5):
try:
# Get query embedding
query_embedding = self.embedder.encode([query])
# Calculate similarities
similarities = cosine_similarity(query_embedding, self.embeddings)[0]
# Get top k most similar documents
top_indices = np.argsort(similarities)[-top_k:][::-1]
return [self.texts[i] for i in top_indices]
except Exception as e:
logger.error(f"Error retrieving relevant documents: {str(e)}")
return []
def process_query(self, query):
try:
# Get relevant documents
relevant_docs = self.get_relevant_documents(query)
if not relevant_docs:
return "No relevant documents found."
# Combine retrieved documents into context
context = " ".join(relevant_docs)
# Prepare QA input
qa_input = {
"question": query,
"context": context[:512] # Limit context length for the model
}
# Get answer using QA pipeline
answer = self.qa_pipeline(qa_input)
return answer['answer']
except Exception as e:
logger.error(f"Error processing query: {str(e)}")
return f"Failed to process query: {str(e)}"