syedmudassir16 commited on
Commit
7eccbd5
·
verified ·
1 Parent(s): 63c43cd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -2
app.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import os
2
  import multiprocessing
3
  import concurrent.futures
@@ -15,12 +17,69 @@ import gradio as gr
15
  import re
16
  from threading import Thread
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  class DocumentRetrievalAndGeneration:
19
  def __init__(self, embedding_model_name, lm_model_id, data_folder):
20
  self.all_splits = self.load_documents(data_folder)
21
  self.embeddings = SentenceTransformer(embedding_model_name)
22
  self.gpu_index = self.create_faiss_index()
23
  self.tokenizer, self.model = self.initialize_llm(lm_model_id)
 
 
 
 
 
 
 
 
 
 
24
 
25
  def load_documents(self, folder_path):
26
  loader = DirectoryLoader(folder_path, loader_cls=TextLoader)
@@ -59,6 +118,30 @@ class DocumentRetrievalAndGeneration:
59
  )
60
  return tokenizer, model
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  def generate_response_with_timeout(self, input_ids, max_new_tokens=1000):
63
  try:
64
  streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
@@ -86,6 +169,7 @@ class DocumentRetrievalAndGeneration:
86
  print(f"Error in generate_response_with_timeout: {str(e)}")
87
  return "Text generation process encountered an error"
88
 
 
89
  def query_and_generate_response(self, query):
90
  similarityThreshold = 1
91
  query_embedding = self.embeddings.encode(query, convert_to_tensor=True).cpu().numpy()
@@ -142,8 +226,8 @@ class DocumentRetrievalAndGeneration:
142
  return solution_text, content
143
 
144
  def qa_infer_gradio(self, query):
145
- response = self.query_and_generate_response(query)
146
- return response
147
 
148
  if __name__ == "__main__":
149
  embedding_model_name = 'flax-sentence-embeddings/all_datasets_v3_MiniLM-L12'
 
1
+ Updated Multi-agent RAG-based LLM Model
2
+
3
  import os
4
  import multiprocessing
5
  import concurrent.futures
 
17
  import re
18
  from threading import Thread
19
 
20
+ class Agent:
21
+ def __init__(self, name, role, doc_retrieval_gen, tokenizer):
22
+ self.name = name
23
+ self.role = role
24
+ self.doc_retrieval_gen = doc_retrieval_gen
25
+ self.tokenizer = tokenizer
26
+
27
+ def generate_response(self, query, context):
28
+ if self.role == "Information Retrieval":
29
+ return self.retriever_logic(query, context)
30
+ elif self.role == "Content Analysis":
31
+ return self.analyzer_logic(query, context)
32
+ elif self.role == "Response Generation":
33
+ return self.generator_logic(query, context)
34
+ elif self.role == "Task Coordination":
35
+ return self.coordinator_logic(query, context)
36
+
37
+ def retriever_logic(self, query, all_splits):
38
+ query_embedding = self.doc_retrieval_gen.embeddings.encode(query, convert_to_tensor=True).cpu().numpy()
39
+ distances, indices = self.doc_retrieval_gen.gpu_index.search(np.array([query_embedding]), k=3)
40
+ relevant_docs = [all_splits[i] for i in indices[0] if distances[0][i] <= 1]
41
+ return relevant_docs
42
+
43
+ def analyzer_logic(self, query, relevant_docs):
44
+ analysis_prompt = f"Analyze the following documents in relation to the query: '{query}'\n\nDocuments:\n"
45
+ for doc in relevant_docs:
46
+ analysis_prompt += f"- {doc.page_content}\n"
47
+ analysis_prompt += "\nProvide a concise analysis of the key points relevant to the query."
48
+
49
+ input_ids = self.tokenizer.encode(analysis_prompt, return_tensors="pt").to(self.doc_retrieval_gen.model.device)
50
+ analysis = self.doc_retrieval_gen.model.generate(input_ids, max_length=200, num_return_sequences=1)
51
+ return self.tokenizer.decode(analysis[0], skip_special_tokens=True)
52
+
53
+ def generator_logic(self, query, analyzed_content):
54
+ generation_prompt = f"Based on the following analysis, generate a comprehensive answer to the query: '{query}'\n\nAnalysis:\n{analyzed_content}\n\nGenerate a detailed response:"
55
+
56
+ input_ids = self.tokenizer.encode(generation_prompt, return_tensors="pt").to(self.doc_retrieval_gen.model.device)
57
+ response = self.doc_retrieval_gen.model.generate(input_ids, max_length=300, num_return_sequences=1)
58
+ return self.tokenizer.decode(response[0], skip_special_tokens=True)
59
+
60
+ def coordinator_logic(self, query, final_response):
61
+ coordination_prompt = f"As a coordinator, review and refine the following response to the query: '{query}'\n\nResponse:\n{final_response}\n\nProvide a final, polished answer:"
62
+
63
+ input_ids = self.tokenizer.encode(coordination_prompt, return_tensors="pt").to(self.doc_retrieval_gen.model.device)
64
+ coordinated_response = self.doc_retrieval_gen.model.generate(input_ids, max_length=350, num_return_sequences=1)
65
+ return self.tokenizer.decode(coordinated_response[0], skip_special_tokens=True)
66
+
67
  class DocumentRetrievalAndGeneration:
