Fecalisboa commited on
Commit
80c8e97
·
verified ·
1 Parent(s): c162f7c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -47
app.py CHANGED
@@ -1,6 +1,5 @@
1
  import gradio as gr
2
  import os
3
- import getpass
4
  from pathlib import Path
5
  import re
6
  from unidecode import unidecode
@@ -10,7 +9,7 @@ from langchain_community.document_loaders import PyPDFLoader
10
  from langchain.text_splitter import RecursiveCharacterTextSplitter
11
  from langchain_community.vectorstores import Chroma
12
  from langchain.chains import ConversationalRetrievalChain
13
- from langchain_community.embeddings import HuggingFaceEmbeddings
14
  from langchain_community.llms import HuggingFacePipeline
15
  from langchain.chains import ConversationChain
16
  from langchain.memory import ConversationBufferMemory
@@ -18,9 +17,7 @@ from langchain_community.llms import HuggingFaceEndpoint
18
  import torch
19
  api_token = os.getenv("HF_TOKEN")
20
 
21
-
22
-
23
- list_llm = ["meta-llama/Meta-Llama-3-8B-Instruct", "mistralai/Mistral-7B-Instruct-v0.3"]
24
  list_llm_simple = [os.path.basename(llm) for llm in list_llm]
25
 
26
  # Load PDF document and create doc splits
@@ -66,39 +63,20 @@ def create_db(splits, collection_name, db_type):
66
 
67
  return vectordb
68
 
69
- # Load vector database
70
- def load_db():
71
- embedding = HuggingFaceEmbeddings()
72
- vectordb = Chroma(
73
- embedding_function=embedding)
74
- return vectordb
75
-
76
  # Initialize langchain LLM chain
77
- def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
78
  progress(0.1, desc="Initializing HF tokenizer...")
79
 
80
  progress(0.5, desc="Initializing HF Hub...")
81
 
82
-
83
- if llm_model == "meta-llama/Meta-Llama-3-8B-Instruct":
84
- llm = HuggingFaceEndpoint(
85
- repo_id=llm_model,
86
- huggingfacehub_api_token=api_token,
87
- temperature=temperature,
88
- max_new_tokens=max_tokens,
89
- top_k=top_k,
90
- )
91
-
92
- else:
93
-
94
- llm = HuggingFaceEndpoint(
95
- repo_id=llm_model,
96
- huggingfacehub_api_token=api_token,
97
- temperature=temperature,
98
- max_new_tokens=max_tokens,
99
- top_k=top_k,
100
- )
101
-
102
  progress(0.75, desc="Defining buffer memory...")
