File size: 11,312 Bytes
12c661c
e96852d
 
103a876
 
 
 
e96852d
 
8ce796d
e96852d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103a876
e96852d
 
 
 
103a876
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e96852d
 
 
 
 
 
 
 
103a876
e96852d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103a876
 
e96852d
 
fbe6a2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e96852d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12c661c
e96852d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fbe6a2e
e96852d
 
 
 
 
fbe6a2e
e96852d
 
 
 
fbe6a2e
 
 
e96852d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
import os
from typing import Dict, List, Optional, Union, Any
from smolagents import Tool
from langchain_community.vectorstores import FAISS, Chroma
from langchain_community.embeddings import HuggingFaceBgeEmbeddings, HuggingFaceEmbeddings
from langchain_community.document_loaders import PyPDFLoader, TextLoader, DirectoryLoader
from langchain_text_splitters import CharacterTextSplitter, RecursiveCharacterTextSplitter
from PyPDF2 import PdfReader
import json

class RAGTool(Tool):
    name = "rag_retriever"
    description = """
    Advanced RAG (Retrieval-Augmented Generation) tool that searches in vector stores based on given prompts.
    This tool allows you to query documents stored in vector databases using semantic similarity.
    It supports various configurations including different embedding models, vector stores, and document types.
    """
    inputs = {
        "query": {
            "type": "string",
            "description": "The search query to retrieve relevant information from the document store",
        },
        "top_k": {
            "type": "integer",
            "description": "Number of most relevant documents to retrieve (default: 3)",
            "nullable": True
        }
    }
    output_type = "string"
    
    def __init__(self):
        """
        Initialize the RAG Tool with default settings.
        All configuration is done via class attributes or through the configure method.
        """
        super().__init__()
        self.documents_path = "./documents"
        self.embedding_model = "BAAI/bge-small-en-v1.5"
        self.vector_store_type = "faiss"
        self.chunk_size = 1000
        self.chunk_overlap = 200
        self.persist_directory = "./vector_store"
        self.device = "cpu"
        
        # Don't automatically create storage initially, wait for explicit setup
        self.vector_store = None
        
    def configure(self, 
                 documents_path: str = "./documents",
                 embedding_model: str = "BAAI/bge-small-en-v1.5",
                 vector_store_type: str = "faiss",
                 chunk_size: int = 1000,
                 chunk_overlap: int = 200,
                 persist_directory: str = "./vector_store",
                 device: str = "cpu"):
        """
        Configure the RAG Tool with custom parameters.
        
        Args:
            documents_path: Path to documents or folder containing documents
            embedding_model: HuggingFace model ID for embeddings
            vector_store_type: Type of vector store ('faiss' or 'chroma')
            chunk_size: Size of text chunks for splitting documents
            chunk_overlap: Overlap between text chunks
            persist_directory: Directory to persist vector store
            device: Device to run embedding model on ('cpu' or 'cuda')
        """
        self.documents_path = documents_path
        self.embedding_model = embedding_model
        self.vector_store_type = vector_store_type
        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap
        self.persist_directory = persist_directory
        self.device = device
        
        # Create the vector store if it doesn't exist
        os.makedirs(persist_directory, exist_ok=True)
        self._setup_vector_store()
        
        return self
        
    def _setup_vector_store(self):
        """Set up the vector store with documents if it doesn't exist"""
        # Always try to create directories if they don't exist
        os.makedirs(self.persist_directory, exist_ok=True)
        
        # Check if documents path exists
        if not os.path.exists(self.documents_path):
            print(f"Warning: Documents path {self.documents_path} does not exist.")
            return
        
        # Force creation of vector store from documents
        documents = self._load_documents()
        if not documents:
            print("No documents loaded. Vector store not created.")
            return
        
        # Create the vector store
        self._create_vector_store(documents)
    
    def _get_embeddings(self):
        """Get embedding model based on configuration"""
        try:
            if "bge" in self.embedding_model.lower():
                encode_kwargs = {"normalize_embeddings": True}
                return HuggingFaceBgeEmbeddings(
                    model_name=self.embedding_model,
                    encode_kwargs=encode_kwargs,
                    model_kwargs={"device": self.device}
                )
            else:
                return HuggingFaceEmbeddings(
                    model_name=self.embedding_model,
                    model_kwargs={"device": self.device}
                )
        except Exception as e:
            print(f"Error loading embedding model: {e}")
            # Fallback to a reliable model
            print("Falling back to sentence-transformers/all-MiniLM-L6-v2")
            return HuggingFaceEmbeddings(
                model_name="sentence-transformers/all-MiniLM-L6-v2",
                model_kwargs={"device": self.device}
            )
            
    def _load_documents(self):
        """Load documents from the documents path"""
        documents = []
        
        # Check if documents_path is a file or directory
        if os.path.isfile(self.documents_path):
            # Load single file
            if self.documents_path.lower().endswith('.pdf'):
                try:
                    loader = PyPDFLoader(self.documents_path)
                    documents = loader.load()
                except Exception as e:
                    print(f"Error loading PDF: {e}")
                    # Fallback to using PdfReader
                    try:
                        text = self._extract_text_from_pdf(self.documents_path)
                        splitter = CharacterTextSplitter(
                            separator="\n", 
                            chunk_size=self.chunk_size, 
                            chunk_overlap=self.chunk_overlap
                        )
                        documents = splitter.create_documents([text])
                    except Exception as e2:
                        print(f"Error with fallback PDF extraction: {e2}")
            elif self.documents_path.lower().endswith(('.txt', '.md', '.html')):
                loader = TextLoader(self.documents_path)
                documents = loader.load()
        elif os.path.isdir(self.documents_path):
            # Load all supported files in directory
            try:
                loader = DirectoryLoader(
                    self.documents_path,
                    glob="**/*.*",
                    loader_cls=TextLoader,
                    loader_kwargs={"autodetect_encoding": True}
                )
                documents = loader.load()
            except Exception as e:
                print(f"Error loading directory: {e}")
        
        # Split documents into chunks if they exist
        if documents:
            splitter = RecursiveCharacterTextSplitter(
                chunk_size=self.chunk_size,
                chunk_overlap=self.chunk_overlap
            )
            return splitter.split_documents(documents)
        return []
    
    def _extract_text_from_pdf(self, pdf_path):
        """Extract text from PDF using PyPDF2"""
        text = ""
        pdf_reader = PdfReader(pdf_path)
        for page in pdf_reader.pages:
            text += page.extract_text()
        return text
            
    def _create_vector_store(self, documents):
        """Create a new vector store from documents"""
        embeddings = self._get_embeddings()
        
        if self.vector_store_type.lower() == "faiss":
            vector_store = FAISS.from_documents(documents, embeddings)
            vector_store.save_local(self.persist_directory)
            print(f"Created FAISS vector store at {self.persist_directory}")
        else:  # Default to Chroma
            vector_store = Chroma.from_documents(
                documents, 
                embeddings,
                persist_directory=self.persist_directory
            )
            vector_store.persist()
            print(f"Created Chroma vector store at {self.persist_directory}")
        
        self.vector_store = vector_store
    
    def _load_vector_store(self):
        """Load an existing vector store"""
        embeddings = self._get_embeddings()
        
        try:
            if self.vector_store_type.lower() == "faiss":
                self.vector_store = FAISS.load_local(self.persist_directory, embeddings)
                print(f"Loaded FAISS vector store from {self.persist_directory}")
            else:  # Default to Chroma
                self.vector_store = Chroma(
                    persist_directory=self.persist_directory,
                    embedding_function=embeddings
                )
                print(f"Loaded Chroma vector store from {self.persist_directory}")
        except Exception as e:
            print(f"Error loading vector store: {e}")
            print("Creating a new vector store...")
            documents = self._load_documents()
            if documents:
                self._create_vector_store(documents)
            else:
                print("No documents available. Cannot create vector store.")
                self.vector_store = None

    def forward(self, query: str, top_k: int = None) -> str:
        """
        Retrieve relevant documents based on the query.
        
        Args:
            query: The search query
            top_k: Number of results to return (default: 3)
            
        Returns:
            String with formatted search results
        """
        # Set default value if None
        if top_k is None:
            top_k = 3
        if not hasattr(self, 'vector_store') or self.vector_store is None:
            return "Vector store is not initialized. Please check your configuration."
        
        try:
            # Perform similarity search
            results = self.vector_store.similarity_search(query, k=top_k)
            
            # Format results
            formatted_results = []
            for i, doc in enumerate(results):
                content = doc.page_content
                metadata = doc.metadata
                
                # Format metadata nicely
                meta_str = ""
                if metadata:
                    meta_str = "\nSource: "
                    if "source" in metadata:
                        meta_str += metadata["source"]
                    if "page" in metadata:
                        meta_str += f", Page: {metadata['page']}"
                
                formatted_results.append(f"Document {i+1}:\n{content}{meta_str}\n")
            
            if formatted_results:
                return "Retrieved relevant information:\n\n" + "\n".join(formatted_results)
            else:
                return "No relevant information found for the query."
        except Exception as e:
            return f"Error retrieving information: {str(e)}"

# Example usage:
# rag_tool = RAGTool(
#     documents_path="./my_docs",
#     embedding_model="sentence-transformers/all-MiniLM-L6-v2",
#     vector_store_type="faiss",
#     chunk_size=1000,
#     chunk_overlap=200
# )