Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -14,6 +14,8 @@ from transformers.agents import Tool, HfEngine, ReactJsonAgent
|
|
14 |
from huggingface_hub import InferenceClient
|
15 |
import logging
|
16 |
import torch
|
|
|
|
|
17 |
|
18 |
logging.basicConfig(level=logging.INFO)
|
19 |
logger = logging.getLogger(__name__)
|
@@ -29,11 +31,7 @@ class DocumentRetrievalAndGeneration:
|
|
29 |
def __init__(self, embedding_model_name, lm_model_id, data_folder):
|
30 |
self.all_splits = self.load_documents(data_folder)
|
31 |
self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_name)
|
32 |
-
|
33 |
-
self.vectordb = self.create_faiss_index()
|
34 |
-
else:
|
35 |
-
logger.warning("FAISS is not available. Vector search functionality will be limited.")
|
36 |
-
self.vectordb = None
|
37 |
self.tokenizer, self.model = self.initialize_llm(lm_model_id)
|
38 |
self.retriever_tool = self.create_retriever_tool()
|
39 |
self.agent = self.create_agent()
|
@@ -41,17 +39,20 @@ class DocumentRetrievalAndGeneration:
|
|
41 |
def load_documents(self, folder_path):
|
42 |
loader = DirectoryLoader(folder_path, loader_cls=TextLoader)
|
43 |
documents = loader.load()
|
44 |
-
text_splitter = RecursiveCharacterTextSplitter(chunk_size=
|
45 |
all_splits = text_splitter.split_documents(documents)
|
46 |
logger.info(f'Loaded {len(documents)} documents')
|
47 |
logger.info(f"Split into {len(all_splits)} chunks")
|
48 |
return all_splits
|
49 |
|
50 |
def create_faiss_index(self):
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
|
|
|
|
|
|
55 |
|
56 |
def initialize_llm(self, model_id):
|
57 |
quantization_config = BitsAndBytesConfig(
|
@@ -81,24 +82,56 @@ class DocumentRetrievalAndGeneration:
|
|
81 |
}
|
82 |
output_type = "text"
|
83 |
|
84 |
-
def __init__(self,
|
85 |
super().__init__(**kwargs)
|
86 |
-
self.
|
87 |
|
88 |
def forward(self, query: str) -> str:
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
)
|
95 |
-
|
96 |
-
|
|
|
|
|
|
|
|
|
|
|
97 |
|
98 |
def create_agent(self):
|
99 |
llm_engine = HfEngine("meta-llama/Meta-Llama-3.1-8B-Instruct")
|
100 |
return ReactJsonAgent(tools=[self.retriever_tool], llm_engine=llm_engine, max_iterations=4, verbose=2)
|
101 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
def run_agentic_rag(self, question: str) -> str:
|
103 |
enhanced_question = f"""Using the information in your knowledge base, accessible with the 'retriever' tool,
|
104 |
give a comprehensive answer to the question below.
|
@@ -115,20 +148,23 @@ Question:
|
|
115 |
def run_standard_rag(self, question: str) -> str:
|
116 |
context = self.retriever_tool(query=question)
|
117 |
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
|
|
|
|
|
|
132 |
|
133 |
def query_and_generate_response(self, query):
|
134 |
agentic_answer = self.run_agentic_rag(query)
|
@@ -141,29 +177,17 @@ Question:
|
|
141 |
response = self.query_and_generate_response(query)
|
142 |
return response
|
143 |
|
144 |
-
def save_index(self, path):
|
145 |
-
if self.vectordb is not None:
|
146 |
-
self.vectordb.save_local(path)
|
147 |
-
else:
|
148 |
-
logger.warning("Vector database is not available. Cannot save index.")
|
149 |
-
|
150 |
-
def load_index(self, path):
|
151 |
-
if FAISS is not None:
|
152 |
-
self.vectordb = FAISS.load_local(path, self.embeddings)
|
153 |
-
else:
|
154 |
-
logger.warning("FAISS is not available. Cannot load index.")
|
155 |
-
|
156 |
if __name__ == "__main__":
|
157 |
-
embedding_model_name = '
|
158 |
lm_model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
159 |
data_folder = 'sample_embedding_folder2'
|
160 |
|
|
|
|
|
|
|
161 |
try:
|
162 |
doc_retrieval_gen = DocumentRetrievalAndGeneration(embedding_model_name, lm_model_id, data_folder)
|
163 |
|
164 |
-
# Save the index for future use
|
165 |
-
doc_retrieval_gen.save_index("faiss_index")
|
166 |
-
|
167 |
def launch_interface():
|
168 |
css_code = """
|
169 |
.gradio-container {
|
|
|
14 |
from huggingface_hub import InferenceClient
|
15 |
import logging
|
16 |
import torch
|
17 |
+
import numpy as np
|
18 |
+
import faiss
|
19 |
|
20 |
logging.basicConfig(level=logging.INFO)
|
21 |
logger = logging.getLogger(__name__)
|
|
|
31 |
def __init__(self, embedding_model_name, lm_model_id, data_folder):
|
32 |
self.all_splits = self.load_documents(data_folder)
|
33 |
self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_name)
|
34 |
+
self.gpu_index = self.create_faiss_index()
|
|
|
|
|
|
|
|
|
35 |
self.tokenizer, self.model = self.initialize_llm(lm_model_id)
|
36 |
self.retriever_tool = self.create_retriever_tool()
|
37 |
self.agent = self.create_agent()
|
|
|
39 |
def load_documents(self, folder_path):
|
40 |
loader = DirectoryLoader(folder_path, loader_cls=TextLoader)
|
41 |
documents = loader.load()
|
42 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=5000, chunk_overlap=250)
|
43 |
all_splits = text_splitter.split_documents(documents)
|
44 |
logger.info(f'Loaded {len(documents)} documents')
|
45 |
logger.info(f"Split into {len(all_splits)} chunks")
|
46 |
return all_splits
|
47 |
|
48 |
def create_faiss_index(self):
|
49 |
+
all_texts = [split.page_content for split in self.all_splits]
|
50 |
+
embeddings = self.embeddings.embed_documents(all_texts)
|
51 |
+
index = faiss.IndexFlatL2(len(embeddings[0]))
|
52 |
+
index.add(np.array(embeddings))
|
53 |
+
gpu_resource = faiss.StandardGpuResources()
|
54 |
+
gpu_index = faiss.index_cpu_to_gpu(gpu_resource, 0, index)
|
55 |
+
return gpu_index
|
56 |
|
57 |
def initialize_llm(self, model_id):
|
58 |
quantization_config = BitsAndBytesConfig(
|
|
|
82 |
}
|
83 |
output_type = "text"
|
84 |
|
85 |
+
def __init__(self, parent, **kwargs):
|
86 |
super().__init__(**kwargs)
|
87 |
+
self.parent = parent
|
88 |
|
89 |
def forward(self, query: str) -> str:
|
90 |
+
similarityThreshold = 1
|
91 |
+
query_embedding = self.parent.embeddings.embed_query(query)
|
92 |
+
distances, indices = self.parent.gpu_index.search(np.array([query_embedding]), k=3)
|
93 |
+
content = ""
|
94 |
+
filtered_results = []
|
95 |
+
for idx, distance in zip(indices[0], distances[0]):
|
96 |
+
if distance <= similarityThreshold:
|
97 |
+
filtered_results.append(idx)
|
98 |
+
content += "-" * 50 + "\n"
|
99 |
+
content += self.parent.all_splits[idx].page_content + "\n"
|
100 |
+
return content
|
101 |
+
|
102 |
+
return RetrieverTool(self)
|
103 |
|
104 |
def create_agent(self):
|
105 |
llm_engine = HfEngine("meta-llama/Meta-Llama-3.1-8B-Instruct")
|
106 |
return ReactJsonAgent(tools=[self.retriever_tool], llm_engine=llm_engine, max_iterations=4, verbose=2)
|
107 |
|
108 |
+
def generate_response_with_timeout(self, input_ids, max_new_tokens=1000):
|
109 |
+
try:
|
110 |
+
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
111 |
+
generate_kwargs = dict(
|
112 |
+
input_ids=input_ids,
|
113 |
+
max_new_tokens=max_new_tokens,
|
114 |
+
do_sample=True,
|
115 |
+
top_p=1.0,
|
116 |
+
top_k=20,
|
117 |
+
temperature=0.8,
|
118 |
+
repetition_penalty=1.2,
|
119 |
+
eos_token_id=[128001, 128008, 128009],
|
120 |
+
streamer=streamer,
|
121 |
+
)
|
122 |
+
|
123 |
+
thread = Thread(target=self.model.generate, kwargs=generate_kwargs)
|
124 |
+
thread.start()
|
125 |
+
|
126 |
+
generated_text = ""
|
127 |
+
for new_text in streamer:
|
128 |
+
generated_text += new_text
|
129 |
+
|
130 |
+
return generated_text
|
131 |
+
except Exception as e:
|
132 |
+
logger.error(f"Error in generate_response_with_timeout: {str(e)}")
|
133 |
+
return "Text generation process encountered an error"
|
134 |
+
|
135 |
def run_agentic_rag(self, question: str) -> str:
|
136 |
enhanced_question = f"""Using the information in your knowledge base, accessible with the 'retriever' tool,
|
137 |
give a comprehensive answer to the question below.
|
|
|
148 |
def run_standard_rag(self, question: str) -> str:
|
149 |
context = self.retriever_tool(query=question)
|
150 |
|
151 |
+
conversation = [
|
152 |
+
{"role": "system", "content": "You are a knowledgeable assistant with access to a comprehensive database."},
|
153 |
+
{"role": "user", "content": f"""
|
154 |
+
I need you to answer my question and provide related information in a specific format.
|
155 |
+
I have provided five relatable json files {context}, choose the most suitable chunks for answering the query.
|
156 |
+
RETURN ONLY SOLUTION without additional comments, sign-offs, retrived chunks, refrence to any Ticket or extra phrases. Be direct and to the point.
|
157 |
+
IF THERE IS NO ANSWER RELATABLE IN RETRIEVED CHUNKS, RETURN "NO SOLUTION AVAILABLE".
|
158 |
+
DO NOT GIVE REFRENCE TO ANY CHUNKS OR TICKETS,BE ON POINT.
|
159 |
+
|
160 |
+
Here's my question:
|
161 |
+
Query: {question}
|
162 |
+
Solution==>
|
163 |
+
"""}
|
164 |
+
]
|
165 |
+
input_ids = self.tokenizer.apply_chat_template(conversation, return_tensors="pt").to(self.model.device)
|
166 |
+
|
167 |
+
return self.generate_response_with_timeout(input_ids)
|
168 |
|
169 |
def query_and_generate_response(self, query):
|
170 |
agentic_answer = self.run_agentic_rag(query)
|
|
|
177 |
response = self.query_and_generate_response(query)
|
178 |
return response
|
179 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
if __name__ == "__main__":
|
181 |
+
embedding_model_name = 'flax-sentence-embeddings/all_datasets_v3_MiniLM-L12'
|
182 |
lm_model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
183 |
data_folder = 'sample_embedding_folder2'
|
184 |
|
185 |
+
# Set your HuggingFace token here
|
186 |
+
os.environ["HUGGINGFACE_TOKEN"] = "your_huggingface_token_here"
|
187 |
+
|
188 |
try:
|
189 |
doc_retrieval_gen = DocumentRetrievalAndGeneration(embedding_model_name, lm_model_id, data_folder)
|
190 |
|
|
|
|
|
|
|
191 |
def launch_interface():
|
192 |
css_code = """
|
193 |
.gradio-container {
|