Files changed (1) hide show
  1. app.py +11 -333
app.py CHANGED
@@ -22,347 +22,25 @@ import tqdm
22
  import accelerate
23
  import re
24
 
25
-
26
-
27
- # default_persist_directory = './chroma_HF/'
28
  list_llm = ["HuggingFaceH4/zephyr-7b-beta", "mistralai/Mistral-7B-Instruct-v0.2"]
29
  list_llm_simple = [os.path.basename(llm) for llm in list_llm]
30
 
31
-
32
- # Load PDF document and create doc splits
33
- def load_doc(list_file_path, chunk_size, chunk_overlap):
34
- # Processing for one document only
35
- # loader = PyPDFLoader(file_path)
36
- # pages = loader.load()
37
- loaders = [PyPDFLoader(x) for x in list_file_path]
38
- pages = []
39
- for loader in loaders:
40
- pages.extend(loader.load())
41
- # text_splitter = RecursiveCharacterTextSplitter(chunk_size = 600, chunk_overlap = 50)
42
- text_splitter = RecursiveCharacterTextSplitter(
43
- chunk_size = chunk_size,
44
- chunk_overlap = chunk_overlap)
45
- doc_splits = text_splitter.split_documents(pages)
46
- return doc_splits
47
-
48
-
49
- # Create vector database
50
- def create_db(splits, collection_name):
51
- embedding = HuggingFaceEmbeddings()
52
- new_client = chromadb.EphemeralClient()
53
- vectordb = Chroma.from_documents(
54
- documents=splits,
55
- embedding=embedding,
56
- client=new_client,
57
- collection_name=collection_name,
58
- # persist_directory=default_persist_directory
59
- )
60
- return vectordb
61
-
62
-
63
- # Load vector database
64
- def load_db():
65
- embedding = HuggingFaceEmbeddings()
66
- vectordb = Chroma(
67
- # persist_directory=default_persist_directory,
68
- embedding_function=embedding)
69
- return vectordb
70
-
71
-
72
- # Initialize langchain LLM chain
73
- def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
74
- progress(0.1, desc="Initializing HF tokenizer...")
75
- # HuggingFacePipeline uses local model
76
- # Note: it will download model locally...
77
- # tokenizer=AutoTokenizer.from_pretrained(llm_model)
78
- # progress(0.5, desc="Initializing HF pipeline...")
79
- # pipeline=transformers.pipeline(
80
- # "text-generation",
81
- # model=llm_model,
82
- # tokenizer=tokenizer,
83
- # torch_dtype=torch.bfloat16,
84
- # trust_remote_code=True,
85
- # device_map="auto",
86
- # # max_length=1024,
87
- # max_new_tokens=max_tokens,
88
- # do_sample=True,
89
- # top_k=top_k,
90
- # num_return_sequences=1,
91
- # eos_token_id=tokenizer.eos_token_id
92
- # )
93
- # llm = HuggingFacePipeline(pipeline=pipeline, model_kwargs={'temperature': temperature})
94
-
95
- # HuggingFaceHub uses HF inference endpoints
96
- progress(0.5, desc="Initializing HF Hub...")
97
- # Use of trust_remote_code as model_kwargs
98
- # Warning: langchain issue
99
- # URL: https://github.com/langchain-ai/langchain/issues/6080
100
- if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1":
101
- llm = HuggingFaceEndpoint(
102
- repo_id=llm_model,
103
- # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "load_in_8bit": True}
104
- temperature = temperature,
105
- max_new_tokens = max_tokens,
106
- top_k = top_k,
107
- load_in_8bit = True,
108
- )
109
- elif llm_model in ["HuggingFaceH4/zephyr-7b-gemma-v0.1","mosaicml/mpt-7b-instruct"]:
110
- raise gr.Error("LLM model is too large to be loaded automatically on free inference endpoint")
111
- llm = HuggingFaceEndpoint(
112
- repo_id=llm_model,
113
- temperature = temperature,
114
- max_new_tokens = max_tokens,
115
- top_k = top_k,
116
- )
117
- elif llm_model == "microsoft/phi-2":
118
- raise gr.Error("phi-2 model requires 'trust_remote_code=True', currently not supported by langchain HuggingFaceHub...")
119
- llm = HuggingFaceEndpoint(
120
- repo_id=llm_model,
121
- # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "trust_remote_code": True, "torch_dtype": "auto"}
122
- temperature = temperature,
123
- max_new_tokens = max_tokens,
124
- top_k = top_k,
125
- trust_remote_code = True,
126
- torch_dtype = "auto",
127
- )
128
- elif llm_model == "TinyLlama/TinyLlama-1.1B-Chat-v1.0":
129
- llm = HuggingFaceEndpoint(
130
- repo_id=llm_model,
131
- # model_kwargs={"temperature": temperature, "max_new_tokens": 250, "top_k": top_k}
132
- temperature = temperature,
133
- max_new_tokens = 250,
134
- top_k = top_k,
135
- )
136
- elif llm_model == "meta-llama/Llama-2-7b-chat-hf":
137
- raise gr.Error("Llama-2-7b-chat-hf model requires a Pro subscription...")
138
- llm = HuggingFaceEndpoint(
139
- repo_id=llm_model,
140
- # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k}
141
- temperature = temperature,
142
- max_new_tokens = max_tokens,
143
- top_k = top_k,
144
- )
145
- else:
146
- llm = HuggingFaceEndpoint(
147
- repo_id=llm_model,
148
- # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "trust_remote_code": True, "torch_dtype": "auto"}
149
- # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k}
150
- temperature = temperature,
151
- max_new_tokens = max_tokens,
152
- top_k = top_k,
153
- )
154
-
155
- progress(0.75, desc="Defining buffer memory...")
156
- memory = ConversationBufferMemory(
157
- memory_key="chat_history",
158
- output_key='answer',
159
- return_messages=True
160
- )
161
- # retriever=vector_db.as_retriever(search_type="similarity", search_kwargs={'k': 3})
162
- retriever=vector_db.as_retriever()
163
- progress(0.8, desc="Defining retrieval chain...")
164
- qa_chain = ConversationalRetrievalChain.from_llm(
165
- llm,
166
- retriever=retriever,
167
- chain_type="stuff",
168
- memory=memory,
169
- # combine_docs_chain_kwargs={"prompt": your_prompt})
170
- return_source_documents=True,
171
- #return_generated_question=False,
172
- verbose=False,
173
- )
174
- progress(0.9, desc="Done!")
175
- return qa_chain
176
-
177
-
178
- # Generate collection name for vector database
179
- # - Use filepath as input, ensuring unicode text
180
- def create_collection_name(filepath):
181
- # Extract filename without extension
182
- collection_name = Path(filepath).stem
183
- # Fix potential issues from naming convention
184
- ## Remove space
185
- collection_name = collection_name.replace(" ","-")
186
- ## ASCII transliterations of Unicode text
187
- collection_name = unidecode(collection_name)
188
- ## Remove special characters
189
- #collection_name = re.findall("[\dA-Za-z]*", collection_name)[0]
190
- collection_name = re.sub('[^A-Za-z0-9]+', '-', collection_name)
191
- ## Limit length to 50 characters
192
- collection_name = collection_name[:50]
193
- ## Minimum length of 3 characters
194
- if len(collection_name) < 3:
195
- collection_name = collection_name + 'xyz'
196
- ## Enforce start and end as alphanumeric character
197
- if not collection_name[0].isalnum():
198
- collection_name = 'A' + collection_name[1:]
199
- if not collection_name[-1].isalnum():
200
- collection_name = collection_name[:-1] + 'Z'
201
- print('Filepath: ', filepath)
202
- print('Collection name: ', collection_name)
203
- return collection_name
204
-
205
-
206
- # Initialize database
207
- def initialize_database(list_file_obj, chunk_size, chunk_overlap, progress=gr.Progress()):
208
- # Create list of documents (when valid)
209
- list_file_path = [x.name for x in list_file_obj if x is not None]
210
- # Create collection_name for vector database
211
- progress(0.1, desc="Creating collection name...")
212
- collection_name = create_collection_name(list_file_path[0])
213
- progress(0.25, desc="Loading document...")
214
- # Load document and create splits
215
- doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
216
- # Create or load vector database
217
- progress(0.5, desc="Generating vector database...")
218
- # global vector_db
219
- vector_db = create_db(doc_splits, collection_name)
220
- progress(0.9, desc="Done!")
221
- return vector_db, collection_name, "Complete!"
222
-
223
-
224
- def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
225
- # print("llm_option",llm_option)
226
- llm_name = list_llm[llm_option]
227
- print("llm_name: ",llm_name)
228
- qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress)
229
- return qa_chain, "Complete!"
230
-
231
-
232
- def format_chat_history(message, chat_history):
233
- formatted_chat_history = []
234
- for user_message, bot_message in chat_history:
235
- formatted_chat_history.append(f"User: {user_message}")
236
- formatted_chat_history.append(f"Assistant: {bot_message}")
237
- return formatted_chat_history
238
-
239
-
240
- def conversation(qa_chain, message, history):
241
- formatted_chat_history = format_chat_history(message, history)
242
- #print("formatted_chat_history",formatted_chat_history)
243
-
244
- # Generate response using QA chain
245
- response = qa_chain({"question": message, "chat_history": formatted_chat_history})
246
- response_answer = response["answer"]
247
- if response_answer.find("Helpful Answer:") != -1:
248
- response_answer = response_answer.split("Helpful Answer:")[-1]
249
- response_sources = response["source_documents"]
250
- response_source1 = response_sources[0].page_content.strip()
251
- response_source2 = response_sources[1].page_content.strip()
252
- response_source3 = response_sources[2].page_content.strip()
253
- # Langchain sources are zero-based
254
- response_source1_page = response_sources[0].metadata["page"] + 1
255
- response_source2_page = response_sources[1].metadata["page"] + 1
256
- response_source3_page = response_sources[2].metadata["page"] + 1
257
- # print ('chat response: ', response_answer)
258
- # print('DB source', response_sources)
259
-
260
- # Append user message and response to chat history
261
- new_history = history + [(message, response_answer)]
262
- # return gr.update(value=""), new_history, response_sources[0], response_sources[1]
263
- return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
264
-
265
-
266
- def upload_file(file_obj):
267
- list_file_path = []
268
- for idx, file in enumerate(file_obj):
269
- file_path = file_obj.name
270
- list_file_path.append(file_path)
271
- # print(file_path)
272
- # initialize_database(file_path, progress)
273
- return list_file_path
274
-
275
 