103
  memory = ConversationBufferMemory(
104
  memory_key="chat_history",
@@ -110,18 +88,19 @@ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, pr
110
  qa_chain = ConversationalRetrievalChain.from_llm(
111
  llm,
112
  retriever=retriever,
113
- chain_type="stuff",
114
  memory=memory,
115
  return_source_documents=True,
116
  verbose=False,
117
  )
 
118
  progress(0.9, desc="Done!")
119
  return qa_chain
120
 
121
  # Generate collection name for vector database
122
  def create_collection_name(filepath):
123
  collection_name = Path(filepath).stem
124
- collection_name = collection_name.replace(" ", "-")
125
  collection_name = unidecode(collection_name)
126
  collection_name = re.sub('[^A-Za-z0-9]+', '-', collection_name)
127
  collection_name = collection_name[:50]
@@ -147,10 +126,10 @@ def initialize_database(list_file_obj, chunk_size, chunk_overlap, db_type, progr
147
  progress(0.9, desc="Done!")
148
  return vector_db, collection_name, "Complete!"
149
 
150
- def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
151
  llm_name = list_llm[llm_option]
152
  print("llm_name: ", llm_name)
153
- qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress)
154
  return qa_chain, "Complete!"
155
 
156
  def format_chat_history(message, chat_history):
@@ -162,7 +141,6 @@ def format_chat_history(message, chat_history):
162
 
163
  def conversation(qa_chain, message, history):
164
  formatted_chat_history = format_chat_history(message, history)
165
-
166
  response = qa_chain({"question": message, "chat_history": formatted_chat_history})
167
  response_answer = response["answer"]
168
  if "Helpful Answer:" in response_answer:
@@ -178,6 +156,13 @@ def conversation(qa_chain, message, history):
178
  new_history = history + [(message, response_answer)]
179
  return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
180
 
 
 
 
 
 
 
 
181
  def upload_file(file_obj):
182
  list_file_path = []
183
  for file in file_obj:
@@ -189,7 +174,9 @@ def demo():
189
  vector_db = gr.State()
190
  qa_chain = gr.State()
191
  collection_name = gr.State()
192
-
 
 
193
  gr.Markdown(
194
  """<center><h2>PDF-based chatbot</center></h2>
195
  <h3>Ask any questions about your PDF documents</h3>""")
@@ -197,7 +184,7 @@ def demo():
197
  """<b>Note:</b> Esta é a lucIAna, primeira Versão da IA para seus PDF documentos.
198
  Este chatbot leva em consideração perguntas anteriores ao gerar respostas (por meio de memória conversacional) e inclui referências a documentos para fins de clareza.
199
  """)
200
-
201
  with gr.Tab("Step 1 - Upload PDF"):
202
  with gr.Row():
203
  document = gr.Files(height=100, file_count="multiple", file_types=["pdf"], interactive=True, label="Upload your PDF documents (single or multiple)")
@@ -215,7 +202,13 @@ def demo():
215
  with gr.Row():
216
  db_btn = gr.Button("Generate vector database")
217
 
218
- with gr.Tab("Step 3 - Initialize QA chain"):
 
 
 
 
 
 
219
  with gr.Row():
220
  llm_btn = gr.Radio(list_llm_simple,
221
  label="LLM models", value=list_llm_simple[0], type="index", info="Choose your LLM model")
@@ -231,7 +224,7 @@ def demo():
231
  with gr.Row():
232
  qachain_btn = gr.Button("Initialize Question Answering chain")
233
 
234
- with gr.Tab("Step 4 - Chatbot"):
235
  chatbot = gr.Chatbot(height=300)
236
  with gr.Accordion("Advanced - Document references", open=False):
237
  with gr.Row():
@@ -248,19 +241,30 @@ def demo():
248
  with gr.Row():
249
  submit_btn = gr.Button("Submit message")
250
  clear_btn = gr.ClearButton([msg, chatbot], value="Clear conversation")
251
-
 
 
 
 
 
 
 
 
252
  # Preprocessing events
253
  db_btn.click(initialize_database,
254
  inputs=[document, slider_chunk_size, slider_chunk_overlap, db_type_radio],
255
  outputs=[vector_db, collection_name, db_progress])
 
 
 
256
  qachain_btn.click(initialize_LLM,
257
- inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db],
258
  outputs=[qa_chain, llm_progress]).then(lambda:[None,"",0,"",0,"",0],
259
  inputs=None,
260
  outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
261
  queue=False)
262
 
263
- # Chatbot events
264
  msg.submit(conversation,
265
  inputs=[qa_chain, msg, chatbot],
266
  outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
@@ -273,7 +277,18 @@ def demo():
273
  inputs=None,
274
  outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
275
  queue=False)
276
- demo.queue().launch(debug=True)
 
 
 
 
 
 
 
 
 
 
 
277
 
278
  if __name__ == "__main__":
279
- demo()
 
1
  import gradio as gr
2
  import os
 
3
  from pathlib import Path
4
  import re
5
  from unidecode import unidecode
 
9
  from langchain.text_splitter import RecursiveCharacterTextSplitter
10
  from langchain_community.vectorstores import Chroma
11
  from langchain.chains import ConversationalRetrievalChain
