Chris4K commited on
Commit
e96852d
·
verified ·
1 Parent(s): 27b47d4

Update vector_store_retriever.py

Browse files
Files changed (1) hide show
  1. vector_store_retriever.py +254 -191
vector_store_retriever.py CHANGED
@@ -1,195 +1,258 @@
1
- import json
2
  import os
3
- import gradio as gr
4
- import time
5
- import langchain
6
-
7
- from pydantic import BaseModel, Field
8
- from typing import Any, Optional, Dict, List, Union
9
- from huggingface_hub import InferenceClient
10
- from langchain.llms.base import LLM
11
- #from langchain.Images import Images
12
-
13
- from langchain.llms.base import LLM
14
- #from langchain_core.embeddings import EmbeddingFunction, Embeddings
15
-
16
- from langchain.embeddings import HuggingFaceInstructEmbeddings
17
- #from langchain import [all]
18
- #from langchain.Documents import Documents
19
- from langchain.vectorstores import Chroma
20
- from dotenv import load_dotenv
21
- from transformers import AutoTokenizer, AutoModel, Tool
22
-
23
- load_dotenv()
24
-
25
- path_work = "."
26
- hf_token = os.getenv("HF")
27
-
28
- class HuggingFaceInstructEmbeddings(HuggingFaceInstructEmbeddings):
29
- def __init__(self, model_name: str, model_kwargs: Optional[Dict[str, Any]] = None):
30
- self.model = AutoModel.from_pretrained(model_name, **(model_kwargs or {}))
31
- self.tokenizer = AutoTokenizer.from_pretrained(model_name)
32
 
33
- def __call__(self, input: Union[Documents]) -> HuggingFaceInstructEmbeddings:
34
- if isinstance(input, Documents):
35
- texts = [doc.text for doc in input]
36
- embeddings = self._embed_text(texts)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  else:
38
- # Handle image embeddings if needed
39
- pass
40
-
41
- return embeddings
42
-
43
- def _embed_text(self, texts: List[str]) -> Embeddings:
44
- # Your existing logic for text embeddings using Hugging Face models...
45
- inputs = self.tokenizer(texts, return_tensors="pt", padding=True, truncation=True)
46
- with torch.no_grad():
47
- outputs = self.model(**inputs)
48
- embeddings = outputs.last_hidden_state.mean(dim=1) # Adjust this based on your specific model
49
-
50
- return embeddings
51
-
52
-
53
- vectordb = Chroma(
54
- persist_directory=path_work + '/new_papers',
55
- embedding_function=HuggingFaceInstructEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2", model_kwargs={"device": "cpu"})
56
- )
57
-
58
- retriever = vectordb.as_retriever(search_kwargs={"k": 2})#5
59
-
60
-
61
- class KwArgsModel(BaseModel):
62
- kwargs: Dict[str, Any] = Field(default_factory=dict)
63
-
64
- class CustomInferenceClient(LLM, KwArgsModel):
65
- model_name: str
66
- inference_client: InferenceClient
67
-
68
- def __init__(self, model_name: str, hf_token: str, kwargs: Optional[Dict[str, Any]] = None):
69
- inference_client = InferenceClient(model=model_name, token=hf_token)
70
- super().__init__(
71
- model_name=model_name,
72
- hf_token=hf_token,
73
- kwargs=kwargs,
74
- inference_client=inference_client
75
- )
76
-
77
- def _call(
78
- self,
79
- prompt: str,
80
- stop: Optional[List[str]] = None
81
- ) -> str:
82
- if stop is not None:
83
- raise ValueError("stop kwargs are not permitted.")
84
- response_gen = self.inference_client.text_generation(prompt, **self.kwargs, stream=True)
85
- response = ''.join(response_gen)
86
- return response
87
-
88
- @property
89
- def _llm_type(self) -> str:
90
- return "custom"
91
-
92
- @property
93
- def _identifying_params(self) -> dict:
94
- return {"model_name": self.model_name}
95
-
96
- kwargs = {"max_new_tokens": 256, "temperature": 0.9, "top_p": 0.6, "repetition_penalty": 1.3, "do_sample": True}
97
-
98
- model_list = [
99
- "meta-llama/Llama-2-13b-chat-hf",
100
- "HuggingFaceH4/zephyr-7b-alpha",
101
- "meta-llama/Llama-2-70b-chat-hf",
102
- "tiiuae/falcon-180B-chat"
103
- ]
104
-
105
- qa_chain = None
106
-
107
- def load_model(model_selected):
108
- global qa_chain
109
- model_name = model_selected
110
- llm = CustomInferenceClient(model_name=model_name, hf_token=hf_token, kwargs=kwargs)
111
-
112
- from langchain.chains import RetrievalQA
113
- qa_chain = RetrievalQA.from_chain_type(
114
- llm=llm,
115
- chain_type="stuff",
116
- retriever=retriever,
117
- return_source_documents=True,
118
- verbose=True,
119
- )
120
- return qa_chain
121
-
122
- load_model("meta-llama/Llama-2-70b-chat-hf")
123
-
124
- ##########
125
- #####
126
- #########
127
-
128
-
129
- ###
130
- ###
131
- ###
132
-
133
- def predict(message, temperature=0.9, max_new_tokens=512, top_p=0.6, repetition_penalty=1.3):
134
- temperature = float(temperature)
135
- if temperature < 1e-2: temperature = 1e-2
136
- top_p = float(top_p)
137
-
138
- llm_response = qa_chain(message)
139
- res_result = llm_response['result']
140
-
141
- res_relevant_doc = [source.metadata['source'] for source in llm_response["source_documents"]]
142
- response = f"{res_result}" + "\n\n" + "[Answer Source Documents (Ctrl + Click!)] :" + "\n" + f" \n {res_relevant_doc}"
143
- print("response: =====> \n", response, "\n\n")
144
- tokens = response.split('\n')
145
- token_list = []
146
- for idx, token in enumerate(tokens):
147
- token_dict = {"id": idx + 1, "text": token}
148
- token_list.append(token_dict)
149
- response = {"data": {"token": token_list}}
150
- response = json.dumps(response, indent=4)
151
-
152
- response = json.loads(response)
153
- data_dict = response.get('data', {})
154
- token_list = data_dict.get('token', [])
155
-
156
- partial_message = ""
157
- for token_entry in token_list:
158
- if token_entry:
159
  try:
