Robin Genolet commited on
Commit
6ed9cc0
·
1 Parent(s): 9b52308

feat: specify params

Browse files
Files changed (2) hide show
  1. app.py +67 -118
  2. utils/epfl_meditron_utils.py +22 -28
app.py CHANGED
@@ -11,7 +11,7 @@ import sys
11
  import io
12
 
13
  from utils.default_values import get_system_prompt, get_guidelines_dict
14
- from utils.epfl_meditron_utils import get_llm_response
15
  from utils.openai_utils import get_available_engines, get_search_query_type_options
16
 
17
  from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
@@ -83,19 +83,18 @@ def display_streamlit_sidebar():
83
  st.sidebar.write('**Parameters**')
84
  form = st.sidebar.form("config_form", clear_on_submit=True)
85
 
86
- model_option = form.selectbox("Quickly select a model", ("llama", "meditron"))
87
- model_repo_id = form.text_input(label="Repo", value=model_option)#value=st.session_state["model_repo_id"])
88
- model_filename = form.text_input(label="File name", value=st.session_state["model_filename"])
89
- model_type = form.text_input(label="Model type", value=st.session_state["model_type"])
90
- gpu_layers = form.slider('GPU Layers', min_value=0,
91
- max_value=100, value=st.session_state['gpu_layers'], step=1)
92
-
93
- system_prompt = ""
94
- #form.text_area(label='System prompt',
95
- # value=st.session_state["system_prompt"])
96
 
97
  temperature = form.slider('Temperature (0 = deterministic, 1 = more freedom)', min_value=0.0,
98
  max_value=1.0, value=st.session_state['temperature'], step=0.1)
 
99
  top_p = form.slider('top_p (0 = focused, 1 = broader answer range)', min_value=0.0,
100
  max_value=1.0, value=st.session_state['top_p'], step=0.1)
101
 
@@ -104,22 +103,16 @@ def display_streamlit_sidebar():
104
  submitted = form.form_submit_button("Start session")
105
  if submitted and not st.session_state['session_started']:
106
  print('Parameters updated...')
107
- restart_session()
108
  st.session_state['session_started'] = True
109
 
110
- st.session_state["model_repo_id"] = model_repo_id
111
- st.session_state["model_filename"] = model_filename
112
- st.session_state["model_type"] = model_type
113
- st.session_state['gpu_layers'] = gpu_layers
114
-
115
- st.session_state["questions"] = []
116
- st.session_state["lead_symptom"] = None
117
- st.session_state["scenario_name"] = None
118
- st.session_state["system_prompt"] = system_prompt
119
- st.session_state['session_started'] = True
120
- st.session_state["session_started"] = True
121
  st.session_state["temperature"] = temperature
 
122
  st.session_state["top_p"] = top_p
 
 
 
123
 
124
  st.rerun()
125
 
@@ -190,96 +183,20 @@ def get_chat_history_string(chat_history):
190
  raise Exception('Unknown role: ' + str(i["role"]))
191
 
192
  return res
193
-
194
-
195
- def restart_session():
196
- print("Resetting params...")
197
- st.session_state["emg_class_enabled"] = False
198
- st.session_state["enable_llm_summary"] = False
199
- st.session_state["num_variants"] = 3
200
- st.session_state["lang_index"] = 0
201
- st.session_state["llm_message"] = ""
202
- st.session_state["llm_messages"] = []
203
-
204
- st.session_state["triage_prompt_variants"] = ['''You are a telemedicine triage agent that decides between the following:
205
- Emergency: Patient health is at risk if he doesn't speak to a Doctor urgently
206
- Telecare: Patient can likely be treated remotely
207
- General Practitioner: Patient should visit a GP for an ad-real consultation''',
208
-
209
- '''You are a Doctor assistant that decides if a medical case can likely be treated remotely by a Doctor or not.
210
- The remote Doctor can write prescriptions and request the patient to provide a picture.
211
- Provide the triage recommendation first, and then explain your reasoning respecting the format given below:
212
- Treat remotely: <your reasoning>
213
- Treat ad-real: <your reasoning>''',
214
-
215
- '''You are a medical triage agent working for the telemedicine Company Medgate based in Switzerland.
216
- You decide if a case can be treated remotely or not, knowing that the remote Doctor can write prescriptions and request pictures.
217
- Provide the triage recommendation first, and then explain your reasoning respecting the format given below:
218
- Treat remotely: <your reasoning>
219
- Treat ad-real: <your reasoning>''']
220
-
221
- st.session_state['nbqs'] = []
222
- st.session_state['citations'] = {}
223
-
224
- st.session_state['past_messages'] = []
225
- st.session_state["last_request"] = None
226
- st.session_state["last_proposal"] = None
227
-
228
- st.session_state['doctor_question'] = ''
229
- st.session_state['patient_reply'] = ''
230
-
231
- st.session_state['chat_history_array'] = []
232
- st.session_state['chat_history'] = ''
233
-
234
- st.session_state['feed_summary'] = ''
235
- st.session_state['summary'] = ''
236
-
237
- st.session_state["selected_guidelines"] = ["General"]
238
- st.session_state["guidelines_dict"] = get_guidelines_dict()
239
-
240
- st.session_state["triage_recommendation"] = ''
241
-
242
- st.session_state["session_events"] = []
243
-
244
 
