File size: 10,780 Bytes
12c661c
e96852d
 
 
 
 
 
 
 
8ce796d
e96852d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ce796d
e96852d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12c661c
e96852d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
import os
from typing import Dict, List, Optional, Union, Any
from smolagents import Tool
from langchain.vectorstores import FAISS, Chroma
from langchain.embeddings import HuggingFaceBgeEmbeddings, HuggingFaceEmbeddings
from langchain.document_loaders import PyPDFLoader, TextLoader, DirectoryLoader
from langchain.text_splitter import CharacterTextSplitter, RecursiveCharacterTextSplitter
from PyPDF2 import PdfReader
import json

class RAGTool(Tool):
    name = "rag_retriever"
    description = """
    Advanced RAG (Retrieval-Augmented Generation) tool that searches in vector stores based on given prompts.
    This tool allows you to query documents stored in vector databases using semantic similarity.
    It supports various configurations including different embedding models, vector stores, and document types.
    """
    inputs = {
        "query": {
            "type": "string",
            "description": "The search query to retrieve relevant information from the document store",
        },
        "top_k": {
            "type": "integer",
            "description": "Number of most relevant documents to retrieve (default: 3)",
        }
    }
    output_type = "string"
    
    def __init__(self, 
                 documents_path: str = "./documents",
                 embedding_model: str = "BAAI/bge-small-en-v1.5",
                 vector_store_type: str = "faiss",
                 chunk_size: int = 1000,
                 chunk_overlap: int = 200,
                 persist_directory: str = "./vector_store",
                 device: str = "cpu"):
        """
        Initialize the RAG Tool with configurable parameters.
        
        Args:
            documents_path: Path to documents or folder containing documents
            embedding_model: HuggingFace model ID for embeddings
            vector_store_type: Type of vector store ('faiss' or 'chroma')
            chunk_size: Size of text chunks for splitting documents
            chunk_overlap: Overlap between text chunks
            persist_directory: Directory to persist vector store
            device: Device to run embedding model on ('cpu' or 'cuda')
        """
        super().__init__()
        self.documents_path = documents_path
        self.embedding_model = embedding_model
        self.vector_store_type = vector_store_type
        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap
        self.persist_directory = persist_directory
        self.device = device
        
        # Create the vector store if it doesn't exist
        os.makedirs(persist_directory, exist_ok=True)
        self._setup_vector_store()
        
    def _setup_vector_store(self):
        """Set up the vector store with documents if it doesn't exist"""
        # Check if we need to create a new vector store
        if not os.path.exists(os.path.join(self.persist_directory, "index.faiss")) and \
           not os.path.exists(os.path.join(self.persist_directory, "chroma")):
            # Check if documents path exists
            if not os.path.exists(self.documents_path):
                print(f"Warning: Documents path {self.documents_path} does not exist.")
                return
            
            # Load and process documents
            documents = self._load_documents()
            if not documents:
                print("No documents loaded. Vector store not created.")
                return
            
            # Create the vector store
            self._create_vector_store(documents)
        else:
            print(f"Vector store already exists at {self.persist_directory}")
            self._load_vector_store()
    
    def _get_embeddings(self):
        """Get embedding model based on configuration"""
        try:
            if "bge" in self.embedding_model.lower():
                encode_kwargs = {"normalize_embeddings": True}
                return HuggingFaceBgeEmbeddings(
                    model_name=self.embedding_model,
                    encode_kwargs=encode_kwargs,
                    model_kwargs={"device": self.device}
                )
            else:
                return HuggingFaceEmbeddings(
                    model_name=self.embedding_model,
                    model_kwargs={"device": self.device}
                )
        except Exception as e:
            print(f"Error loading embedding model: {e}")
            # Fallback to a reliable model
            print("Falling back to sentence-transformers/all-MiniLM-L6-v2")
            return HuggingFaceEmbeddings(
                model_name="sentence-transformers/all-MiniLM-L6-v2",
                model_kwargs={"device": self.device}
            )
            
    def _load_documents(self):
        """Load documents from the documents path"""
        documents = []
        
        # Check if documents_path is a file or directory
        if os.path.isfile(self.documents_path):
            # Load single file
            if self.documents_path.lower().endswith('.pdf'):
                try:
                    loader = PyPDFLoader(self.documents_path)
                    documents = loader.load()
                except Exception as e:
                    print(f"Error loading PDF: {e}")
                    # Fallback to using PdfReader
                    try:
                        text = self._extract_text_from_pdf(self.documents_path)
                        splitter = CharacterTextSplitter(
                            separator="\n", 
                            chunk_size=self.chunk_size, 
                            chunk_overlap=self.chunk_overlap
                        )
                        documents = splitter.create_documents([text])
                    except Exception as e2:
                        print(f"Error with fallback PDF extraction: {e2}")
            elif self.documents_path.lower().endswith(('.txt', '.md', '.html')):
                loader = TextLoader(self.documents_path)
                documents = loader.load()
        elif os.path.isdir(self.documents_path):
            # Load all supported files in directory
            try:
                loader = DirectoryLoader(
                    self.documents_path,
                    glob="**/*.*",
                    loader_cls=TextLoader,
                    loader_kwargs={"autodetect_encoding": True}
                )
                documents = loader.load()
            except Exception as e:
                print(f"Error loading directory: {e}")
        
        # Split documents into chunks if they exist
        if documents:
            splitter = RecursiveCharacterTextSplitter(
                chunk_size=self.chunk_size,
                chunk_overlap=self.chunk_overlap
            )
            return splitter.split_documents(documents)
        return []
    
    def _extract_text_from_pdf(self, pdf_path):
        """Extract text from PDF using PyPDF2"""
        text = ""
        pdf_reader = PdfReader(pdf_path)
        for page in pdf_reader.pages:
            text += page.extract_text()
        return text
            
    def _create_vector_store(self, documents):
        """Create a new vector store from documents"""
        embeddings = self._get_embeddings()
        
        if self.vector_store_type.lower() == "faiss":
            vector_store = FAISS.from_documents(documents, embeddings)
            vector_store.save_local(self.persist_directory)
            print(f"Created FAISS vector store at {self.persist_directory}")
        else:  # Default to Chroma
            vector_store = Chroma.from_documents(
                documents, 
                embeddings,
                persist_directory=self.persist_directory
            )
            vector_store.persist()
            print(f"Created Chroma vector store at {self.persist_directory}")
        
        self.vector_store = vector_store
    
    def _load_vector_store(self):
        """Load an existing vector store"""
        embeddings = self._get_embeddings()
        
        try:
            if self.vector_store_type.lower() == "faiss":
                self.vector_store = FAISS.load_local(self.persist_directory, embeddings)
                print(f"Loaded FAISS vector store from {self.persist_directory}")
            else:  # Default to Chroma
                self.vector_store = Chroma(
                    persist_directory=self.persist_directory,
                    embedding_function=embeddings
                )
                print(f"Loaded Chroma vector store from {self.persist_directory}")
        except Exception as e:
            print(f"Error loading vector store: {e}")
            print("Creating a new vector store...")
            documents = self._load_documents()
            if documents:
                self._create_vector_store(documents)
            else:
                print("No documents available. Cannot create vector store.")
                self.vector_store = None

    def forward(self, query: str, top_k: int = 3) -> str:
        """
        Retrieve relevant documents based on the query.
        
        Args:
            query: The search query
            top_k: Number of results to return
            
        Returns:
            String with formatted search results
        """
        if not hasattr(self, 'vector_store') or self.vector_store is None:
            return "Vector store is not initialized. Please check your configuration."
        
        try:
            # Perform similarity search
            results = self.vector_store.similarity_search(query, k=top_k)
            
            # Format results
            formatted_results = []
            for i, doc in enumerate(results):
                content = doc.page_content
                metadata = doc.metadata
                
                # Format metadata nicely
                meta_str = ""
                if metadata:
                    meta_str = "\nSource: "
                    if "source" in metadata:
                        meta_str += metadata["source"]
                    if "page" in metadata:
                        meta_str += f", Page: {metadata['page']}"
                
                formatted_results.append(f"Document {i+1}:\n{content}{meta_str}\n")
            
            if formatted_results:
                return "Retrieved relevant information:\n\n" + "\n".join(formatted_results)
            else:
                return "No relevant information found for the query."
        except Exception as e:
            return f"Error retrieving information: {str(e)}"

# Example usage:
# rag_tool = RAGTool(
#     documents_path="./my_docs",
#     embedding_model="sentence-transformers/all-MiniLM-L6-v2",
#     vector_store_type="faiss",
#     chunk_size=1000,
#     chunk_overlap=200
# )