160
- # Handle missing 'id' key gracefully
161
- token_id = token_entry.get('id', None)
162
- token_text = token_entry.get('text', None)
163
-
164
- if token_text:
165
- for char in token_text:
166
- partial_message += char
167
- yield partial_message
168
- time.sleep(0.01)
169
- else:
170
- print(f"Warning ==> The key 'text' does not exist or is None in this token entry: {token_entry}")
171
- pass
172
-
173
- except KeyError as e:
174
- print(f"KeyError: {e} occurred for token entry: {token_entry}")
175
- continue
176
-
177
-
178
- class TextGeneratorTool(Tool):
179
- name = "vector_retriever"
180
- description = "This tool searches in a vector store based on a given prompt."
181
- inputs = ["prompt"]
182
- outputs = ["text"]
183
-
184
-
185
- def __init__(self):
186
- #self.retriever = db.as_retriever(search_kwargs={"k": 1})
187
- pass # You might want to add some initialization logic here
188
-
189
- def __call__(self, prompt: str):
190
- result = predict(prompt, 0.9, 512, 0.6, 1.4)
191
- return result
192
-
193
-
194
-
195
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ from typing import Dict, List, Optional, Union, Any
3
+ from smolagents import Tool
4
+ from langchain.vectorstores import FAISS, Chroma
5
+ from langchain.embeddings import HuggingFaceBgeEmbeddings, HuggingFaceEmbeddings
6
+ from langchain.document_loaders import PyPDFLoader, TextLoader, DirectoryLoader
7
+ from langchain.text_splitter import CharacterTextSplitter, RecursiveCharacterTextSplitter
8
+ from PyPDF2 import PdfReader
9
+ import json
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
+ class RAGTool(Tool):
12
+ name = "rag_retriever"
13
+ description = """
14
+ Advanced RAG (Retrieval-Augmented Generation) tool that searches in vector stores based on given prompts.
15
+ This tool allows you to query documents stored in vector databases using semantic similarity.
16
+ It supports various configurations including different embedding models, vector stores, and document types.
17
+ """
18
+ inputs = {
19
+ "query": {
20
+ "type": "string",
21
+ "description": "The search query to retrieve relevant information from the document store",
22
+ },
23
+ "top_k": {
24
+ "type": "integer",
25
+ "description": "Number of most relevant documents to retrieve (default: 3)",
26
+ }
27
+ }
28
+ output_type = "string"
29
+
30
+ def __init__(self,
31
+ documents_path: str = "./documents",
32
+ embedding_model: str = "BAAI/bge-small-en-v1.5",
33
+ vector_store_type: str = "faiss",
34
+ chunk_size: int = 1000,
35
+ chunk_overlap: int = 200,
36
+ persist_directory: str = "./vector_store",
37
+ device: str = "cpu"):
38
+ """
39
+ Initialize the RAG Tool with configurable parameters.
40
+
41
+ Args:
42
+ documents_path: Path to documents or folder containing documents
43
+ embedding_model: HuggingFace model ID for embeddings
44
+ vector_store_type: Type of vector store ('faiss' or 'chroma')
45
+ chunk_size: Size of text chunks for splitting documents
46
+ chunk_overlap: Overlap between text chunks
47
+ persist_directory: Directory to persist vector store
48
+ device: Device to run embedding model on ('cpu' or 'cuda')
49
+ """
50
+ super().__init__()
51
+ self.documents_path = documents_path
52
+ self.embedding_model = embedding_model
53
+ self.vector_store_type = vector_store_type
54
+ self.chunk_size = chunk_size
55
+ self.chunk_overlap = chunk_overlap
56
+ self.persist_directory = persist_directory
57
+ self.device = device
58
+
59
+ # Create the vector store if it doesn't exist
60
+ os.makedirs(persist_directory, exist_ok=True)
61
+ self._setup_vector_store()
62
+
63
+ def _setup_vector_store(self):
64
+ """Set up the vector store with documents if it doesn't exist"""
65
+ # Check if we need to create a new vector store
66
+ if not os.path.exists(os.path.join(self.persist_directory, "index.faiss")) and \
67
+ not os.path.exists(os.path.join(self.persist_directory, "chroma")):
68
+ # Check if documents path exists
69
+ if not os.path.exists(self.documents_path):
70
+ print(f"Warning: Documents path {self.documents_path} does not exist.")
71
+ return
72
+
73
+ # Load and process documents
74
+ documents = self._load_documents()
75
+ if not documents:
76
+ print("No documents loaded. Vector store not created.")
77
+ return
78
+
79
+ # Create the vector store
80
+ self._create_vector_store(documents)
81
  else:
