Fecalisboa commited on
Commit
fa8e56a
·
verified ·
1 Parent(s): 064a4fe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -138
app.py CHANGED
@@ -14,51 +14,13 @@ from langchain_community.llms import HuggingFacePipeline
14
  from langchain.chains import ConversationChain
15
  from langchain.memory import ConversationBufferMemory
16
  from langchain_community.llms import HuggingFaceEndpoint
17
- from huggingface_hub import InferenceClient
18
  import torch
19
-
20
  api_token = os.getenv("HF_TOKEN")
21
 
22
- client = InferenceClient(
23
- "mistralai/Mistral-7B-Instruct-v0.3"
24
- )
25
-
26
  list_llm = ["meta-llama/Meta-Llama-3-8B-Instruct", "mistralai/Mistral-7B-Instruct-v0.3"]
27
  list_llm_simple = [os.path.basename(llm) for llm in list_llm]
28
 
29
- def format_prompt(message, history):
30
- prompt = "<s>"
31
- for user_prompt, bot_response in history:
32
- prompt += f"[INST] {user_prompt} [/INST]"
33
- prompt += f" {bot_response}</s> "
34
- prompt += f"[INST] {message} [/INST]"
35
- return prompt
36
-
37
- def generate(prompt, history, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0):
38
- temperature = float(temperature)
39
- if temperature < 1e-2:
40
- temperature = 1e-2
41
- top_p = float(top_p)
42
-
43
- generate_kwargs = dict(
44
- temperature=temperature,
45
- max_new_tokens=max_new_tokens,
46
- top_p=top_p,
47
- repetition_penalty=repetition_penalty,
48
- do_sample=True,
49
- seed=42,
50
- )
51
-
52
- formatted_prompt = format_prompt(prompt, history)
53
-
54
- stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
55
- output = ""
56
-
57
- for response in stream:
58
- output += response.token.text
59
- yield output
60
- return output
61
-
62
  def load_doc(list_file_path, chunk_size, chunk_overlap):
63
  loaders = [PyPDFLoader(x) for x in list_file_path]
64
  pages = []
@@ -68,6 +30,7 @@ def load_doc(list_file_path, chunk_size, chunk_overlap):
68
  doc_splits = text_splitter.split_documents(pages)
69
  return doc_splits
70
 
 
71
  def create_db(splits, collection_name, db_type):
72
  embedding = HuggingFaceEmbeddings()
73
 
@@ -100,8 +63,10 @@ def create_db(splits, collection_name, db_type):
100
 
101
  return vectordb
102
 
 
103
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, initial_prompt, progress=gr.Progress()):
104
  progress(0.1, desc="Initializing HF tokenizer...")
 
105
  progress(0.5, desc="Initializing HF Hub...")
106
 