245
  def init_session_state():
246
  print('init_session_state()')
247
  st.session_state['session_started'] = False
248
- st.session_state['guidelines_ignored'] = False
249
- st.session_state['model_index'] = 1
250
-
251
- st.session_state["model_repo_id"] = "TheBloke/meditron-7B-GGUF"
252
- st.session_state["model_filename"] = "meditron-7b.Q5_K_S.gguf"
253
- st.session_state["model_type"] = "llama"
254
- st.session_state['gpu_layers'] = 1
255
-
256
- default_gender_index = 0
257
- st.session_state['gender'] = get_genders()[default_gender_index]
258
- st.session_state['gender_index'] = default_gender_index
259
-
260
- st.session_state['age'] = 30
261
-
262
- st.session_state['patient_medical_info'] = ''
263
-
264
- default_search_query = 0
265
- st.session_state['search_query_type'] = get_search_query_type_options()[default_search_query]
266
- st.session_state['search_query_type_index'] = default_search_query
267
- st.session_state['engine'] = get_available_engines()[0]
268
- st.session_state['temperature'] = 0.0
269
- st.session_state['top_p'] = 1.0
270
- st.session_state['feed_chat_transcript'] = ''
271
-
272
- st.session_state["llm_model"] = True
273
- st.session_state["hugging_face_models"] = True
274
- st.session_state["local_models"] = True
275
-
276
- restart_session()
277
-
278
- st.session_state['system_prompt'] = get_system_prompt()
279
- st.session_state['system_prompt_after_on_change'] = get_system_prompt()
280
-
281
- st.session_state["summary"] = ''
282
-
283
 
284
  def get_genders():
285
  return ['Male', 'Female']
@@ -498,23 +415,55 @@ def get_diarized_f_path(audio_f_name):
498
  return DATA_FOLDER + base_name + ".txt"
499
 
500
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
501
  def display_llm_output():
502
  st.header("LLM")
503
 
504
  form = st.form('llm')
505
 
506
- llm_message = form.text_area('Message', value=st.session_state["llm_message"])
 
 
 
507
 
508
- api_submitted = form.form_submit_button('Submit')
509
 
510
- if api_submitted:
 
 
511
  llm_response = get_llm_response(
512
- st.session_state["model_repo_id"],
513
- st.session_state["model_filename"],
514
- st.session_state["model_type"],
515
- st.session_state["gpu_layers"],
516
- "You are a medical assistant",
517
- llm_message)
 
 
518
  st.write(llm_response)
519
  st.write('Done displaying LLM response')
520
 
 
11
  import io
12
 
13
  from utils.default_values import get_system_prompt, get_guidelines_dict
14
+ from utils.epfl_meditron_utils import get_llm_response, gptq_model_options
15
  from utils.openai_utils import get_available_engines, get_search_query_type_options
16
 
17
  from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
 
83
  st.sidebar.write('**Parameters**')
84
  form = st.sidebar.form("config_form", clear_on_submit=True)
85
 