82
+ print(f"Vector store already exists at {self.persist_directory}")
83
+ self._load_vector_store()
84
+
85
+ def _get_embeddings(self):
86
+ """Get embedding model based on configuration"""
87
+ try:
88
+ if "bge" in self.embedding_model.lower():
89
+ encode_kwargs = {"normalize_embeddings": True}
90
+ return HuggingFaceBgeEmbeddings(
91
+ model_name=self.embedding_model,
92
+ encode_kwargs=encode_kwargs,
93
+ model_kwargs={"device": self.device}
94
+ )
95
+ else:
96
+ return HuggingFaceEmbeddings(
97
+ model_name=self.embedding_model,
98
+ model_kwargs={"device": self.device}
99
+ )
100
+ except Exception as e:
101
+ print(f"Error loading embedding model: {e}")
102
+ # Fallback to a reliable model
103
+ print("Falling back to sentence-transformers/all-MiniLM-L6-v2")
104
+ return HuggingFaceEmbeddings(
105
+ model_name="sentence-transformers/all-MiniLM-L6-v2",
106
+ model_kwargs={"device": self.device}
107
+ )
108
+
109
+ def _load_documents(self):
110
+ """Load documents from the documents path"""
111
+ documents = []
112
+
113
+ # Check if documents_path is a file or directory
114
+ if os.path.isfile(self.documents_path):
115
+ # Load single file
116
+ if self.documents_path.lower().endswith('.pdf'):
117
+ try:
118
+ loader = PyPDFLoader(self.documents_path)
119
+ documents = loader.load()
120
+ except Exception as e:
121
+ print(f"Error loading PDF: {e}")
122
+ # Fallback to using PdfReader
123
+ try:
124
+ text = self._extract_text_from_pdf(self.documents_path)
125
+ splitter = CharacterTextSplitter(
126
+ separator="\n",
127
+ chunk_size=self.chunk_size,
128
+ chunk_overlap=self.chunk_overlap
129
+ )
130
+ documents = splitter.create_documents([text])
131
+ except Exception as e2:
132
+ print(f"Error with fallback PDF extraction: {e2}")
133
+ elif self.documents_path.lower().endswith(('.txt', '.md', '.html')):
134
+ loader = TextLoader(self.documents_path)
135
+ documents = loader.load()
136
+ elif os.path.isdir(self.documents_path):
137
+ # Load all supported files in directory
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  try:
139
+ loader = DirectoryLoader(
140
+ self.documents_path,
141
+ glob="**/*.*",
142
+ loader_cls=TextLoader,
143
+ loader_kwargs={"autodetect_encoding": True}
144
+ )
145
+ documents = loader.load()
146
+ except Exception as e:
147
+ print(f"Error loading directory: {e}")
148
+
149
+ # Split documents into chunks if they exist
150
+ if documents:
151
+ splitter = RecursiveCharacterTextSplitter(
152
+ chunk_size=self.chunk_size,
153
+ chunk_overlap=self.chunk_overlap
154
+ )
155
+ return splitter.split_documents(documents)
156
+ return []
157
+
158
+ def _extract_text_from_pdf(self, pdf_path):
159
+ """Extract text from PDF using PyPDF2"""
160
+ text = ""
161
+ pdf_reader = PdfReader(pdf_path)
162
+ for page in pdf_reader.pages:
163
+ text += page.extract_text()
164
+ return text
165
+
166
+ def _create_vector_store(self, documents):
167
+ """Create a new vector store from documents"""
168
+ embeddings = self._get_embeddings()
169
+
170
+ if self.vector_store_type.lower() == "faiss":
171
+ vector_store = FAISS.from_documents(documents, embeddings)
172
+ vector_store.save_local(self.persist_directory)
173
+ print(f"Created FAISS vector store at {self.persist_directory}")
174
+ else: # Default to Chroma
175
+ vector_store = Chroma.from_documents(
176
+ documents,
177
+ embeddings,
178
+ persist_directory=self.persist_directory
179
+ )
180
+ vector_store.persist()
181
+ print(f"Created Chroma vector store at {self.persist_directory}")
182
+
183
+ self.vector_store = vector_store
184
+
185
+ def _load_vector_store(self):
186
+ """Load an existing vector store"""
187
+ embeddings = self._get_embeddings()
188
+
189
+ try:
190
+ if self.vector_store_type.lower() == "faiss":
191
+ self.vector_store = FAISS.load_local(self.persist_directory, embeddings)
192
+ print(f"Loaded FAISS vector store from {self.persist_directory}")
193
+ else: # Default to Chroma
194
+ self.vector_store = Chroma(
195
+ persist_directory=self.persist_directory,
196
+ embedding_function=embeddings
197
+ )
198
+ print(f"Loaded Chroma vector store from {self.persist_directory}")
199
+ except Exception as e:
200
+ print(f"Error loading vector store: {e}")
201
+ print("Creating a new vector store...")
202
+ documents = self._load_documents()
203
+ if documents:
204
+ self._create_vector_store(documents)
205
+ else:
206
+ print("No documents available. Cannot create vector store.")
207
+ self.vector_store = None
208
+
209
+ def forward(self, query: str, top_k: int = 3) -> str:
210
+ """
211
+ Retrieve relevant documents based on the query.
212
+
213
+ Args:
214
+ query: The search query
215
+ top_k: Number of results to return
216
+
217
+ Returns:
218
+ String with formatted search results
219
+ """
220
+ if not hasattr(self, 'vector_store') or self.vector_store is None:
221
+ return "Vector store is not initialized. Please check your configuration."
222
+
223
+ try:
224
+ # Perform similarity search
225
+ results = self.vector_store.similarity_search(query, k=top_k)
226
+
227
+ # Format results
228
+ formatted_results = []
229
+ for i, doc in enumerate(results):
230
+ content = doc.page_content
231
+ metadata = doc.metadata
232
+
233
+ # Format metadata nicely
234
+ meta_str = ""
235
+ if metadata:
236
+ meta_str = "\nSource: "
237
+ if "source" in metadata:
238
+ meta_str += metadata["source"]
239
+ if "page" in metadata:
240
+ meta_str += f", Page: {metadata['page']}"
241
+
242
+ formatted_results.append(f"Document {i+1}:\n{content}{meta_str}\n")
243
+
244
+ if formatted_results:
245
+ return "Retrieved relevant information:\n\n" + "\n".join(formatted_results)
246
+ else:
247
+ return "No relevant information found for the query."
248
+ except Exception as e:
249
+ return f"Error retrieving information: {str(e)}"
250
+
251
+ # Example usage:
252
+ # rag_tool = RAGTool(
253
+ # documents_path="./my_docs",
254
+ # embedding_model="sentence-transformers/all-MiniLM-L6-v2",
255
+ # vector_store_type="faiss",
256
+ # chunk_size=1000,
257
+ # chunk_overlap=200
258
+ # )