Update vector_store_retriever.py
Browse files- vector_store_retriever.py +254 -191
vector_store_retriever.py
CHANGED
@@ -1,195 +1,258 @@
|
|
1 |
-
import json
|
2 |
import os
|
3 |
-
import
|
4 |
-
import
|
5 |
-
import
|
6 |
-
|
7 |
-
from
|
8 |
-
from
|
9 |
-
from
|
10 |
-
|
11 |
-
#from langchain.Images import Images
|
12 |
-
|
13 |
-
from langchain.llms.base import LLM
|
14 |
-
#from langchain_core.embeddings import EmbeddingFunction, Embeddings
|
15 |
-
|
16 |
-
from langchain.embeddings import HuggingFaceInstructEmbeddings
|
17 |
-
#from langchain import [all]
|
18 |
-
#from langchain.Documents import Documents
|
19 |
-
from langchain.vectorstores import Chroma
|
20 |
-
from dotenv import load_dotenv
|
21 |
-
from transformers import AutoTokenizer, AutoModel, Tool
|
22 |
-
|
23 |
-
load_dotenv()
|
24 |
-
|
25 |
-
path_work = "."
|
26 |
-
hf_token = os.getenv("HF")
|
27 |
-
|
28 |
-
class HuggingFaceInstructEmbeddings(HuggingFaceInstructEmbeddings):
|
29 |
-
def __init__(self, model_name: str, model_kwargs: Optional[Dict[str, Any]] = None):
|
30 |
-
self.model = AutoModel.from_pretrained(model_name, **(model_kwargs or {}))
|
31 |
-
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
32 |
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
else:
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
return {"model_name": self.model_name}
|
95 |
-
|
96 |
-
kwargs = {"max_new_tokens": 256, "temperature": 0.9, "top_p": 0.6, "repetition_penalty": 1.3, "do_sample": True}
|
97 |
-
|
98 |
-
model_list = [
|
99 |
-
"meta-llama/Llama-2-13b-chat-hf",
|
100 |
-
"HuggingFaceH4/zephyr-7b-alpha",
|
101 |
-
"meta-llama/Llama-2-70b-chat-hf",
|
102 |
-
"tiiuae/falcon-180B-chat"
|
103 |
-
]
|
104 |
-
|
105 |
-
qa_chain = None
|
106 |
-
|
107 |
-
def load_model(model_selected):
|
108 |
-
global qa_chain
|
109 |
-
model_name = model_selected
|
110 |
-
llm = CustomInferenceClient(model_name=model_name, hf_token=hf_token, kwargs=kwargs)
|
111 |
-
|
112 |
-
from langchain.chains import RetrievalQA
|
113 |
-
qa_chain = RetrievalQA.from_chain_type(
|
114 |
-
llm=llm,
|
115 |
-
chain_type="stuff",
|
116 |
-
retriever=retriever,
|
117 |
-
return_source_documents=True,
|
118 |
-
verbose=True,
|
119 |
-
)
|
120 |
-
return qa_chain
|
121 |
-
|
122 |
-
load_model("meta-llama/Llama-2-70b-chat-hf")
|
123 |
-
|
124 |
-
##########
|
125 |
-
#####
|
126 |
-
#########
|
127 |
-
|
128 |
-
|
129 |
-
###
|
130 |
-
###
|
131 |
-
###
|
132 |
-
|
133 |
-
def predict(message, temperature=0.9, max_new_tokens=512, top_p=0.6, repetition_penalty=1.3):
|
134 |
-
temperature = float(temperature)
|
135 |
-
if temperature < 1e-2: temperature = 1e-2
|
136 |
-
top_p = float(top_p)
|
137 |
-
|
138 |
-
llm_response = qa_chain(message)
|
139 |
-
res_result = llm_response['result']
|
140 |
-
|
141 |
-
res_relevant_doc = [source.metadata['source'] for source in llm_response["source_documents"]]
|
142 |
-
response = f"{res_result}" + "\n\n" + "[Answer Source Documents (Ctrl + Click!)] :" + "\n" + f" \n {res_relevant_doc}"
|
143 |
-
print("response: =====> \n", response, "\n\n")
|
144 |
-
tokens = response.split('\n')
|
145 |
-
token_list = []
|
146 |
-
for idx, token in enumerate(tokens):
|
147 |
-
token_dict = {"id": idx + 1, "text": token}
|
148 |
-
token_list.append(token_dict)
|
149 |
-
response = {"data": {"token": token_list}}
|
150 |
-
response = json.dumps(response, indent=4)
|
151 |
-
|
152 |
-
response = json.loads(response)
|
153 |
-
data_dict = response.get('data', {})
|
154 |
-
token_list = data_dict.get('token', [])
|
155 |
-
|
156 |
-
partial_message = ""
|
157 |
-
for token_entry in token_list:
|
158 |
-
if token_entry:
|
159 |
try:
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
+
from typing import Dict, List, Optional, Union, Any
|
3 |
+
from smolagents import Tool
|
4 |
+
from langchain.vectorstores import FAISS, Chroma
|
5 |
+
from langchain.embeddings import HuggingFaceBgeEmbeddings, HuggingFaceEmbeddings
|
6 |
+
from langchain.document_loaders import PyPDFLoader, TextLoader, DirectoryLoader
|
7 |
+
from langchain.text_splitter import CharacterTextSplitter, RecursiveCharacterTextSplitter
|
8 |
+
from PyPDF2 import PdfReader
|
9 |
+
import json
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
+
class RAGTool(Tool):
|
12 |
+
name = "rag_retriever"
|
13 |
+
description = """
|
14 |
+
Advanced RAG (Retrieval-Augmented Generation) tool that searches in vector stores based on given prompts.
|
15 |
+
This tool allows you to query documents stored in vector databases using semantic similarity.
|
16 |
+
It supports various configurations including different embedding models, vector stores, and document types.
|
17 |
+
"""
|
18 |
+
inputs = {
|
19 |
+
"query": {
|
20 |
+
"type": "string",
|
21 |
+
"description": "The search query to retrieve relevant information from the document store",
|
22 |
+
},
|
23 |
+
"top_k": {
|
24 |
+
"type": "integer",
|
25 |
+
"description": "Number of most relevant documents to retrieve (default: 3)",
|
26 |
+
}
|
27 |
+
}
|
28 |
+
output_type = "string"
|
29 |
+
|
30 |
+
def __init__(self,
|
31 |
+
documents_path: str = "./documents",
|
32 |
+
embedding_model: str = "BAAI/bge-small-en-v1.5",
|
33 |
+
vector_store_type: str = "faiss",
|
34 |
+
chunk_size: int = 1000,
|
35 |
+
chunk_overlap: int = 200,
|
36 |
+
persist_directory: str = "./vector_store",
|
37 |
+
device: str = "cpu"):
|
38 |
+
"""
|
39 |
+
Initialize the RAG Tool with configurable parameters.
|
40 |
+
|
41 |
+
Args:
|
42 |
+
documents_path: Path to documents or folder containing documents
|
43 |
+
embedding_model: HuggingFace model ID for embeddings
|
44 |
+
vector_store_type: Type of vector store ('faiss' or 'chroma')
|
45 |
+
chunk_size: Size of text chunks for splitting documents
|
46 |
+
chunk_overlap: Overlap between text chunks
|
47 |
+
persist_directory: Directory to persist vector store
|
48 |
+
device: Device to run embedding model on ('cpu' or 'cuda')
|
49 |
+
"""
|
50 |
+
super().__init__()
|
51 |
+
self.documents_path = documents_path
|
52 |
+
self.embedding_model = embedding_model
|
53 |
+
self.vector_store_type = vector_store_type
|
54 |
+
self.chunk_size = chunk_size
|
55 |
+
self.chunk_overlap = chunk_overlap
|
56 |
+
self.persist_directory = persist_directory
|
57 |
+
self.device = device
|
58 |
+
|
59 |
+
# Create the vector store if it doesn't exist
|
60 |
+
os.makedirs(persist_directory, exist_ok=True)
|
61 |
+
self._setup_vector_store()
|
62 |
+
|
63 |
+
def _setup_vector_store(self):
|
64 |
+
"""Set up the vector store with documents if it doesn't exist"""
|
65 |
+
# Check if we need to create a new vector store
|
66 |
+
if not os.path.exists(os.path.join(self.persist_directory, "index.faiss")) and \
|
67 |
+
not os.path.exists(os.path.join(self.persist_directory, "chroma")):
|
68 |
+
# Check if documents path exists
|
69 |
+
if not os.path.exists(self.documents_path):
|
70 |
+
print(f"Warning: Documents path {self.documents_path} does not exist.")
|
71 |
+
return
|
72 |
+
|
73 |
+
# Load and process documents
|
74 |
+
documents = self._load_documents()
|
75 |
+
if not documents:
|
76 |
+
print("No documents loaded. Vector store not created.")
|
77 |
+
return
|
78 |
+
|
79 |
+
# Create the vector store
|
80 |
+
self._create_vector_store(documents)
|
81 |
else:
|
82 |
+
print(f"Vector store already exists at {self.persist_directory}")
|
83 |
+
self._load_vector_store()
|
84 |
+
|
85 |
+
def _get_embeddings(self):
|
86 |
+
"""Get embedding model based on configuration"""
|
87 |
+
try:
|
88 |
+
if "bge" in self.embedding_model.lower():
|
89 |
+
encode_kwargs = {"normalize_embeddings": True}
|
90 |
+
return HuggingFaceBgeEmbeddings(
|
91 |
+
model_name=self.embedding_model,
|
92 |
+
encode_kwargs=encode_kwargs,
|
93 |
+
model_kwargs={"device": self.device}
|
94 |
+
)
|
95 |
+
else:
|
96 |
+
return HuggingFaceEmbeddings(
|
97 |
+
model_name=self.embedding_model,
|
98 |
+
model_kwargs={"device": self.device}
|
99 |
+
)
|
100 |
+
except Exception as e:
|
101 |
+
print(f"Error loading embedding model: {e}")
|
102 |
+
# Fallback to a reliable model
|
103 |
+
print("Falling back to sentence-transformers/all-MiniLM-L6-v2")
|
104 |
+
return HuggingFaceEmbeddings(
|
105 |
+
model_name="sentence-transformers/all-MiniLM-L6-v2",
|
106 |
+
model_kwargs={"device": self.device}
|
107 |
+
)
|
108 |
+
|
109 |
+
def _load_documents(self):
|
110 |
+
"""Load documents from the documents path"""
|
111 |
+
documents = []
|
112 |
+
|
113 |
+
# Check if documents_path is a file or directory
|
114 |
+
if os.path.isfile(self.documents_path):
|
115 |
+
# Load single file
|
116 |
+
if self.documents_path.lower().endswith('.pdf'):
|
117 |
+
try:
|
118 |
+
loader = PyPDFLoader(self.documents_path)
|
119 |
+
documents = loader.load()
|
120 |
+
except Exception as e:
|
121 |
+
print(f"Error loading PDF: {e}")
|
122 |
+
# Fallback to using PdfReader
|
123 |
+
try:
|
124 |
+
text = self._extract_text_from_pdf(self.documents_path)
|
125 |
+
splitter = CharacterTextSplitter(
|
126 |
+
separator="\n",
|
127 |
+
chunk_size=self.chunk_size,
|
128 |
+
chunk_overlap=self.chunk_overlap
|
129 |
+
)
|
130 |
+
documents = splitter.create_documents([text])
|
131 |
+
except Exception as e2:
|
132 |
+
print(f"Error with fallback PDF extraction: {e2}")
|
133 |
+
elif self.documents_path.lower().endswith(('.txt', '.md', '.html')):
|
134 |
+
loader = TextLoader(self.documents_path)
|
135 |
+
documents = loader.load()
|
136 |
+
elif os.path.isdir(self.documents_path):
|
137 |
+
# Load all supported files in directory
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
try:
|
139 |
+
loader = DirectoryLoader(
|
140 |
+
self.documents_path,
|
141 |
+
glob="**/*.*",
|
142 |
+
loader_cls=TextLoader,
|
143 |
+
loader_kwargs={"autodetect_encoding": True}
|
144 |
+
)
|
145 |
+
documents = loader.load()
|
146 |
+
except Exception as e:
|
147 |
+
print(f"Error loading directory: {e}")
|
148 |
+
|
149 |
+
# Split documents into chunks if they exist
|
150 |
+
if documents:
|
151 |
+
splitter = RecursiveCharacterTextSplitter(
|
152 |
+
chunk_size=self.chunk_size,
|
153 |
+
chunk_overlap=self.chunk_overlap
|
154 |
+
)
|
155 |
+
return splitter.split_documents(documents)
|
156 |
+
return []
|
157 |
+
|
158 |
+
def _extract_text_from_pdf(self, pdf_path):
|
159 |
+
"""Extract text from PDF using PyPDF2"""
|
160 |
+
text = ""
|
161 |
+
pdf_reader = PdfReader(pdf_path)
|
162 |
+
for page in pdf_reader.pages:
|
163 |
+
text += page.extract_text()
|
164 |
+
return text
|
165 |
+
|
166 |
+
def _create_vector_store(self, documents):
|
167 |
+
"""Create a new vector store from documents"""
|
168 |
+
embeddings = self._get_embeddings()
|
169 |
+
|
170 |
+
if self.vector_store_type.lower() == "faiss":
|
171 |
+
vector_store = FAISS.from_documents(documents, embeddings)
|
172 |
+
vector_store.save_local(self.persist_directory)
|
173 |
+
print(f"Created FAISS vector store at {self.persist_directory}")
|
174 |
+
else: # Default to Chroma
|
175 |
+
vector_store = Chroma.from_documents(
|
176 |
+
documents,
|
177 |
+
embeddings,
|
178 |
+
persist_directory=self.persist_directory
|
179 |
+
)
|
180 |
+
vector_store.persist()
|
181 |
+
print(f"Created Chroma vector store at {self.persist_directory}")
|
182 |
+
|
183 |
+
self.vector_store = vector_store
|
184 |
+
|
185 |
+
def _load_vector_store(self):
|
186 |
+
"""Load an existing vector store"""
|
187 |
+
embeddings = self._get_embeddings()
|
188 |
+
|
189 |
+
try:
|
190 |
+
if self.vector_store_type.lower() == "faiss":
|
191 |
+
self.vector_store = FAISS.load_local(self.persist_directory, embeddings)
|
192 |
+
print(f"Loaded FAISS vector store from {self.persist_directory}")
|
193 |
+
else: # Default to Chroma
|
194 |
+
self.vector_store = Chroma(
|
195 |
+
persist_directory=self.persist_directory,
|
196 |
+
embedding_function=embeddings
|
197 |
+
)
|
198 |
+
print(f"Loaded Chroma vector store from {self.persist_directory}")
|
199 |
+
except Exception as e:
|
200 |
+
print(f"Error loading vector store: {e}")
|
201 |
+
print("Creating a new vector store...")
|
202 |
+
documents = self._load_documents()
|
203 |
+
if documents:
|
204 |
+
self._create_vector_store(documents)
|
205 |
+
else:
|
206 |
+
print("No documents available. Cannot create vector store.")
|
207 |
+
self.vector_store = None
|
208 |
+
|
209 |
+
def forward(self, query: str, top_k: int = 3) -> str:
|
210 |
+
"""
|
211 |
+
Retrieve relevant documents based on the query.
|
212 |
+
|
213 |
+
Args:
|
214 |
+
query: The search query
|
215 |
+
top_k: Number of results to return
|
216 |
+
|
217 |
+
Returns:
|
218 |
+
String with formatted search results
|
219 |
+
"""
|
220 |
+
if not hasattr(self, 'vector_store') or self.vector_store is None:
|
221 |
+
return "Vector store is not initialized. Please check your configuration."
|
222 |
+
|
223 |
+
try:
|
224 |
+
# Perform similarity search
|
225 |
+
results = self.vector_store.similarity_search(query, k=top_k)
|
226 |
+
|
227 |
+
# Format results
|
228 |
+
formatted_results = []
|
229 |
+
for i, doc in enumerate(results):
|
230 |
+
content = doc.page_content
|
231 |
+
metadata = doc.metadata
|
232 |
+
|
233 |
+
# Format metadata nicely
|
234 |
+
meta_str = ""
|
235 |
+
if metadata:
|
236 |
+
meta_str = "\nSource: "
|
237 |
+
if "source" in metadata:
|
238 |
+
meta_str += metadata["source"]
|
239 |
+
if "page" in metadata:
|
240 |
+
meta_str += f", Page: {metadata['page']}"
|
241 |
+
|
242 |
+
formatted_results.append(f"Document {i+1}:\n{content}{meta_str}\n")
|
243 |
+
|
244 |
+
if formatted_results:
|
245 |
+
return "Retrieved relevant information:\n\n" + "\n".join(formatted_results)
|
246 |
+
else:
|
247 |
+
return "No relevant information found for the query."
|
248 |
+
except Exception as e:
|
249 |
+
return f"Error retrieving information: {str(e)}"
|
250 |
+
|
251 |
+
# Example usage:
|
252 |
+
# rag_tool = RAGTool(
|
253 |
+
# documents_path="./my_docs",
|
254 |
+
# embedding_model="sentence-transformers/all-MiniLM-L6-v2",
|
255 |
+
# vector_store_type="faiss",
|
256 |
+
# chunk_size=1000,
|
257 |
+
# chunk_overlap=200
|
258 |
+
# )
|