12
+ from langchain_community.embeddings import HuggingFaceEmbeddings
13
  from langchain_community.llms import HuggingFacePipeline
14
  from langchain.chains import ConversationChain
15
  from langchain.memory import ConversationBufferMemory
 
17
  import torch
18
  api_token = os.getenv("HF_TOKEN")
19
 
20
+ list_llm = ["meta-llama/Meta-Llama-3-8B-Instruct", "mistralai/Mistral-7B-Instruct"]
 
 
21
  list_llm_simple = [os.path.basename(llm) for llm in list_llm]
22
 
23
  # Load PDF document and create doc splits
 
63
 
64
  return vectordb
65
 
 
 
 
 
 
 
 
66
  # Initialize langchain LLM chain
67
+ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, initial_prompt, progress=gr.Progress()):
68
  progress(0.1, desc="Initializing HF tokenizer...")
69
 
70
  progress(0.5, desc="Initializing HF Hub...")
71
 
72
+ llm = HuggingFaceEndpoint(
73
+ repo_id=llm_model,
74
+ huggingfacehub_api_token=api_token,
75
+ temperature=temperature,
76
+ max_new_tokens=max_tokens,
77
+ top_k=top_k,
78
+ )
79
+
 
 
 
 
 
 
 
 
 
 
 
 
80
  progress(0.75, desc="Defining buffer memory...")
81
  memory = ConversationBufferMemory(
82
  memory_key="chat_history",
 
88
  qa_chain = ConversationalRetrievalChain.from_llm(
89
  llm,
90
  retriever=retriever,
91
+ chain_type="stuff",
92
  memory=memory,
93
  return_source_documents=True,
94
  verbose=False,
95
  )
96
+ qa_chain({"question": initial_prompt}) # Initialize with the initial prompt
97
  progress(0.9, desc="Done!")
98
  return qa_chain
99
 
100
  # Generate collection name for vector database
101
  def create_collection_name(filepath):
102
  collection_name = Path(filepath).stem
103
+ collection_name = collection_name.replace(" ", "-")
104
  collection_name = unidecode(collection_name)
105
  collection_name = re.sub('[^A-Za-z0-9]+', '-', collection_name)
106
  collection_name = collection_name[:50]
 
126
  progress(0.9, desc="Done!")
127
  return vector_db, collection_name, "Complete!"
128
 
129
+ def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, initial_prompt, progress=gr.Progress()):
130
  llm_name = list_llm[llm_option]
131
  print("llm_name: ", llm_name)
132
+ qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, initial_prompt, progress)
133
  return qa_chain, "Complete!"
134
 
135
  def format_chat_history(message, chat_history):
 
141
 
142
  def conversation(qa_chain, message, history):
143
  formatted_chat_history = format_chat_history(message, history)
 
144
  response = qa_chain({"question": message, "chat_history": formatted_chat_history})
145
  response_answer = response["answer"]
146
  if "Helpful Answer:" in response_answer:
 
156
  new_history = history + [(message, response_answer)]
157
  return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
158
 
159
+ def conversation_no_doc(llm, message, history):
160
+ formatted_chat_history = format_chat_history(message, history)
161
+ response = llm({"question": message, "chat_history": formatted_chat_history})
162
+ response_answer = response["answer"]
163
+ new_history = history + [(message, response_answer)]
164
+ return llm, gr.update(value=""), new_history
165
+
166
  def upload_file(file_obj):
167
  list_file_path = []
168
  for file in file_obj:
 
174
  vector_db = gr.State()
175
  qa_chain = gr.State()
176
  collection_name = gr.State()
177
+ initial_prompt = gr.State()
178
+ llm_no_doc = gr.State()
179
+
180
  gr.Markdown(
181
  """<center><h2>PDF-based chatbot</center></h2>
182
  <h3>Ask any questions about your PDF documents</h3>""")
 