68
  def __init__(self, embedding_model_name, lm_model_id, data_folder):
69
  self.all_splits = self.load_documents(data_folder)
70
  self.embeddings = SentenceTransformer(embedding_model_name)
71
  self.gpu_index = self.create_faiss_index()
72
  self.tokenizer, self.model = self.initialize_llm(lm_model_id)
73
+ self.agents = self.initialize_agents()
74
+
75
+ def initialize_agents(self):
76
+ agents = [
77
+ Agent("Retriever", "Information Retrieval", self, self.tokenizer),
78
+ Agent("Analyzer", "Content Analysis", self, self.tokenizer),
79
+ Agent("Generator", "Response Generation", self, self.tokenizer),
80
+ Agent("Coordinator", "Task Coordination", self, self.tokenizer)
81
+ ]
82
+ return agents
83
 
84
  def load_documents(self, folder_path):
85
  loader = DirectoryLoader(folder_path, loader_cls=TextLoader)
 
118
  )
119
  return tokenizer, model
120
 
121
+ def coordinate_agents(self, query):
122
+ coordinator = next(agent for agent in self.agents if agent.name == "Coordinator")
123
+
124
+ # Step 1: Information Retrieval
125
+ retriever = next(agent for agent in self.agents if agent.name == "Retriever")
126
+ relevant_docs = retriever.generate_response(query, self.all_splits)
127
+
128
+ # Step 2: Content Analysis
129
+ analyzer = next(agent for agent in self.agents if agent.name == "Analyzer")
130
+ analyzed_content = analyzer.generate_response(query, relevant_docs)
131
+
132
+ # Step 3: Response Generation
133
+ generator = next(agent for agent in self.agents if agent.name == "Generator")
134
+ final_response = generator.generate_response(query, analyzed_content)
135
+
136
+ # Step 4: Coordination and Refinement
137
+ coordinated_response = coordinator.generate_response(query, final_response)
138
+
139
+ return coordinated_response, "\n".join([doc.page_content for doc in relevant_docs])
140
+
141
+ def query_and_generate_response(self, query):
142
+ return self.coordinate_agents(query)
143
+
144
+
145
  def generate_response_with_timeout(self, input_ids, max_new_tokens=1000):
146
  try:
147
  streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
 
169
  print(f"Error in generate_response_with_timeout: {str(e)}")
170
  return "Text generation process encountered an error"
171
 
172
+
173
  def query_and_generate_response(self, query):
174
  similarityThreshold = 1
175
  query_embedding = self.embeddings.encode(query, convert_to_tensor=True).cpu().numpy()
 
226
  return solution_text, content
227
 
228
  def qa_infer_gradio(self, query):
229
+ response, related_queries = self.query_and_generate_response(query)
230
+ return response, related_queries
231
 
232
  if __name__ == "__main__":
233
  embedding_model_name = 'flax-sentence-embeddings/all_datasets_v3_MiniLM-L12'