himel06 commited on
Commit
1fbc044
·
verified ·
1 Parent(s): 9533e0e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -84
app.py CHANGED
@@ -15,20 +15,18 @@ from pathlib import Path
15
  import chromadb
16
  from unidecode import unidecode
17
 
18
- from transformers import AutoTokenizer
19
- import transformers
20
- import torch
21
- import tqdm
22
- import accelerate
23
  import re
24
 
25
  # LLM model to use
26
  llm_model = "mistralai/Mistral-7B-Instruct-v0.2"
27
 
 
 
28
 
29
- # Load PDF document and create doc splits
30
- def load_doc(list_file_path, chunk_size, chunk_overlap):
31
- loaders = [PyPDFLoader(x) for x in list_file_path]
 
32
  pages = []
33
  for loader in loaders:
34
  pages.extend(loader.load())
@@ -36,8 +34,7 @@ def load_doc(list_file_path, chunk_size, chunk_overlap):
36
  chunk_size=chunk_size,
37
  chunk_overlap=chunk_overlap)
38
  doc_splits = text_splitter.split_documents(pages)
39
- return doc_splits
40
-
41
 
42
  # Create vector database
43
  def create_db(splits, collection_name):
@@ -51,7 +48,6 @@ def create_db(splits, collection_name):
51
  )
52
  return vectordb
53
 
54
-
55
  # Load vector database
56
  def load_db():
57
  embedding = HuggingFaceEmbeddings()
@@ -59,7 +55,6 @@ def load_db():
59
  embedding_function=embedding)
60
  return vectordb
61
 
62
-
63
  # Initialize langchain LLM chain
64
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
65
  progress(0.5, desc="Initializing HF Hub...")
@@ -90,7 +85,6 @@ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, pr
90
  progress(0.9, desc="Done!")
91
  return qa_chain
92
 
93
-
94
  # Generate collection name for vector database
95
  def create_collection_name(filepath):
96
  collection_name = Path(filepath).stem
@@ -108,26 +102,21 @@ def create_collection_name(filepath):
108
  print('Collection name: ', collection_name)
109
  return collection_name
110
 
111
-
112
  # Initialize database
113
- def initialize_database(list_file_obj, chunk_size, chunk_overlap, progress=gr.Progress()):
114
- list_file_path = [x.name for x in list_file_obj if x is not None]
115
- progress(0.1, desc="Creating collection name...")
116
- collection_name = create_collection_name(list_file_path[0])
117
- progress(0.25, desc="Loading document...")
118
- doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
119
  progress(0.5, desc="Generating vector database...")
120
  vector_db = create_db(doc_splits, collection_name)
121
- progress(0.9, desc="Done!")
122
  return vector_db, collection_name, "Complete!"
123
 
