Spaces:

syedmudassir16 commited on
Commit
75b7c0b
1 Parent(s): 79aa22e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -24
app.py CHANGED
@@ -3,11 +3,8 @@ import multiprocessing
3
  import concurrent.futures
4
  from langchain.document_loaders import TextLoader, DirectoryLoader
5
  from langchain.text_splitter import RecursiveCharacterTextSplitter
6
- from langchain.vectorstores import FAISS
7
- from sentence_transformers import SentenceTransformer
8
- import faiss
9
- import torch
10
- import numpy as np
11
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
12
  from datetime import datetime
13
  import json
@@ -24,7 +21,7 @@ logger = logging.getLogger(__name__)
24
  class DocumentRetrievalAndGeneration:
25
  def __init__(self, embedding_model_name, lm_model_id, data_folder):
26
  self.all_splits = self.load_documents(data_folder)
27
- self.embeddings = SentenceTransformer(embedding_model_name)
28
  self.vectordb = self.create_faiss_index()
29
  self.tokenizer, self.model = self.initialize_llm(lm_model_id)
30
  self.retriever_tool = self.create_retriever_tool()
@@ -40,24 +37,7 @@ class DocumentRetrievalAndGeneration:
40
  return all_splits
41
 
42
  def create_faiss_index(self):
43
- all_texts = [split.page_content for split in self.all_splits]
44
- embeddings = self.embeddings.encode(all_texts)
45
-
46
- # Create FAISS index
47
- vector_dimension = embeddings.shape[1]
48
- index = faiss.IndexFlatL2(vector_dimension)
49
- index.add(embeddings)
50
-
51
- # Create docstore
52
- docstore = {i: doc for i, doc in enumerate(self.all_splits)}
53
-
54
- # Create and return FAISS object
55
- return FAISS(
56
- embedding_function=self.embeddings.encode,
57
- index=index,
58
- docstore=docstore,
59
- index_to_docstore_id={i: i for i in range(len(self.all_splits))}
60
- )
61
 
62
  def initialize_llm(self, model_id):
63
  quantization_config = BitsAndBytesConfig(
@@ -145,6 +125,12 @@ Question:
145
  response = self.query_and_generate_response(query)
146
  return response
147
 
 
 
 
 
 
 
148
  if __name__ == "__main__":
149
  embedding_model_name = 'thenlper/gte-small'
150
  lm_model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
@@ -152,6 +138,9 @@ if __name__ == "__main__":
152
 
153
  doc_retrieval_gen = DocumentRetrievalAndGeneration(embedding_model_name, lm_model_id, data_folder)
154
 
 
 
 
155
  def launch_interface():
156
  css_code = """
157
  .gradio-container {
 
3
  import concurrent.futures
4
  from langchain.document_loaders import TextLoader, DirectoryLoader
5
  from langchain.text_splitter import RecursiveCharacterTextSplitter
6
+ from langchain_community.vectorstores import FAISS
7
+ from langchain_community.embeddings import HuggingFaceEmbeddings
 
 
 
8
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
9
  from datetime import datetime
10
  import json
 
21
  class DocumentRetrievalAndGeneration:
22
  def __init__(self, embedding_model_name, lm_model_id, data_folder):
23
  self.all_splits = self.load_documents(data_folder)
24
+ self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_name)
25
  self.vectordb = self.create_faiss_index()
26
  self.tokenizer, self.model = self.initialize_llm(lm_model_id)
27
  self.retriever_tool = self.create_retriever_tool()
 
37
  return all_splits
38
 
39
  def create_faiss_index(self):
40
+ return FAISS.from_documents(self.all_splits, self.embeddings)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  def initialize_llm(self, model_id):
43
  quantization_config = BitsAndBytesConfig(
 
125
  response = self.query_and_generate_response(query)
126
  return response
127
 
128
+ def save_index(self, path):
129
+ self.vectordb.save_local(path)
130
+
131
+ def load_index(self, path):
132
+ self.vectordb = FAISS.load_local(path, self.embeddings)
133
+
134
  if __name__ == "__main__":
135
  embedding_model_name = 'thenlper/gte-small'
136
  lm_model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
 
138
 
139
  doc_retrieval_gen = DocumentRetrievalAndGeneration(embedding_model_name, lm_model_id, data_folder)
140
 
141
+ # Save the index for future use
142
+ doc_retrieval_gen.save_index("faiss_index")
143
+
144
  def launch_interface():
145
  css_code = """
146
  .gradio-container {