107
  llm = HuggingFaceEndpoint(
@@ -177,30 +142,12 @@ def conversation(qa_chain, message, history):
177
  new_history = history + [(message, response_answer)]
178
  return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
179
 
180
- def conversation_no_doc(prompt, history, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0):
181
- temperature = float(temperature)
182
- if temperature < 1e-2:
183
- temperature = 1e-2
184
- top_p = float(top_p)
185
-
186
- generate_kwargs = dict(
187
- temperature=temperature,
188
- max_new_tokens=max_new_tokens,
189
- top_p=top_p,
190
- repetition_penalty=repetition_penalty,
191
- do_sample=True,
192
- seed=42,
193
- )
194
-
195
- formatted_prompt = format_prompt(prompt, history)
196
-
197
- stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
198
- output = ""
199
-
200
- for response in stream:
201
- output += response.token.text
202
- yield output
203
- return output
204
 
205
  def upload_file(file_obj):
206
  list_file_path = []
@@ -208,33 +155,6 @@ def upload_file(file_obj):
208
  list_file_path.append(file.name)
209
  return list_file_path
210
 
211
- def initialize_database(list_file_obj, chunk_size, chunk_overlap, db_type, progress=gr.Progress()):
212
- list_file_path = [x.name for x in list_file_obj if x is not None]
213
- progress(0.1, desc="Creating collection name...")
214
- collection_name = create_collection_name(list_file_path[0])
215
- progress(0.25, desc="Loading document...")
216
- doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
217
- progress(0.5, desc="Generating vector database...")
218
- vector_db = create_db(doc_splits, collection_name, db_type)
219
- progress(0.9, desc="Done!")
220
- return vector_db, collection_name, "Complete!"
221
-
222
- def create_collection_name(filepath):
223
- collection_name = Path(filepath).stem
224
- collection_name = collection_name.replace(" ", "-")
225
- collection_name = unidecode(collection_name)
226
- collection_name = re.sub('[^A-Za-z0-9]+', '-', collection_name)
227
- collection_name = collection_name[:50]
228
- if len(collection_name) < 3:
229
- collection_name = collection_name + 'xyz'
230
- if not collection_name[0].isalnum():
231
- collection_name = 'A' + collection_name[1:]
232
- if not collection_name[-1].isalnum():
233
- collection_name = collection_name[:-1] + 'Z'
234
- print('Filepath: ', filepath)
235
- print('Collection name: ', collection_name)
236
- return collection_name
237
-
238
  def demo():
239
  with gr.Blocks(theme="base") as demo:
240
  vector_db = gr.State()
@@ -309,58 +229,27 @@ def demo():
309
  clear_btn = gr.ClearButton([msg, chatbot], value="Clear conversation")
310
 
311
  with gr.Tab("Step 6 - Chatbot without document"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
  chatbot_no_doc = gr.Chatbot(height=300)
313
- additional_inputs=[
314
- gr.Slider(
315
- label="Temperature",
316
- value=0.9,
317
- minimum=0.0,
318
- maximum=1.0,
319
- step=0.05,
320
- interactive=True,
321
- info="Higher values produce more diverse outputs",
322
- ),
323
- gr.Slider(
324
- label="Max new tokens",
325
- value=256,
326
- minimum=0,
327
- maximum=1048,
328
- step=64,
329
- interactive=True,
330
- info="The maximum numbers of new tokens",
331
- ),
332
- gr.Slider(
333
- label="Top-p (nucleus sampling)",
334
- value=0.90,
335
- minimum=0.0,
336
- maximum=1,
337
- step=0.05,
338
- interactive=True,
339
- info="Higher values sample more low-probability tokens",
340
- ),
341
- gr.Slider(
342
- label="Repetition penalty",
343
- value=1.2,
344
- minimum=1.0,
345
- maximum=2.0,
346
- step=0.05,
347
- interactive=True,
348
- info="Penalize repeated tokens",
349
- )
350
- ]
351
  with gr.Row():
352
  msg_no_doc = gr.Textbox(placeholder="Type message to chat with lucIAna", container=True)
353
  with gr.Row():
354
  submit_btn_no_doc = gr.Button("Submit message")
355
  clear_btn_no_doc = gr.ClearButton([msg_no_doc, chatbot_no_doc], value="Clear conversation")
356
 
357
- chat_interface = gr.ChatInterface(
358
- fn=generate,
359
- chatbot=chatbot_no_doc,
360
- additional_inputs=additional_inputs,
361
- title="Mistral 7B v0.3"
362
- )
363
-
364
  # Preprocessing events
365
  db_btn.click(initialize_database,
366
  inputs=[document, slider_chunk_size, slider_chunk_overlap, db_type_radio],
@@ -368,7 +257,7 @@ def demo():
368
  set_prompt_btn.click(lambda prompt: gr.update(value=prompt),
369
  inputs=prompt_input,
370
  outputs=initial_prompt)
371
- qachain_btn.click(initialize_llmchain,
372
  inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db, initial_prompt],
373
  outputs=[qa_chain, llm_progress]).then(lambda:[None,"",0,"",0,"",0],
374
  inputs=None,
@@ -390,9 +279,13 @@ def demo():
390
  queue=False)
391
 
392
  # Initialize LLM without document for conversation
 
 
 
 
393
  submit_btn_no_doc.click(conversation_no_doc,
394
- inputs=[msg_no_doc, chatbot_no_doc],
395
- outputs=[msg_no_doc, chatbot_no_doc],
396
  queue=False)
397
  clear_btn_no_doc.click(lambda:[None,""],
398
  inputs=None,
 
14
  from langchain.chains import ConversationChain
15
  from langchain.memory import ConversationBufferMemory
16
  from langchain_community.llms import HuggingFaceEndpoint
 
17
  import torch
 
18
  api_token = os.getenv("HF_TOKEN")
19
 
 
 
 
 
20
  list_llm = ["meta-llama/Meta-Llama-3-8B-Instruct", "mistralai/Mistral-7B-Instruct-v0.3"]
21
  list_llm_simple = [os.path.basename(llm) for llm in list_llm]
22
 
23
+ # Load PDF document and create doc splits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  def load_doc(list_file_path, chunk_size, chunk_overlap):
25
  loaders = [PyPDFLoader(x) for x in list_file_path]
26
  pages = []
 
30
  doc_splits = text_splitter.split_documents(pages)
31
  return doc_splits
32
 
33
+ # Create vector database
34
  def create_db(splits, collection_name, db_type):
35
  embedding = HuggingFaceEmbeddings()
36
 
 
63
 
64
  return vectordb
65
 
66
+ # Initialize langchain LLM chain
67
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, initial_prompt, progress=gr.Progress()):
68
  progress(0.1, desc="Initializing HF tokenizer...")
69
+
70
  progress(0.5, desc="Initializing HF Hub...")
71
 
72
  llm = HuggingFaceEndpoint(
 
142
  new_history = history + [(message, response_answer)]
143
  return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
144
 
145
+ def conversation_no_doc(llm, message, history):
146
+ formatted_chat_history = format_chat_history(message, history)
147
+ response = llm({"question": message, "chat_history": formatted_chat_history})
148
+ response_answer = response["answer"]
149
+ new_history = history + [(message, response_answer)]
150
+ return llm, gr.update(value=""), new_history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
  def upload_file(file_obj):
153
  list_file_path = []
 
155
  list_file_path.append(file.name)
156
  return list_file_path
157
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  def demo():
159
  with gr.Blocks(theme="base") as demo:
160
  vector_db = gr.State()
 
229
  clear_btn = gr.ClearButton([msg, chatbot], value="Clear conversation")
230
 
231
  with gr.Tab("Step 6 - Chatbot without document"):
232
+ with gr.Row():
233
+ llm_no_doc_btn = gr.Radio(list_llm_simple,
234
+ label="LLM models", value=list_llm_simple[0], type="index", info="Choose your LLM model for chatbot without document")
235
+ with gr.Accordion("Advanced options - LLM model", open=False):
236
+ with gr.Row():
237
+ slider_temperature_no_doc = gr.Slider(minimum=0.01, maximum=1.0, value=0.7, step=0.1, label="Temperature", info="Model temperature", interactive=True)
238
+ with gr.Row():
239
+ slider_maxtokens_no_doc = gr.Slider(minimum=224, maximum=4096, value=1024, step=32, label="Max Tokens", info="Model max tokens", interactive=True)
240
+ with gr.Row():
241
+ slider_topk_no_doc = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="top-k samples", info="Model top-k samples", interactive=True)
242
+ with gr.Row():
243
+ llm_no_doc_progress = gr.Textbox(value="None", label="LLM initialization for chatbot without document")
244
+ with gr.Row():
245
+ llm_no_doc_init_btn = gr.Button("Initialize LLM for Chatbot without document")
246
  chatbot_no_doc = gr.Chatbot(height=300)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  with gr.Row():
248
  msg_no_doc = gr.Textbox(placeholder="Type message to chat with lucIAna", container=True)
249
  with gr.Row():
250
  submit_btn_no_doc = gr.Button("Submit message")
251
  clear_btn_no_doc = gr.ClearButton([msg_no_doc, chatbot_no_doc], value="Clear conversation")
252
 
 
 
 
 
 
 
 
253
  # Preprocessing events
254
  db_btn.click(initialize_database,
255
  inputs=[document, slider_chunk_size, slider_chunk_overlap, db_type_radio],
 
257
  set_prompt_btn.click(lambda prompt: gr.update(value=prompt),
258
  inputs=prompt_input,
259
  outputs=initial_prompt)
260
+ qachain_btn.click(initialize_LLM,
261
  inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db, initial_prompt],
262
  outputs=[qa_chain, llm_progress]).then(lambda:[None,"",0,"",0,"",0],
263
  inputs=None,
 
279
  queue=False)
280
 
281
  # Initialize LLM without document for conversation
282
+ llm_no_doc_init_btn.click(initialize_llm_no_doc,
283
+ inputs=[llm_no_doc_btn, slider_temperature_no_doc, slider_maxtokens_no_doc, slider_topk_no_doc, initial_prompt],
284
+ outputs=[llm_no_doc, llm_no_doc_progress])
285
+
286
  submit_btn_no_doc.click(conversation_no_doc,
287
+ inputs=[llm_no_doc, msg_no_doc, chatbot_no_doc],
288
+ outputs=[llm_no_doc, msg_no_doc, chatbot_no_doc],
289
  queue=False)
290
  clear_btn_no_doc.click(lambda:[None,""],
291
  inputs=None,