124
-
125
  def initialize_LLM(llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
126
  print("LLM model: ", llm_model)
127
  qa_chain = initialize_llmchain(llm_model, llm_temperature, max_tokens, top_k, vector_db, progress)
128
  return qa_chain, "Complete!"
129
 
130
-
131
  def format_chat_history(message, chat_history):
132
  formatted_chat_history = []
133
  for user_message, bot_message in chat_history:
@@ -135,7 +124,6 @@ def format_chat_history(message, chat_history):
135
  formatted_chat_history.append(f"Assistant: {bot_message}")
136
  return formatted_chat_history
137
 
138
-
139
  def conversation(qa_chain, message, history):
140
  formatted_chat_history = format_chat_history(message, history)
141
  response = qa_chain({"question": message, "chat_history": formatted_chat_history})
@@ -153,15 +141,6 @@ def conversation(qa_chain, message, history):
153
  return qa_chain, gr.update(
154
  value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
155
 
156
-
157
- def upload_file(file_obj):
158
- list_file_path = []
159
- for idx, file in enumerate(file_obj):
160
- file_path = file_obj.name
161
- list_file_path.append(file_path)
162
- return list_file_path
163
-
164
-
165
  def demo():
166
  with gr.Blocks(theme="base") as demo:
167
  vector_db = gr.State()
@@ -178,62 +157,53 @@ def demo():
178
  <br><b>Warning:</b> This space uses the free CPU Basic hardware from Hugging Face. Some steps and LLM models used below (free inference endpoints) can take some time to generate a reply.
179
  """)
180
 
181
- with gr.Tab("Step 1 - Upload PDF"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  with gr.Row():
183
- document = gr.Files(height=100, file_count="multiple", file_types=["pdf"], interactive=True,
184
- label="Upload your PDF documents (single or multiple)")
185
-
186
- with gr.Tab("Step 2 - Process document"):
187
- with gr.Row():
188
- db_btn = gr.Radio(["ChromaDB"], label="Vector database type", value="ChromaDB", type="index",
189
- info="Choose your vector database")
190
- with gr.Accordion("Advanced options - Document text splitter", open=False):
191
- with gr.Row():
192
- slider_chunk_size = gr.Slider(minimum=100, maximum=1000, value=600, step=20, label="Chunk size",
193
- info="Chunk size", interactive=True)
194
- with gr.Row():
195
- slider_chunk_overlap = gr.Slider(minimum=10, maximum=200, value=40, step=10, label="Chunk overlap",
196
- info="Chunk overlap", interactive=True)
197
- with gr.Row():
198
- db_progress = gr.Textbox(label="Vector database initialization", value="None")
199
- with gr.Row():
200
- db_btn = gr.Button("Generate vector database")
201
-
202
- with gr.Tab("Step 3 - Initialize QA chain"):
203
- with gr.Row():
204
- slider_temperature = gr.Slider(minimum=0.01, maximum=1.0, value=0.7, step=0.1, label="Temperature",
205
- info="Model temperature", interactive=True)
206
- with gr.Row():
207
- slider_maxtokens = gr.Slider(minimum=224, maximum=4096, value=1024, step=32, label="Max Tokens",
208
- info="Model max tokens", interactive=True)
209
- with gr.Row():
210
- slider_topk = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="top-k samples",
211
- info="Model top-k samples", interactive=True)
212
- with gr.Row():
213
- llm_progress = gr.Textbox(value="None", label="QA chain initialization")
214
- with gr.Row():
215
- qachain_btn = gr.Button("Initialize Question Answering chain")
216
-
217
- with gr.Tab("Step 4 - Chatbot"):
218
- chatbot = gr.Chatbot(height=300)
219
- with gr.Accordion("Advanced - Document references", open=False):
220
- with gr.Row():
221
- doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
222
- source1_page = gr.Number(label="Page", scale=1)
223
- with gr.Row():
224
- doc_source2 = gr.Textbox(label="Reference 2", lines=2, container=True, scale=20)
225
- source2_page = gr.Number(label="Page", scale=1)
226
- with gr.Row():
227
- doc_source3 = gr.Textbox(label="Reference 3", lines=2, container=True, scale=20)
228
- source3_page = gr.Number(label="Page", scale=1)
229
  with gr.Row():
230
- msg = gr.Textbox(placeholder="Type message (e.g. 'What is this document about?')", container=True)
 
231
  with gr.Row():
232
- submit_btn = gr.Button("Submit message")
233
- clear_btn = gr.ClearButton([msg, chatbot], value="Clear conversation")
 
 
 
 
 
234
 
235
  db_btn.click(initialize_database, \
236
- inputs=[document, slider_chunk_size, slider_chunk_overlap], \
237
  outputs=[vector_db, collection_name, db_progress])
238
  qachain_btn.click(initialize_LLM, \
239
  inputs=[slider_temperature, slider_maxtokens, slider_topk, vector_db], \
 
15
  import chromadb
16
  from unidecode import unidecode
17
 
 
 
 
 
 
18
  import re
19
 
20
  # LLM model to use
21
  llm_model = "mistralai/Mistral-7B-Instruct-v0.2"
22
 
23
+ # Directory where PDFs are stored
24
+ pdf_directory = "data"
25
 
26
+ # Load PDF documents from the specified directory and create doc splits
27
+ def load_docs_from_directory(directory_path, chunk_size, chunk_overlap):
28
+ pdf_files = [os.path.join(directory_path, f) for f in os.listdir(directory_path) if f.endswith('.pdf')]
29
+ loaders = [PyPDFLoader(file) for file in pdf_files]
30
  pages = []
31
  for loader in loaders:
32
  pages.extend(loader.load())
 
34
  chunk_size=chunk_size,
35
  chunk_overlap=chunk_overlap)
36
  doc_splits = text_splitter.split_documents(pages)
37
+ return doc_splits, pdf_files
 
38
 
39
  # Create vector database
40
  def create_db(splits, collection_name):
 
48
  )
49
  return vectordb
50
 
 
51
  # Load vector database
52
  def load_db():
53
  embedding = HuggingFaceEmbeddings()
 
55
  embedding_function=embedding)
56
  return vectordb
57
 
 
58
  # Initialize langchain LLM chain
59
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
60
  progress(0.5, desc="Initializing HF Hub...")
 
85
  progress(0.9, desc="Done!")
86
  return qa_chain
87
 
 
88
  # Generate collection name for vector database
89
  def create_collection_name(filepath):
90
  collection_name = Path(filepath).stem
 
102
  print('Collection name: ', collection_name)
103
  return collection_name
104
 
 
105
  # Initialize database
106
+ def initialize_database(directory_path, chunk_size, chunk_overlap, progress=gr.Progress()):
107
+ progress(0.1, desc="Loading documents from directory...")
108
+ doc_splits, pdf_files = load_docs_from_directory(directory_path, chunk_size, chunk_overlap)
109
+ collection_name = create_collection_name(pdf_files[0])
 
 
110
  progress(0.5, desc="Generating vector database...")
111
  vector_db = create_db(doc_splits, collection_name)
112
+ progress(0.9, desc="Database initialization complete!")
113
  return vector_db, collection_name, "Complete!"
114
 
 
115
  def initialize_LLM(llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
116
  print("LLM model: ", llm_model)
117
  qa_chain = initialize_llmchain(llm_model, llm_temperature, max_tokens, top_k, vector_db, progress)
118
  return qa_chain, "Complete!"
119
 
 
120
  def format_chat_history(message, chat_history):
121
  formatted_chat_history = []
122
  for user_message, bot_message in chat_history:
 
124
  formatted_chat_history.append(f"Assistant: {bot_message}")
125
  return formatted_chat_history
126
 
 
127
  def conversation(qa_chain, message, history):
128
  formatted_chat_history = format_chat_history(message, history)
129
  response = qa_chain({"question": message, "chat_history": formatted_chat_history})
 
141
  return qa_chain, gr.update(
142
  value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
143
 
 
 
 
 
 
 
 
 
 
144
  def demo():
145
  with gr.Blocks(theme="base") as demo:
146
  vector_db = gr.State()
 
157
  <br><b>Warning:</b> This space uses the free CPU Basic hardware from Hugging Face. Some steps and LLM models used below (free inference endpoints) can take some time to generate a reply.
158
  """)
159
 
160
+ gr.Markdown("<h4>Step 1 - Process and Load Documents from 'data' Folder</h4>")
161
+ with gr.Row():
162
+ slider_chunk_size = gr.Slider(minimum=100, maximum=1000, value=600, step=20, label="Chunk size",
163
+ info="Chunk size", interactive=True)
164
+ with gr.Row():
165
+ slider_chunk_overlap = gr.Slider(minimum=10, maximum=200, value=40, step=10, label="Chunk overlap",
166
+ info="Chunk overlap", interactive=True)
167
+ with gr.Row():
168
+ db_progress = gr.Textbox(label="Vector database initialization", value="None")
169
+ with gr.Row():
170
+ db_btn = gr.Button("Generate vector database")
171
+
172
+ gr.Markdown("<h4>Step 2 - Initialize QA chain</h4>")
173
+ with gr.Row():
174
+ slider_temperature = gr.Slider(minimum=0.01, maximum=1.0, value=0.7, step=0.1, label="Temperature",
175
+ info="Model temperature", interactive=True)
176
+ with gr.Row():
177
+ slider_maxtokens = gr.Slider(minimum=224, maximum=4096, value=1024, step=32, label="Max Tokens",
178
+ info="Model max tokens", interactive=True)
179
+ with gr.Row():
180
+ slider_topk = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="top-k samples",
181
+ info="Model top-k samples", interactive=True)
182
+ with gr.Row():
183
+ llm_progress = gr.Textbox(value="None", label="QA chain initialization")
184
+ with gr.Row():
185
+ qachain_btn = gr.Button("Initialize Question Answering chain")
186
+
187
+ gr.Markdown("<h4>Step 3 - Chatbot</h4>")
188
+ chatbot = gr.Chatbot(height=300)
189
+ with gr.Accordion("Advanced - Document references", open=False):
190
  with gr.Row():
191
+ doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
192
+ source1_page = gr.Number(label="Page", scale=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  with gr.Row():
194
+ doc_source2 = gr.Textbox(label="Reference 2", lines=2, container=True, scale=20)
195
+ source2_page = gr.Number(label="Page", scale=1)
196
  with gr.Row():
197
+ doc_source3 = gr.Textbox(label="Reference 3", lines=2, container=True, scale=20)
198
+ source3_page = gr.Number(label="Page", scale=1)
199
+ with gr.Row():
200
+ msg = gr.Textbox(placeholder="Type message (e.g. 'What is this document about?')", container=True)
201
+ with gr.Row():
202
+ submit_btn = gr.Button("Submit message")
203
+ clear_btn = gr.ClearButton([msg, chatbot], value="Clear conversation")
204
 
205
  db_btn.click(initialize_database, \
206
+ inputs=[pdf_directory, slider_chunk_size, slider_chunk_overlap], \
207
  outputs=[vector_db, collection_name, db_progress])
208
  qachain_btn.click(initialize_LLM, \
209
  inputs=[slider_temperature, slider_maxtokens, slider_topk, vector_db], \