86
+ model_name_or_path = form.selectbox("Select model", gptq_model_options())
87
+
88
+ temperature = form.slider(label="Temperature", min_value=0.0, max_value=1.0, step=0.01, value=0.01)
89
+ do_sample = form.checkbox('do_sample')
90
+ top_p = form.slider(label="top_p", min_value=0.0, max_value=1.0, step=0.01, value=0.95)
91
+ top_k = form.slider(label="top_k", min_value=1, max_value=1000, step=1, value=40)
92
+ max_new_tokens = form.slider(label="max_new_tokens", min_value=32, max_value=512, step=1, value=32)
93
+ repetition_penalty = form.slider(label="repetition_penalty", min_value=0.0, max_value=1.0, step=0.01, value=0.95)
 
 
94
 
95
  temperature = form.slider('Temperature (0 = deterministic, 1 = more freedom)', min_value=0.0,
96
  max_value=1.0, value=st.session_state['temperature'], step=0.1)
97
+
98
  top_p = form.slider('top_p (0 = focused, 1 = broader answer range)', min_value=0.0,
99
  max_value=1.0, value=st.session_state['top_p'], step=0.1)
100
 
 
103
  submitted = form.form_submit_button("Start session")
104
  if submitted and not st.session_state['session_started']:
105
  print('Parameters updated...')
 
106
  st.session_state['session_started'] = True
107
 
108
+ st.session_state["session_events"] = []
109
+ st.session_state["model_name_or_path"] = model_name_or_path
 
 
 
 
 
 
 
 
 
110
  st.session_state["temperature"] = temperature
111
+ st.session_state["do_sample"] = do_sample
112
  st.session_state["top_p"] = top_p
113
+ st.session_state["top_k"] = top_k
114
+ st.session_state["max_new_tokens"] = max_new_tokens
115
+ st.session_state["repetition_penalty"] = repetition_penalty
116
 
117
  st.rerun()
118
 
 
183
  raise Exception('Unknown role: ' + str(i["role"]))
184
 
185
  return res
186
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
  def init_session_state():
189
  print('init_session_state()')
190
  st.session_state['session_started'] = False
191
+ st.session_state["session_events"] = []
192
+ st.session_state["model_name_or_path"] = "TheBloke/meditron-7B-GPTQ"
193
+ st.session_state["temperature"] = 0.01
194
+ st.session_state["do_sample"] = True
195
+ st.session_state["top_p"] = 0.95
196
+ st.session_state["top_k"] = 40
197
+ st.session_state["max_new_tokens"] = 512
198
+ st.session_state["repetition_penalty"] = 1.1
199
+ st.session_state["system_message"] = "You are a medical expert that provides answers for a medically trained audience"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
 
201
  def get_genders():
202
  return ['Male', 'Female']
 
415
  return DATA_FOLDER + base_name + ".txt"
416
 
417
 
418
+ def get_prompt_format(model_name):
419
+ if model_name == "TheBloke/Llama-2-13B-chat-GPTQ":
420
+ return '''[INST] <<SYS>>
421
+ You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
422
+ <</SYS>>
423
+ {prompt}[/INST]
424
+
425
+ '''
426
+ if model_name == "TheBloke/Llama-2-7B-Chat-GPTQ":
427
+ return "[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n{prompt}[/INST]"
428
+
429
+ if model_name == "TheBloke/meditron-7B-GPTQ" or model_name == "TheBloke/meditron-70B-GPTQ":
430
+ return '''<|im_start|>system
431
+ {system_message}<|im_end|>
432
+ <|im_start|>user
433
+ {prompt}<|im_end|>
434
+ <|im_start|>assistant'''
435
+
436
+ return ""
437
+
438
+ def format_prompt(template, system_message, prompt):
439
+ if template == "":
440
+ return f"{system_message} {prompt}"
441
+ return template.format(system_message=system_message, prompt=prompt)
442
+
443
  def display_llm_output():
444
  st.header("LLM")
445
 
446
  form = st.form('llm')
447
 
448
+ prompt_format_str = get_prompt_format(st.session_state["model_name_or_path"])
449
+ prompt_format = form.text_area('Prompt format', value=prompt_format_str)
450
+ system_prompt = form.text_area('System prompt', value=st.session_state["system_prompt"])
451
+ prompt = form.text_area('Prompt', value=st.session_state["prompt"])
452
 