184
  """<b>Note:</b> Esta é a lucIAna, primeira Versão da IA para seus PDF documentos.
185
  Este chatbot leva em consideração perguntas anteriores ao gerar respostas (por meio de memória conversacional) e inclui referências a documentos para fins de clareza.
186
  """)
187
+
188
  with gr.Tab("Step 1 - Upload PDF"):
189
  with gr.Row():
190
  document = gr.Files(height=100, file_count="multiple", file_types=["pdf"], interactive=True, label="Upload your PDF documents (single or multiple)")
 
202
  with gr.Row():
203
  db_btn = gr.Button("Generate vector database")
204
 
205
+ with gr.Tab("Step 3 - Set Initial Prompt"):
206
+ with gr.Row():
207
+ prompt_input = gr.Textbox(label="Initial Prompt", lines=5, value="Você é um advogado sênior, onde seu papel é analisar e trazer as informações sem inventar, dando a sua melhor opinião sempre trazendo contexto e referência. Aprenda o que é jurisprudência.")
208
+ with gr.Row():
209
+ set_prompt_btn = gr.Button("Set Prompt")
210
+
211
+ with gr.Tab("Step 4 - Initialize QA chain"):
212
  with gr.Row():
213
  llm_btn = gr.Radio(list_llm_simple,
214
  label="LLM models", value=list_llm_simple[0], type="index", info="Choose your LLM model")
 
224
  with gr.Row():
225
  qachain_btn = gr.Button("Initialize Question Answering chain")
226
 
227
+ with gr.Tab("Step 5 - Chatbot with document"):
228
  chatbot = gr.Chatbot(height=300)
229
  with gr.Accordion("Advanced - Document references", open=False):
230
  with gr.Row():
 
241
  with gr.Row():
242
  submit_btn = gr.Button("Submit message")
243
  clear_btn = gr.ClearButton([msg, chatbot], value="Clear conversation")
244
+
245
+ with gr.Tab("Step 6 - Chatbot without document"):
246
+ chatbot_no_doc = gr.Chatbot(height=300)
247
+ with gr.Row():
248
+ msg_no_doc = gr.Textbox(placeholder="Type message to chat with lucIAna", container=True)
249
+ with gr.Row():
250
+ submit_btn_no_doc = gr.Button("Submit message")
251
+ clear_btn_no_doc = gr.ClearButton([msg_no_doc, chatbot_no_doc], value="Clear conversation")
252
+
253
  # Preprocessing events
254
  db_btn.click(initialize_database,
255
  inputs=[document, slider_chunk_size, slider_chunk_overlap, db_type_radio],
256
  outputs=[vector_db, collection_name, db_progress])
257
+ set_prompt_btn.click(lambda prompt: prompt,
258
+ inputs=prompt_input,
259
+ outputs=initial_prompt)
260
  qachain_btn.click(initialize_LLM,
261
+ inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db, initial_prompt],
262
  outputs=[qa_chain, llm_progress]).then(lambda:[None,"",0,"",0,"",0],
263
  inputs=None,
264
  outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
265
  queue=False)
266
 
267
+ # Chatbot events with document
268
  msg.submit(conversation,
269
  inputs=[qa_chain, msg, chatbot],
270
  outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
 
277
  inputs=None,
278
  outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
279
  queue=False)
280
+
281
+ # Chatbot events without document
282
+ submit_btn_no_doc.click(conversation_no_doc,
283
+ inputs=[llm_no_doc, msg_no_doc, chatbot_no_doc],
284
+ outputs=[llm_no_doc, msg_no_doc, chatbot_no_doc],
285
+ queue=False)
286
+ clear_btn_no_doc.click(lambda:[None,""],
287
+ inputs=None,
288
+ outputs=[chatbot_no_doc, msg_no_doc],
289
+ queue=False)
290
+
291
+ demo.queue().launch(debug=True, share=True)
292
 
293
  if __name__ == "__main__":
294
+ demo()