File size: 3,181 Bytes
3a16d21
3b3a6c5
d7a26ff
 
3b3a6c5
d7a26ff
3a16d21
 
 
 
 
 
 
 
d7a26ff
 
 
 
 
 
 
 
3a16d21
d7a26ff
 
 
3a16d21
 
 
 
 
 
 
 
 
d7a26ff
3b3a6c5
d7a26ff
 
3b3a6c5
d7a26ff
 
 
 
3a16d21
 
 
 
3b3a6c5
d7a26ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3b3a6c5
3a16d21
 
d7a26ff
 
 
 
 
 
 
 
3a16d21
d7a26ff
3a16d21
 
d7a26ff
3a16d21
d7a26ff
 
 
 
 
 
3a16d21
d7a26ff
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import os
import numpy as np
import pandas as pd
from transformers import pipeline
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from utils.logger import setup_logger
from utils.model_loader import ModelLoader

logger = setup_logger(__name__)

class RAGSystem:
    def __init__(self, csv_path="apparel.csv"):
        try:
            # Initialize the sentence transformer model
            self.embedder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
            
            # Initialize the QA pipeline
            self.qa_pipeline = pipeline(
                "question-answering",
                model="distilbert-base-cased-distilled-squad",
                tokenizer="distilbert-base-cased-distilled-squad"
            )
            
            self.setup_system(csv_path)
            
        except Exception as e:
            logger.error(f"Failed to initialize RAGSystem: {str(e)}")
            raise

    def setup_system(self, csv_path):
        if not os.path.exists(csv_path):
            raise FileNotFoundError(f"CSV file not found at {csv_path}")
        
        try:
            # Load and preprocess documents
            self.documents = pd.read_csv(csv_path)
            self.texts = self.documents['Title'].astype(str).tolist()
            
            # Create embeddings for all documents
            self.embeddings = self.embedder.encode(self.texts)
            
            logger.info(f"Successfully loaded {len(self.texts)} documents")
            
        except Exception as e:
            logger.error(f"Failed to setup RAG system: {str(e)}")
            raise

    def get_relevant_documents(self, query, top_k=5):
        try:
            # Get query embedding
            query_embedding = self.embedder.encode([query])
            
            # Calculate similarities
            similarities = cosine_similarity(query_embedding, self.embeddings)[0]
            
            # Get top k most similar documents
            top_indices = np.argsort(similarities)[-top_k:][::-1]
            
            return [self.texts[i] for i in top_indices]
            
        except Exception as e:
            logger.error(f"Error retrieving relevant documents: {str(e)}")
            return []

    def process_query(self, query):
        try:
            # Get relevant documents
            relevant_docs = self.get_relevant_documents(query)
            
            if not relevant_docs:
                return "No relevant documents found."
            
            # Combine retrieved documents into context
            context = " ".join(relevant_docs)
            
            # Prepare QA input
            qa_input = {
                "question": query,
                "context": context[:512]  # Limit context length for the model
            }
            
            # Get answer using QA pipeline
            answer = self.qa_pipeline(qa_input)
            
            return answer['answer']
            
        except Exception as e:
            logger.error(f"Error processing query: {str(e)}")
            return f"Failed to process query: {str(e)}"