453
+ submitted = form.form_submit_button('Submit')
454
 
455
+ if submitted:
456
+ formatted_prompt = format_prompt(prompt_format, system_prompt, prompt)
457
+ print(f"Formatted prompt: {format_prompt}")
458
  llm_response = get_llm_response(
459
+ st.session_state["model_name"],
460
+ st.session_state["temperature"],
461
+ st.session_state["do_sample"],
462
+ st.session_state["top_p"],
463
+ st.session_state["top_k"],
464
+ st.session_state["max_new_tokens"],
465
+ st.session_state["repetition_penalty"],
466
+ formatted_prompt)
467
  st.write(llm_response)
468
  st.write('Done displaying LLM response')
469
 
utils/epfl_meditron_utils.py CHANGED
@@ -1,49 +1,43 @@
1
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
2
 
3
-
4
- def get_llm_response(repo, filename, model_type, gpu_layers, system_message, prompt):
5
-
6
- model_name_or_path = "TheBloke/meditron-7B-GPTQ"
7
- # To use a different branch, change revision
8
- # For example: revision="gptq-4bit-128g-actorder_True"
 
 
 
9
  model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
10
  device_map="auto",
11
  trust_remote_code=False,
12
  revision="main")
13
 
14
  tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
15
-
16
- prompt_template=f'''<|im_start|>system
17
- {system_message}<|im_end|>
18
- <|im_start|>user
19
- {prompt}<|im_end|>
20
- <|im_start|>assistant
21
- '''
22
-
23
- print("Template:")
24
- print(prompt_template)
25
 
26
  print("\n\n*** Generate:")
27
 
28
- input_ids = tokenizer(prompt_template, return_tensors='pt').input_ids.cuda()
29
- output = model.generate(inputs=input_ids, temperature=0.01, do_sample=True, top_p=0.95, top_k=40, max_new_tokens=512)
30
- print(tokenizer.decode(output[0]))
31
 
32
  print("*** Pipeline:")
33
  pipe = pipeline(
34
  "text-generation",
35
  model=model,
36
  tokenizer=tokenizer,
37
- max_new_tokens=512,
38
- do_sample=True,
39
- temperature=0.7,
40
- top_p=0.95,
41
- top_k=40,
42
- repetition_penalty=1.1
43
  )
44
 
45
-
46
-
47
- response = pipe(prompt_template)[0]['generated_text']
48
  print(response)
49
  return response
 
1
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
2
 
3
+ def gptq_model_options():
4
+ return [
5
+ "TheBloke/Llama-2-7B-Chat-GPTQ",
6
+ "TheBloke/Llama-2-13B-chat-GPTQ",
7
+ "TheBloke/meditron-7B-GPTQ",
8
+ "TheBloke/meditron-70B-GPTQ",
9
+ ]
10
+
11
+ def get_llm_response(model_name_or_path, temperature, do_sample, top_p, top_k, max_new_tokens, repetition_penalty, formatted_prompt):
12
  model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
13
  device_map="auto",
14
  trust_remote_code=False,
15
  revision="main")
16
 
17
  tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
18
+
19
+ print("Formatted prompt:")
20
+ print(formatted_prompt)
 
 
 
 
 
 
 
21
 
22
  print("\n\n*** Generate:")
23
 
24
+ input_ids = tokenizer(formatted_prompt, return_tensors='pt').input_ids.cuda()
25
+ output = model.generate(inputs=input_ids, temperature=temperature, do_sample=do_sample, top_p=top_p, top_k=top_k, max_new_tokens=max_new_tokens)
26
+ print(tokenizer.decode(output[0], skip_special_tokens=True))
27
 
28
  print("*** Pipeline:")
29
  pipe = pipeline(
30
  "text-generation",
31
  model=model,
32
  tokenizer=tokenizer,
33
+ max_new_tokens=max_new_tokens,
34
+ do_sample=do_sample,
35
+ temperature=temperature,
36
+ top_p=top_p,
37
+ top_k=top_p,
38
+ repetition_penalty=repetition_penalty
39
  )
40
 
41
+ response = pipe(formatted_prompt)[0]['generated_text']
 
 
42
  print(response)
43
  return response