Spaces:
Paused
Paused
Robin Genolet
commited on
Commit
·
6ed9cc0
1
Parent(s):
9b52308
feat: specify params
Browse files- app.py +67 -118
- utils/epfl_meditron_utils.py +22 -28
app.py
CHANGED
@@ -11,7 +11,7 @@ import sys
|
|
11 |
import io
|
12 |
|
13 |
from utils.default_values import get_system_prompt, get_guidelines_dict
|
14 |
-
from utils.epfl_meditron_utils import get_llm_response
|
15 |
from utils.openai_utils import get_available_engines, get_search_query_type_options
|
16 |
|
17 |
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
|
@@ -83,19 +83,18 @@ def display_streamlit_sidebar():
|
|
83 |
st.sidebar.write('**Parameters**')
|
84 |
form = st.sidebar.form("config_form", clear_on_submit=True)
|
85 |
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
#form.text_area(label='System prompt',
|
95 |
-
# value=st.session_state["system_prompt"])
|
96 |
|
97 |
temperature = form.slider('Temperature (0 = deterministic, 1 = more freedom)', min_value=0.0,
|
98 |
max_value=1.0, value=st.session_state['temperature'], step=0.1)
|
|
|
99 |
top_p = form.slider('top_p (0 = focused, 1 = broader answer range)', min_value=0.0,
|
100 |
max_value=1.0, value=st.session_state['top_p'], step=0.1)
|
101 |
|
@@ -104,22 +103,16 @@ def display_streamlit_sidebar():
|
|
104 |
submitted = form.form_submit_button("Start session")
|
105 |
if submitted and not st.session_state['session_started']:
|
106 |
print('Parameters updated...')
|
107 |
-
restart_session()
|
108 |
st.session_state['session_started'] = True
|
109 |
|
110 |
-
st.session_state["
|
111 |
-
st.session_state["
|
112 |
-
st.session_state["model_type"] = model_type
|
113 |
-
st.session_state['gpu_layers'] = gpu_layers
|
114 |
-
|
115 |
-
st.session_state["questions"] = []
|
116 |
-
st.session_state["lead_symptom"] = None
|
117 |
-
st.session_state["scenario_name"] = None
|
118 |
-
st.session_state["system_prompt"] = system_prompt
|
119 |
-
st.session_state['session_started'] = True
|
120 |
-
st.session_state["session_started"] = True
|
121 |
st.session_state["temperature"] = temperature
|
|
|
122 |
st.session_state["top_p"] = top_p
|
|
|
|
|
|
|
123 |
|
124 |
st.rerun()
|
125 |
|
@@ -190,96 +183,20 @@ def get_chat_history_string(chat_history):
|
|
190 |
raise Exception('Unknown role: ' + str(i["role"]))
|
191 |
|
192 |
return res
|
193 |
-
|
194 |
-
|
195 |
-
def restart_session():
|
196 |
-
print("Resetting params...")
|
197 |
-
st.session_state["emg_class_enabled"] = False
|
198 |
-
st.session_state["enable_llm_summary"] = False
|
199 |
-
st.session_state["num_variants"] = 3
|
200 |
-
st.session_state["lang_index"] = 0
|
201 |
-
st.session_state["llm_message"] = ""
|
202 |
-
st.session_state["llm_messages"] = []
|
203 |
-
|
204 |
-
st.session_state["triage_prompt_variants"] = ['''You are a telemedicine triage agent that decides between the following:
|
205 |
-
Emergency: Patient health is at risk if he doesn't speak to a Doctor urgently
|
206 |
-
Telecare: Patient can likely be treated remotely
|
207 |
-
General Practitioner: Patient should visit a GP for an ad-real consultation''',
|
208 |
-
|
209 |
-
'''You are a Doctor assistant that decides if a medical case can likely be treated remotely by a Doctor or not.
|
210 |
-
The remote Doctor can write prescriptions and request the patient to provide a picture.
|
211 |
-
Provide the triage recommendation first, and then explain your reasoning respecting the format given below:
|
212 |
-
Treat remotely: <your reasoning>
|
213 |
-
Treat ad-real: <your reasoning>''',
|
214 |
-
|
215 |
-
'''You are a medical triage agent working for the telemedicine Company Medgate based in Switzerland.
|
216 |
-
You decide if a case can be treated remotely or not, knowing that the remote Doctor can write prescriptions and request pictures.
|
217 |
-
Provide the triage recommendation first, and then explain your reasoning respecting the format given below:
|
218 |
-
Treat remotely: <your reasoning>
|
219 |
-
Treat ad-real: <your reasoning>''']
|
220 |
-
|
221 |
-
st.session_state['nbqs'] = []
|
222 |
-
st.session_state['citations'] = {}
|
223 |
-
|
224 |
-
st.session_state['past_messages'] = []
|
225 |
-
st.session_state["last_request"] = None
|
226 |
-
st.session_state["last_proposal"] = None
|
227 |
-
|
228 |
-
st.session_state['doctor_question'] = ''
|
229 |
-
st.session_state['patient_reply'] = ''
|
230 |
-
|
231 |
-
st.session_state['chat_history_array'] = []
|
232 |
-
st.session_state['chat_history'] = ''
|
233 |
-
|
234 |
-
st.session_state['feed_summary'] = ''
|
235 |
-
st.session_state['summary'] = ''
|
236 |
-
|
237 |
-
st.session_state["selected_guidelines"] = ["General"]
|
238 |
-
st.session_state["guidelines_dict"] = get_guidelines_dict()
|
239 |
-
|
240 |
-
st.session_state["triage_recommendation"] = ''
|
241 |
-
|
242 |
-
st.session_state["session_events"] = []
|
243 |
-
|
244 |
|
245 |
def init_session_state():
|
246 |
print('init_session_state()')
|
247 |
st.session_state['session_started'] = False
|
248 |
-
st.session_state[
|
249 |
-
st.session_state[
|
250 |
-
|
251 |
-
st.session_state["
|
252 |
-
st.session_state["
|
253 |
-
st.session_state["
|
254 |
-
st.session_state[
|
255 |
-
|
256 |
-
|
257 |
-
st.session_state['gender'] = get_genders()[default_gender_index]
|
258 |
-
st.session_state['gender_index'] = default_gender_index
|
259 |
-
|
260 |
-
st.session_state['age'] = 30
|
261 |
-
|
262 |
-
st.session_state['patient_medical_info'] = ''
|
263 |
-
|
264 |
-
default_search_query = 0
|
265 |
-
st.session_state['search_query_type'] = get_search_query_type_options()[default_search_query]
|
266 |
-
st.session_state['search_query_type_index'] = default_search_query
|
267 |
-
st.session_state['engine'] = get_available_engines()[0]
|
268 |
-
st.session_state['temperature'] = 0.0
|
269 |
-
st.session_state['top_p'] = 1.0
|
270 |
-
st.session_state['feed_chat_transcript'] = ''
|
271 |
-
|
272 |
-
st.session_state["llm_model"] = True
|
273 |
-
st.session_state["hugging_face_models"] = True
|
274 |
-
st.session_state["local_models"] = True
|
275 |
-
|
276 |
-
restart_session()
|
277 |
-
|
278 |
-
st.session_state['system_prompt'] = get_system_prompt()
|
279 |
-
st.session_state['system_prompt_after_on_change'] = get_system_prompt()
|
280 |
-
|
281 |
-
st.session_state["summary"] = ''
|
282 |
-
|
283 |
|
284 |
def get_genders():
|
285 |
return ['Male', 'Female']
|
@@ -498,23 +415,55 @@ def get_diarized_f_path(audio_f_name):
|
|
498 |
return DATA_FOLDER + base_name + ".txt"
|
499 |
|
500 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
501 |
def display_llm_output():
|
502 |
st.header("LLM")
|
503 |
|
504 |
form = st.form('llm')
|
505 |
|
506 |
-
|
|
|
|
|
|
|
507 |
|
508 |
-
|
509 |
|
510 |
-
if
|
|
|
|
|
511 |
llm_response = get_llm_response(
|
512 |
-
st.session_state["
|
513 |
-
st.session_state["
|
514 |
-
st.session_state["
|
515 |
-
st.session_state["
|
516 |
-
"
|
517 |
-
|
|
|
|
|
518 |
st.write(llm_response)
|
519 |
st.write('Done displaying LLM response')
|
520 |
|
|
|
11 |
import io
|
12 |
|
13 |
from utils.default_values import get_system_prompt, get_guidelines_dict
|
14 |
+
from utils.epfl_meditron_utils import get_llm_response, gptq_model_options
|
15 |
from utils.openai_utils import get_available_engines, get_search_query_type_options
|
16 |
|
17 |
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
|
|
|
83 |
st.sidebar.write('**Parameters**')
|
84 |
form = st.sidebar.form("config_form", clear_on_submit=True)
|
85 |
|
86 |
+
model_name_or_path = form.selectbox("Select model", gptq_model_options())
|
87 |
+
|
88 |
+
temperature = form.slider(label="Temperature", min_value=0.0, max_value=1.0, step=0.01, value=0.01)
|
89 |
+
do_sample = form.checkbox('do_sample')
|
90 |
+
top_p = form.slider(label="top_p", min_value=0.0, max_value=1.0, step=0.01, value=0.95)
|
91 |
+
top_k = form.slider(label="top_k", min_value=1, max_value=1000, step=1, value=40)
|
92 |
+
max_new_tokens = form.slider(label="max_new_tokens", min_value=32, max_value=512, step=1, value=32)
|
93 |
+
repetition_penalty = form.slider(label="repetition_penalty", min_value=0.0, max_value=1.0, step=0.01, value=0.95)
|
|
|
|
|
94 |
|
95 |
temperature = form.slider('Temperature (0 = deterministic, 1 = more freedom)', min_value=0.0,
|
96 |
max_value=1.0, value=st.session_state['temperature'], step=0.1)
|
97 |
+
|
98 |
top_p = form.slider('top_p (0 = focused, 1 = broader answer range)', min_value=0.0,
|
99 |
max_value=1.0, value=st.session_state['top_p'], step=0.1)
|
100 |
|
|
|
103 |
submitted = form.form_submit_button("Start session")
|
104 |
if submitted and not st.session_state['session_started']:
|
105 |
print('Parameters updated...')
|
|
|
106 |
st.session_state['session_started'] = True
|
107 |
|
108 |
+
st.session_state["session_events"] = []
|
109 |
+
st.session_state["model_name_or_path"] = model_name_or_path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
st.session_state["temperature"] = temperature
|
111 |
+
st.session_state["do_sample"] = do_sample
|
112 |
st.session_state["top_p"] = top_p
|
113 |
+
st.session_state["top_k"] = top_k
|
114 |
+
st.session_state["max_new_tokens"] = max_new_tokens
|
115 |
+
st.session_state["repetition_penalty"] = repetition_penalty
|
116 |
|
117 |
st.rerun()
|
118 |
|
|
|
183 |
raise Exception('Unknown role: ' + str(i["role"]))
|
184 |
|
185 |
return res
|
186 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
187 |
|
188 |
def init_session_state():
|
189 |
print('init_session_state()')
|
190 |
st.session_state['session_started'] = False
|
191 |
+
st.session_state["session_events"] = []
|
192 |
+
st.session_state["model_name_or_path"] = "TheBloke/meditron-7B-GPTQ"
|
193 |
+
st.session_state["temperature"] = 0.01
|
194 |
+
st.session_state["do_sample"] = True
|
195 |
+
st.session_state["top_p"] = 0.95
|
196 |
+
st.session_state["top_k"] = 40
|
197 |
+
st.session_state["max_new_tokens"] = 512
|
198 |
+
st.session_state["repetition_penalty"] = 1.1
|
199 |
+
st.session_state["system_message"] = "You are a medical expert that provides answers for a medically trained audience"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
200 |
|
201 |
def get_genders():
|
202 |
return ['Male', 'Female']
|
|
|
415 |
return DATA_FOLDER + base_name + ".txt"
|
416 |
|
417 |
|
418 |
+
def get_prompt_format(model_name):
|
419 |
+
if model_name == "TheBloke/Llama-2-13B-chat-GPTQ":
|
420 |
+
return '''[INST] <<SYS>>
|
421 |
+
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
|
422 |
+
<</SYS>>
|
423 |
+
{prompt}[/INST]
|
424 |
+
|
425 |
+
'''
|
426 |
+
if model_name == "TheBloke/Llama-2-7B-Chat-GPTQ":
|
427 |
+
return "[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n{prompt}[/INST]"
|
428 |
+
|
429 |
+
if model_name == "TheBloke/meditron-7B-GPTQ" or model_name == "TheBloke/meditron-70B-GPTQ":
|
430 |
+
return '''<|im_start|>system
|
431 |
+
{system_message}<|im_end|>
|
432 |
+
<|im_start|>user
|
433 |
+
{prompt}<|im_end|>
|
434 |
+
<|im_start|>assistant'''
|
435 |
+
|
436 |
+
return ""
|
437 |
+
|
438 |
+
def format_prompt(template, system_message, prompt):
|
439 |
+
if template == "":
|
440 |
+
return f"{system_message} {prompt}"
|
441 |
+
return template.format(system_message=system_message, prompt=prompt)
|
442 |
+
|
443 |
def display_llm_output():
|
444 |
st.header("LLM")
|
445 |
|
446 |
form = st.form('llm')
|
447 |
|
448 |
+
prompt_format_str = get_prompt_format(st.session_state["model_name_or_path"])
|
449 |
+
prompt_format = form.text_area('Prompt format', value=prompt_format_str)
|
450 |
+
system_prompt = form.text_area('System prompt', value=st.session_state["system_prompt"])
|
451 |
+
prompt = form.text_area('Prompt', value=st.session_state["prompt"])
|
452 |
|
453 |
+
submitted = form.form_submit_button('Submit')
|
454 |
|
455 |
+
if submitted:
|
456 |
+
formatted_prompt = format_prompt(prompt_format, system_prompt, prompt)
|
457 |
+
print(f"Formatted prompt: {format_prompt}")
|
458 |
llm_response = get_llm_response(
|
459 |
+
st.session_state["model_name"],
|
460 |
+
st.session_state["temperature"],
|
461 |
+
st.session_state["do_sample"],
|
462 |
+
st.session_state["top_p"],
|
463 |
+
st.session_state["top_k"],
|
464 |
+
st.session_state["max_new_tokens"],
|
465 |
+
st.session_state["repetition_penalty"],
|
466 |
+
formatted_prompt)
|
467 |
st.write(llm_response)
|
468 |
st.write('Done displaying LLM response')
|
469 |
|
utils/epfl_meditron_utils.py
CHANGED
@@ -1,49 +1,43 @@
|
|
1 |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
2 |
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
|
|
|
|
|
|
9 |
model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
|
10 |
device_map="auto",
|
11 |
trust_remote_code=False,
|
12 |
revision="main")
|
13 |
|
14 |
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
<|im_start|>user
|
19 |
-
{prompt}<|im_end|>
|
20 |
-
<|im_start|>assistant
|
21 |
-
'''
|
22 |
-
|
23 |
-
print("Template:")
|
24 |
-
print(prompt_template)
|
25 |
|
26 |
print("\n\n*** Generate:")
|
27 |
|
28 |
-
input_ids = tokenizer(
|
29 |
-
output = model.generate(inputs=input_ids, temperature=
|
30 |
-
print(tokenizer.decode(output[0]))
|
31 |
|
32 |
print("*** Pipeline:")
|
33 |
pipe = pipeline(
|
34 |
"text-generation",
|
35 |
model=model,
|
36 |
tokenizer=tokenizer,
|
37 |
-
max_new_tokens=
|
38 |
-
do_sample=
|
39 |
-
temperature=
|
40 |
-
top_p=
|
41 |
-
top_k=
|
42 |
-
repetition_penalty=
|
43 |
)
|
44 |
|
45 |
-
|
46 |
-
|
47 |
-
response = pipe(prompt_template)[0]['generated_text']
|
48 |
print(response)
|
49 |
return response
|
|
|
1 |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
2 |
|
3 |
+
def gptq_model_options():
|
4 |
+
return [
|
5 |
+
"TheBloke/Llama-2-7B-Chat-GPTQ",
|
6 |
+
"TheBloke/Llama-2-13B-chat-GPTQ",
|
7 |
+
"TheBloke/meditron-7B-GPTQ",
|
8 |
+
"TheBloke/meditron-70B-GPTQ",
|
9 |
+
]
|
10 |
+
|
11 |
+
def get_llm_response(model_name_or_path, temperature, do_sample, top_p, top_k, max_new_tokens, repetition_penalty, formatted_prompt):
|
12 |
model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
|
13 |
device_map="auto",
|
14 |
trust_remote_code=False,
|
15 |
revision="main")
|
16 |
|
17 |
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
|
18 |
+
|
19 |
+
print("Formatted prompt:")
|
20 |
+
print(formatted_prompt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
print("\n\n*** Generate:")
|
23 |
|
24 |
+
input_ids = tokenizer(formatted_prompt, return_tensors='pt').input_ids.cuda()
|
25 |
+
output = model.generate(inputs=input_ids, temperature=temperature, do_sample=do_sample, top_p=top_p, top_k=top_k, max_new_tokens=max_new_tokens)
|
26 |
+
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
27 |
|
28 |
print("*** Pipeline:")
|
29 |
pipe = pipeline(
|
30 |
"text-generation",
|
31 |
model=model,
|
32 |
tokenizer=tokenizer,
|
33 |
+
max_new_tokens=max_new_tokens,
|
34 |
+
do_sample=do_sample,
|
35 |
+
temperature=temperature,
|
36 |
+
top_p=top_p,
|
37 |
+
top_k=top_p,
|
38 |
+
repetition_penalty=repetition_penalty
|
39 |
)
|
40 |
|
41 |
+
response = pipe(formatted_prompt)[0]['generated_text']
|
|
|
|
|
42 |
print(response)
|
43 |
return response
|