Spaces:
Runtime error
Runtime error
syedmudassir16
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
import os
|
2 |
import multiprocessing
|
3 |
import concurrent.futures
|
@@ -15,12 +17,69 @@ import gradio as gr
|
|
15 |
import re
|
16 |
from threading import Thread
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
class DocumentRetrievalAndGeneration:
|
19 |
def __init__(self, embedding_model_name, lm_model_id, data_folder):
|
20 |
self.all_splits = self.load_documents(data_folder)
|
21 |
self.embeddings = SentenceTransformer(embedding_model_name)
|
22 |
self.gpu_index = self.create_faiss_index()
|
23 |
self.tokenizer, self.model = self.initialize_llm(lm_model_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
def load_documents(self, folder_path):
|
26 |
loader = DirectoryLoader(folder_path, loader_cls=TextLoader)
|
@@ -59,6 +118,30 @@ class DocumentRetrievalAndGeneration:
|
|
59 |
)
|
60 |
return tokenizer, model
|
61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
def generate_response_with_timeout(self, input_ids, max_new_tokens=1000):
|
63 |
try:
|
64 |
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
@@ -86,6 +169,7 @@ class DocumentRetrievalAndGeneration:
|
|
86 |
print(f"Error in generate_response_with_timeout: {str(e)}")
|
87 |
return "Text generation process encountered an error"
|
88 |
|
|
|
89 |
def query_and_generate_response(self, query):
|
90 |
similarityThreshold = 1
|
91 |
query_embedding = self.embeddings.encode(query, convert_to_tensor=True).cpu().numpy()
|
@@ -142,8 +226,8 @@ class DocumentRetrievalAndGeneration:
|
|
142 |
return solution_text, content
|
143 |
|
144 |
def qa_infer_gradio(self, query):
|
145 |
-
response = self.query_and_generate_response(query)
|
146 |
-
return response
|
147 |
|
148 |
if __name__ == "__main__":
|
149 |
embedding_model_name = 'flax-sentence-embeddings/all_datasets_v3_MiniLM-L12'
|
|
|
1 |
+
Updated Multi-agent RAG-based LLM Model
|
2 |
+
|
3 |
import os
|
4 |
import multiprocessing
|
5 |
import concurrent.futures
|
|
|
17 |
import re
|
18 |
from threading import Thread
|
19 |
|
20 |
+
class Agent:
|
21 |
+
def __init__(self, name, role, doc_retrieval_gen, tokenizer):
|
22 |
+
self.name = name
|
23 |
+
self.role = role
|
24 |
+
self.doc_retrieval_gen = doc_retrieval_gen
|
25 |
+
self.tokenizer = tokenizer
|
26 |
+
|
27 |
+
def generate_response(self, query, context):
|
28 |
+
if self.role == "Information Retrieval":
|
29 |
+
return self.retriever_logic(query, context)
|
30 |
+
elif self.role == "Content Analysis":
|
31 |
+
return self.analyzer_logic(query, context)
|
32 |
+
elif self.role == "Response Generation":
|
33 |
+
return self.generator_logic(query, context)
|
34 |
+
elif self.role == "Task Coordination":
|
35 |
+
return self.coordinator_logic(query, context)
|
36 |
+
|
37 |
+
def retriever_logic(self, query, all_splits):
|
38 |
+
query_embedding = self.doc_retrieval_gen.embeddings.encode(query, convert_to_tensor=True).cpu().numpy()
|
39 |
+
distances, indices = self.doc_retrieval_gen.gpu_index.search(np.array([query_embedding]), k=3)
|
40 |
+
relevant_docs = [all_splits[i] for i in indices[0] if distances[0][i] <= 1]
|
41 |
+
return relevant_docs
|
42 |
+
|
43 |
+
def analyzer_logic(self, query, relevant_docs):
|
44 |
+
analysis_prompt = f"Analyze the following documents in relation to the query: '{query}'\n\nDocuments:\n"
|
45 |
+
for doc in relevant_docs:
|
46 |
+
analysis_prompt += f"- {doc.page_content}\n"
|
47 |
+
analysis_prompt += "\nProvide a concise analysis of the key points relevant to the query."
|
48 |
+
|
49 |
+
input_ids = self.tokenizer.encode(analysis_prompt, return_tensors="pt").to(self.doc_retrieval_gen.model.device)
|
50 |
+
analysis = self.doc_retrieval_gen.model.generate(input_ids, max_length=200, num_return_sequences=1)
|
51 |
+
return self.tokenizer.decode(analysis[0], skip_special_tokens=True)
|
52 |
+
|
53 |
+
def generator_logic(self, query, analyzed_content):
|
54 |
+
generation_prompt = f"Based on the following analysis, generate a comprehensive answer to the query: '{query}'\n\nAnalysis:\n{analyzed_content}\n\nGenerate a detailed response:"
|
55 |
+
|
56 |
+
input_ids = self.tokenizer.encode(generation_prompt, return_tensors="pt").to(self.doc_retrieval_gen.model.device)
|
57 |
+
response = self.doc_retrieval_gen.model.generate(input_ids, max_length=300, num_return_sequences=1)
|
58 |
+
return self.tokenizer.decode(response[0], skip_special_tokens=True)
|
59 |
+
|
60 |
+
def coordinator_logic(self, query, final_response):
|
61 |
+
coordination_prompt = f"As a coordinator, review and refine the following response to the query: '{query}'\n\nResponse:\n{final_response}\n\nProvide a final, polished answer:"
|
62 |
+
|
63 |
+
input_ids = self.tokenizer.encode(coordination_prompt, return_tensors="pt").to(self.doc_retrieval_gen.model.device)
|
64 |
+
coordinated_response = self.doc_retrieval_gen.model.generate(input_ids, max_length=350, num_return_sequences=1)
|
65 |
+
return self.tokenizer.decode(coordinated_response[0], skip_special_tokens=True)
|
66 |
+
|
67 |
class DocumentRetrievalAndGeneration:
|
68 |
def __init__(self, embedding_model_name, lm_model_id, data_folder):
|
69 |
self.all_splits = self.load_documents(data_folder)
|
70 |
self.embeddings = SentenceTransformer(embedding_model_name)
|
71 |
self.gpu_index = self.create_faiss_index()
|
72 |
self.tokenizer, self.model = self.initialize_llm(lm_model_id)
|
73 |
+
self.agents = self.initialize_agents()
|
74 |
+
|
75 |
+
def initialize_agents(self):
|
76 |
+
agents = [
|
77 |
+
Agent("Retriever", "Information Retrieval", self, self.tokenizer),
|
78 |
+
Agent("Analyzer", "Content Analysis", self, self.tokenizer),
|
79 |
+
Agent("Generator", "Response Generation", self, self.tokenizer),
|
80 |
+
Agent("Coordinator", "Task Coordination", self, self.tokenizer)
|
81 |
+
]
|
82 |
+
return agents
|
83 |
|
84 |
def load_documents(self, folder_path):
|
85 |
loader = DirectoryLoader(folder_path, loader_cls=TextLoader)
|
|
|
118 |
)
|
119 |
return tokenizer, model
|
120 |
|
121 |
+
def coordinate_agents(self, query):
|
122 |
+
coordinator = next(agent for agent in self.agents if agent.name == "Coordinator")
|
123 |
+
|
124 |
+
# Step 1: Information Retrieval
|
125 |
+
retriever = next(agent for agent in self.agents if agent.name == "Retriever")
|
126 |
+
relevant_docs = retriever.generate_response(query, self.all_splits)
|
127 |
+
|
128 |
+
# Step 2: Content Analysis
|
129 |
+
analyzer = next(agent for agent in self.agents if agent.name == "Analyzer")
|
130 |
+
analyzed_content = analyzer.generate_response(query, relevant_docs)
|
131 |
+
|
132 |
+
# Step 3: Response Generation
|
133 |
+
generator = next(agent for agent in self.agents if agent.name == "Generator")
|
134 |
+
final_response = generator.generate_response(query, analyzed_content)
|
135 |
+
|
136 |
+
# Step 4: Coordination and Refinement
|
137 |
+
coordinated_response = coordinator.generate_response(query, final_response)
|
138 |
+
|
139 |
+
return coordinated_response, "\n".join([doc.page_content for doc in relevant_docs])
|
140 |
+
|
141 |
+
def query_and_generate_response(self, query):
|
142 |
+
return self.coordinate_agents(query)
|
143 |
+
|
144 |
+
|
145 |
def generate_response_with_timeout(self, input_ids, max_new_tokens=1000):
|
146 |
try:
|
147 |
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
|
|
169 |
print(f"Error in generate_response_with_timeout: {str(e)}")
|
170 |
return "Text generation process encountered an error"
|
171 |
|
172 |
+
|
173 |
def query_and_generate_response(self, query):
|
174 |
similarityThreshold = 1
|
175 |
query_embedding = self.embeddings.encode(query, convert_to_tensor=True).cpu().numpy()
|
|
|
226 |
return solution_text, content
|
227 |
|
228 |
def qa_infer_gradio(self, query):
|
229 |
+
response, related_queries = self.query_and_generate_response(query)
|
230 |
+
return response, related_queries
|
231 |
|
232 |
if __name__ == "__main__":
|
233 |
embedding_model_name = 'flax-sentence-embeddings/all_datasets_v3_MiniLM-L12'
|