syedmudassir16 commited on
Commit
bd811e9
1 Parent(s): ebcd296

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -10
app.py CHANGED
@@ -15,6 +15,11 @@ import gradio as gr
15
  import re
16
  from threading import Thread
17
 
 
 
 
 
 
18
  class Agent:
19
  def __init__(self, name, role, doc_retrieval_gen, tokenizer):
20
  self.name = name
@@ -62,6 +67,71 @@ class Agent:
62
  coordinated_response = self.doc_retrieval_gen.model.generate(input_ids, max_length=350, num_return_sequences=1)
63
  return self.tokenizer.decode(coordinated_response[0], skip_special_tokens=True)
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  class DocumentRetrievalAndGeneration:
66
  def __init__(self, embedding_model_name, lm_model_id, data_folder):
67
  self.all_splits = self.load_documents(data_folder)
@@ -69,6 +139,10 @@ class DocumentRetrievalAndGeneration:
69
  self.gpu_index = self.create_faiss_index()
70
  self.tokenizer, self.model = self.initialize_llm(lm_model_id)
71
  self.agents = self.initialize_agents()
 
 
 
 
72
 
73
  def initialize_agents(self):
74
  agents = [
@@ -80,15 +154,14 @@ class DocumentRetrievalAndGeneration:
80
  return agents
81
 
82
  def load_documents(self, folder_path):
83
- loader = DirectoryLoader(folder_path, loader_cls=TextLoader)
84
- documents = loader.load()
85
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=5000, chunk_overlap=250)
86
- all_splits = text_splitter.split_documents(documents)
87
- print('Length of documents:', len(documents))
88
- print("LEN of all_splits", len(all_splits))
89
- for i in range(3):
90
- print(all_splits[i].page_content)
91
- return all_splits
92
 
93
  def create_faiss_index(self):
94
  all_texts = [split.page_content for split in self.all_splits]
@@ -137,7 +210,8 @@ class DocumentRetrievalAndGeneration:
137
  return coordinated_response, "\n".join([doc.page_content for doc in relevant_docs])
138
 
139
  def query_and_generate_response(self, query):
140
- return self.coordinate_agents(query)
 
141
 
142
 
143
  def generate_response_with_timeout(self, input_ids, max_new_tokens=1000):
 
15
  import re
16
  from threading import Thread
17
 
18
+ from llama_index.core import VectorStoreIndex, Document
19
+ from llama_index.core.tools import QueryEngineTool, ToolMetadata
20
+ from llama_index.agent.openai import OpenAIAgent
21
+
22
+
23
  class Agent:
24
  def __init__(self, name, role, doc_retrieval_gen, tokenizer):
25
  self.name = name
 
67
  coordinated_response = self.doc_retrieval_gen.model.generate(input_ids, max_length=350, num_return_sequences=1)
68
  return self.tokenizer.decode(coordinated_response[0], skip_special_tokens=True)
69
 
70
+ class MultiDocumentAgentSystem:
71
+ def __init__(self, documents_dict, llm, embed_model):
72
+ self.llm = llm
73
+ self.embed_model = embed_model
74
+ self.document_agents = {}
75
+ self.create_document_agents(documents_dict)
76
+ self.top_agent = self.create_top_agent()
77
+
78
+ def create_document_agents(self, documents_dict):
79
+ for doc_name, doc_content in documents_dict.items():
80
+ vector_index = VectorStoreIndex.from_documents([Document(doc_content)])
81
+ summary_index = VectorStoreIndex.from_documents([Document(doc_content)])
82
+
83
+ vector_query_engine = vector_index.as_query_engine(similarity_top_k=2)
84
+ summary_query_engine = summary_index.as_query_engine()
85
+
86
+ query_engine_tools = [
87
+ QueryEngineTool(
88
+ query_engine=vector_query_engine,
89
+ metadata=ToolMetadata(
90
+ name=f"vector_tool_{doc_name}",
91
+ description=f"Useful for specific questions about {doc_name}",
92
+ ),
93
+ ),
94
+ QueryEngineTool(
95
+ query_engine=summary_query_engine,
96
+ metadata=ToolMetadata(
97
+ name=f"summary_tool_{doc_name}",
98
+ description=f"Useful for summarizing content about {doc_name}",
99
+ ),
100
+ ),
101
+ ]
102
+
103
+ self.document_agents[doc_name] = OpenAIAgent.from_tools(
104
+ query_engine_tools,
105
+ llm=self.llm,
106
+ verbose=True,
107
+ system_prompt=f"You are an agent designed to answer queries about {doc_name}.",
108
+ )
109
+
110
+ def create_top_agent(self):
111
+ all_tools = []
112
+ for doc_name, agent in self.document_agents.items():
113
+ doc_tool = QueryEngineTool(
114
+ query_engine=agent,
115
+ metadata=ToolMetadata(
116
+ name=f"tool_{doc_name}",
117
+ description=f"Use this tool for questions about {doc_name}",
118
+ ),
119
+ )
120
+ all_tools.append(doc_tool)
121
+
122
+ obj_index = VectorStoreIndex.from_objects(all_tools, embed_model=self.embed_model)
123
+
124
+ return OpenAIAgent.from_tools(
125
+ all_tools,
126
+ llm=self.llm,
127
+ verbose=True,
128
+ system_prompt="You are an agent designed to answer queries about multiple documents.",
129
+ tool_retriever=obj_index.as_retriever(similarity_top_k=3),
130
+ )
131
+
132
+ def query(self, user_input):
133
+ return self.top_agent.chat(user_input)
134
+
135
  class DocumentRetrievalAndGeneration:
136
  def __init__(self, embedding_model_name, lm_model_id, data_folder):
137
  self.all_splits = self.load_documents(data_folder)
 
139
  self.gpu_index = self.create_faiss_index()
140
  self.tokenizer, self.model = self.initialize_llm(lm_model_id)
141
  self.agents = self.initialize_agents()
142
+ documents_dict = self.load_documents(data_folder)
143
+ self.multi_doc_system = MultiDocumentAgentSystem(documents_dict, self.model, self.embeddings)
144
+
145
+
146
 
147
  def initialize_agents(self):
148
  agents = [
 
154
  return agents
155
 
156
  def load_documents(self, folder_path):
157
+ documents_dict = {}
158
+ for file_name in os.listdir(folder_path):
159
+ if file_name.endswith('.txt'):
160
+ file_path = os.path.join(folder_path, file_name)
161
+ with open(file_path, 'r', encoding='utf-8') as file:
162
+ content = file.read()
163
+ documents_dict[file_name[:-4]] = content # Use filename without .txt as key
164
+ return documents_dict
 
165
 
166
  def create_faiss_index(self):
167
  all_texts = [split.page_content for split in self.all_splits]
 
210
  return coordinated_response, "\n".join([doc.page_content for doc in relevant_docs])
211
 
212
  def query_and_generate_response(self, query):
213
+ response = self.multi_doc_system.query(query)
214
+ return str(response), ""
215
 
216
 
217
  def generate_response_with_timeout(self, input_ids, max_new_tokens=1000):