276
  def demo():
277
  with gr.Blocks(theme="base") as demo:
278
- vector_db = gr.State()
279
- qa_chain = gr.State()
280
- collection_name = gr.State()
281
 
282
- gr.Markdown(
283
- """<center><h2>PDF-based chatbot</center></h2>
284
- <h3>Ask any questions about your PDF documents</h3>""")
285
-
286
 
287
- with gr.Tab("Step 1 - Upload PDF"):
288
- with gr.Row():
289
- document = gr.Files(height=100, file_count="multiple", file_types=["pdf"], interactive=True, label="Upload your PDF documents (single or multiple)")
290
- # upload_btn = gr.UploadButton("Loading document...", height=100, file_count="multiple", file_types=["pdf"], scale=1)
291
 
292
- with gr.Tab("Step 2 - Process document"):
293
- with gr.Row():
294
- db_btn = gr.Radio(["ChromaDB"], label="Vector database type", value = "ChromaDB", type="index", info="Choose your vector database")
295
- with gr.Accordion("Advanced options - Document text splitter", open=False):
296
- with gr.Row():
297
- slider_chunk_size = gr.Slider(minimum = 100, maximum = 1000, value=600, step=20, label="Chunk size", info="Chunk size", interactive=True)
298
- with gr.Row():
299
- slider_chunk_overlap = gr.Slider(minimum = 10, maximum = 200, value=40, step=10, label="Chunk overlap", info="Chunk overlap", interactive=True)
300
- with gr.Row():
301
- db_progress = gr.Textbox(label="Vector database initialization", value="None")
302
- with gr.Row():
303
- db_btn = gr.Button("Generate vector database")
304
-
305
- with gr.Tab("Step 3 - Initialize QA chain"):
306
- with gr.Row():
307
- llm_btn = gr.Radio(list_llm_simple, \
308
- label="LLM models", value = list_llm_simple[0], type="index", info="Choose your LLM model")
309
- with gr.Accordion("Advanced options - LLM model", open=False):
310
- with gr.Row():
311
- slider_temperature = gr.Slider(minimum = 0.01, maximum = 1.0, value=0.7, step=0.1, label="Temperature", info="Model temperature", interactive=True)
312
- with gr.Row():
313
- slider_maxtokens = gr.Slider(minimum = 224, maximum = 4096, value=1024, step=32, label="Max Tokens", info="Model max tokens", interactive=True)
314
- with gr.Row():
315
- slider_topk = gr.Slider(minimum = 1, maximum = 10, value=3, step=1, label="top-k samples", info="Model top-k samples", interactive=True)
316
- with gr.Row():
317
- llm_progress = gr.Textbox(value="None",label="QA chain initialization")
318
- with gr.Row():
319
- qachain_btn = gr.Button("Initialize Question Answering chain")
320
-
321
- with gr.Tab("Step 4 - Chatbot"):
322
- chatbot = gr.Chatbot(height=300)
323
- with gr.Accordion("Advanced - Document references", open=False):
324
- with gr.Row():
325
- doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
326
- source1_page = gr.Number(label="Page", scale=1)
327
- with gr.Row():
328
- doc_source2 = gr.Textbox(label="Reference 2", lines=2, container=True, scale=20)
329
- source2_page = gr.Number(label="Page", scale=1)
330
- with gr.Row():
331
- doc_source3 = gr.Textbox(label="Reference 3", lines=2, container=True, scale=20)
332
- source3_page = gr.Number(label="Page", scale=1)
333
- with gr.Row():
334
- msg = gr.Textbox(placeholder="Type message (e.g. 'What is this document about?')", container=True)
335
- with gr.Row():
336
- submit_btn = gr.Button("Submit message")
337
- clear_btn = gr.ClearButton([msg, chatbot], value="Clear conversation")
338
-
339
- # Preprocessing events
340
- #upload_btn.upload(upload_file, inputs=[upload_btn], outputs=[document])
341
- db_btn.click(initialize_database, \
342
- inputs=[document, slider_chunk_size, slider_chunk_overlap], \
343
- outputs=[vector_db, collection_name, db_progress])
344
- qachain_btn.click(initialize_LLM, \
345
- inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], \
346
- outputs=[qa_chain, llm_progress]).then(lambda:[None,"",0,"",0,"",0], \
347
- inputs=None, \
348
- outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
349
- queue=False)
350
-
351
- # Chatbot events
352
- msg.submit(conversation, \
353
- inputs=[qa_chain, msg, chatbot], \
354
- outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
355
- queue=False)
356
- submit_btn.click(conversation, \
357
- inputs=[qa_chain, msg, chatbot], \
358
- outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
359
- queue=False)
360
- clear_btn.click(lambda:[None,"",0,"",0,"",0], \
361
- inputs=None, \
362
- outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
363
- queue=False)
364
- demo.queue().launch(debug=True)
365
-
366
 
367
  if __name__ == "__main__":
368
- demo()
 
22
  import accelerate
23
  import re
24
 
 
 
 
25
  list_llm = ["HuggingFaceH4/zephyr-7b-beta", "mistralai/Mistral-7B-Instruct-v0.2"]
26
  list_llm_simple = [os.path.basename(llm) for llm in list_llm]
27
 
28
+ def summarize_document(document_text):
29
+ # Your summarization code here
30
+ summary = "The document covers various topics such as X, Y, and Z, providing detailed insights into each aspect."
31
+ return summary
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  def demo():
34
  with gr.Blocks(theme="base") as demo:
35
+ gr.Markdown("<center><h2>PDF Summarizer</center></h2>")
 
 
36
 
37
+ text_input = gr.Textbox(placeholder="Paste your document text here", label="Document Text")
38
+ summarize_btn = gr.Button("Summarize")
39
+ summary_output = gr.Textbox(readonly=True, label="Summary")
 
40
 
41
+ summarize_btn.click(summarize_document, inputs=[text_input], outputs=[summary_output])
 
 
 
42
 
43
+ demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  if __name__ == "__main__":
46
+ demo()