Robin Genolet commited on
Commit
5fd44e9
·
1 Parent(s): dad4228

fix: params

Browse files
Files changed (2) hide show
  1. app.py +16 -90
  2. utils/epfl_meditron_utils.py +7 -3
app.py CHANGED
@@ -10,6 +10,8 @@ import subprocess
10
  import sys
11
  import io
12
 
 
 
13
  from utils.default_values import get_system_prompt, get_guidelines_dict
14
  from utils.epfl_meditron_utils import get_llm_response, gptq_model_options
15
  from utils.openai_utils import get_available_engines, get_search_query_type_options
@@ -17,73 +19,18 @@ from utils.openai_utils import get_available_engines, get_search_query_type_opti
17
  from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
18
  from sklearn.metrics import classification_report
19
 
20
- DATA_FOLDER = "data/"
21
-
22
  POC_VERSION = "0.1.0"
23
- MAX_QUESTIONS = 10
24
- AVAILABLE_LANGUAGES = ["DE", "EN", "FR"]
25
 
26
  st.set_page_config(page_title='Medgate Whisper PoC', page_icon='public/medgate.png')
27
 
28
- # Azure apparently truncates message if longer than 200, see
29
- MAX_SYSTEM_MESSAGE_TOKENS = 200
30
-
31
-
32
- def format_question(q):
33
- res = q
34
-
35
- # Remove numerical prefixes, if any, e.g. '1. [...]'
36
- if re.match(r'^[0-9].\s', q):
37
- res = res[3:]
38
-
39
- # Replace doc reference by doc name
40
- if len(st.session_state["citations"]) > 0:
41
- for source_ref in re.findall(r'\[doc[0-9]+\]', res):
42
- citation_number = int(re.findall(r'[0-9]+', source_ref)[0])
43
- citation_index = citation_number - 1 if citation_number > 0 else 0
44
- citation = st.session_state["citations"][citation_index]
45
- source_title = citation["title"]
46
- res = res.replace(source_ref, '[' + source_title + ']')
47
-
48
- return res.strip()
49
-
50
-
51
- def get_text_from_row(text):
52
- res = str(text)
53
- if res == "nan":
54
- return ""
55
- return res
56
- def get_questions_from_df(df, lang, test_scenario_name):
57
- questions = []
58
- for i, row in df.iterrows():
59
- questions.append({
60
- "question": row[lang + ": Fragen"],
61
- "answer": get_text_from_row(row[test_scenario_name]),
62
- "question_id": uuid.uuid4()
63
- })
64
- return questions
65
-
66
-
67
- def get_questions(df, lead_symptom, lang, test_scenario_name):
68
- print(str(st.session_state["lead_symptom"]) + " -> " + lead_symptom)
69
- print(str(st.session_state["scenario_name"]) + " -> " + test_scenario_name)
70
- if st.session_state["lead_symptom"] != lead_symptom or st.session_state["scenario_name"] != test_scenario_name:
71
- st.session_state["lead_symptom"] = lead_symptom
72
- st.session_state["scenario_name"] = test_scenario_name
73
- symptom_col_name = st.session_state["language"] + ": Symptome"
74
- df_questions = df[(df[symptom_col_name] == lead_symptom)]
75
- st.session_state["questions"] = get_questions_from_df(df_questions, lang, test_scenario_name)
76
-
77
- return st.session_state["questions"]
78
-
79
-
80
  def display_streamlit_sidebar():
81
  st.sidebar.title("Local LLM PoC " + str(POC_VERSION))
82
 
83
  st.sidebar.write('**Parameters**')
84
  form = st.sidebar.form("config_form", clear_on_submit=True)
85
 
86
- model_name_or_path = form.selectbox("Select model", gptq_model_options())
 
87
 
88
  temperature = form.slider(label="Temperature", min_value=0.0, max_value=1.0, step=0.01, value=st.session_state["temperature"])
89
  do_sample = form.checkbox('do_sample', value=st.session_state["do_sample"])
@@ -98,6 +45,15 @@ def display_streamlit_sidebar():
98
  st.session_state['session_started'] = True
99
 
100
  st.session_state["session_events"] = []
 
 
 
 
 
 
 
 
 
101
  st.session_state["model_name_or_path"] = model_name_or_path
102
  st.session_state["temperature"] = temperature
103
  st.session_state["do_sample"] = do_sample
@@ -123,10 +79,7 @@ def init_session_state():
123
  st.session_state["system_prompt"] = "You are a medical expert that provides answers for a medically trained audience"
124
  st.session_state["prompt"] = ""
125
  st.session_state["llm_messages"] = []
126
-
127
- def get_genders():
128
- return ['Male', 'Female']
129
-
130
 
131
  def display_session_overview():
132
  st.subheader('History of LLM queries')
@@ -156,33 +109,6 @@ def display_session_overview():
156
  st.write("Total compute time (ms): " + str(total_time))
157
 
158
 
159
- def plot_report(title, expected, predicted, display_labels):
160
- st.markdown('#### ' + title)
161
- conf_matrix = confusion_matrix(expected, predicted, labels=display_labels)
162
- conf_matrix_plot = ConfusionMatrixDisplay(confusion_matrix=conf_matrix, display_labels=display_labels)
163
- conf_matrix_plot.plot()
164
- st.pyplot(plt.gcf())
165
-
166
- report = classification_report(expected, predicted, output_dict=True)
167
- df_report = pd.DataFrame(report).transpose()
168
- st.write(df_report)
169
-
170
- df_rp = df_report
171
- df_rp = df_rp.drop('support', axis=1)
172
- df_rp = df_rp.drop(['accuracy', 'macro avg', 'weighted avg'])
173
-
174
- try:
175
- ax = df_rp.plot(kind="bar", legend=True)
176
- for container in ax.containers:
177
- ax.bar_label(container, fontsize=7)
178
- plt.xticks(rotation=45)
179
- plt.legend(loc=(1.04, 0))
180
- st.pyplot(plt.gcf())
181
- except Exception as e:
182
- # Out of bounds
183
- pass
184
-
185
-
186
  def get_prompt_format(model_name):
187
  formatted_text = ""
188
  if model_name == "TheBloke/Llama-2-13B-chat-GPTQ" or model_name== "TheBloke/Llama-2-7B-Chat-GPTQ":
@@ -202,7 +128,7 @@ def get_prompt_format(model_name):
202
 
203
  '''
204
 
205
- return formatted_text.replace("\t", "")
206
 
207
  def format_prompt(template, system_message, prompt):
208
  if template == "":
@@ -227,7 +153,7 @@ def display_llm_output():
227
  formatted_prompt = format_prompt(prompt_format, system_prompt, prompt)
228
  print(f"Formatted prompt: {format_prompt}")
229
  llm_response = get_llm_response(
230
- st.session_state["model_name_or_path"],
231
  st.session_state["temperature"],
232
  st.session_state["do_sample"],
233
  st.session_state["top_p"],
 
10
  import sys
11
  import io
12
 
13
+ import inspect
14
+
15
  from utils.default_values import get_system_prompt, get_guidelines_dict
16
  from utils.epfl_meditron_utils import get_llm_response, gptq_model_options
17
  from utils.openai_utils import get_available_engines, get_search_query_type_options
 
19
  from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
20
  from sklearn.metrics import classification_report
21
 
 
 
22
  POC_VERSION = "0.1.0"
 
 
23
 
24
  st.set_page_config(page_title='Medgate Whisper PoC', page_icon='public/medgate.png')
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  def display_streamlit_sidebar():
27
  st.sidebar.title("Local LLM PoC " + str(POC_VERSION))
28
 
29
  st.sidebar.write('**Parameters**')
30
  form = st.sidebar.form("config_form", clear_on_submit=True)
31
 
32
+ model_name_or_path = form.selectbox("Select model", gptq_model_options(), value=st.session_state["model_index"])
33
+ model_name_or_path_other = form.text_input('Or input any GPTQ model', value=st.session_state["model_name_or_path_other"])
34
 
35
  temperature = form.slider(label="Temperature", min_value=0.0, max_value=1.0, step=0.01, value=st.session_state["temperature"])
36
  do_sample = form.checkbox('do_sample', value=st.session_state["do_sample"])
 
45
  st.session_state['session_started'] = True
46
 
47
  st.session_state["session_events"] = []
48
+
49
+ if len(model_name_or_path_other) > 0:
50
+ st.session_state["model_name"] = model_name_or_path_other
51
+ st.session_state["model_name_or_path_other"] = model_name_or_path_other
52
+ else:
53
+ st.session_state["model_name"] = model_name_or_path
54
+ st.session_state["model_index"] = gptq_model_options().index(model_name_or_path)
55
+
56
+
57
  st.session_state["model_name_or_path"] = model_name_or_path
58
  st.session_state["temperature"] = temperature
59
  st.session_state["do_sample"] = do_sample
 
79
  st.session_state["system_prompt"] = "You are a medical expert that provides answers for a medically trained audience"
80
  st.session_state["prompt"] = ""
81
  st.session_state["llm_messages"] = []
82
+
 
 
 
83
 
84
  def display_session_overview():
85
  st.subheader('History of LLM queries')
 
109
  st.write("Total compute time (ms): " + str(total_time))
110
 
111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  def get_prompt_format(model_name):
113
  formatted_text = ""
114
  if model_name == "TheBloke/Llama-2-13B-chat-GPTQ" or model_name== "TheBloke/Llama-2-7B-Chat-GPTQ":
 
128
 
129
  '''
130
 
131
+ return inspect.cleandoc(formatted_text)
132
 
133
  def format_prompt(template, system_message, prompt):
134
  if template == "":
 
153
  formatted_prompt = format_prompt(prompt_format, system_prompt, prompt)
154
  print(f"Formatted prompt: {format_prompt}")
155
  llm_response = get_llm_response(
156
+ st.session_state["model_name"],
157
  st.session_state["temperature"],
158
  st.session_state["do_sample"],
159
  st.session_state["top_p"],
utils/epfl_meditron_utils.py CHANGED
@@ -1,4 +1,5 @@
1
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
 
2
 
3
  def gptq_model_options():
4
  return [
@@ -19,6 +20,8 @@ def get_llm_response(model_name_or_path, temperature, do_sample, top_p, top_k, m
19
  print("Formatted prompt:")
20
  print(formatted_prompt)
21
 
 
 
22
  #print("\n\n*** Generate:")
23
  #input_ids = tokenizer(formatted_prompt, return_tensors='pt').input_ids.cuda()
24
  #output = model.generate(inputs=input_ids, temperature=temperature, do_sample=do_sample, top_p=top_p, top_k=top_k, max_new_tokens=max_new_tokens)
@@ -37,6 +40,7 @@ def get_llm_response(model_name_or_path, temperature, do_sample, top_p, top_k, m
37
  repetition_penalty=repetition_penalty
38
  )
39
 
40
- response = pipe(formatted_prompt)[0]['generated_text']
41
- print(response)
42
- return response
 
 
1
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
2
+ import streamlit as st
3
 
4
  def gptq_model_options():
5
  return [
 
20
  print("Formatted prompt:")
21
  print(formatted_prompt)
22
 
23
+ st.session_state["llm_messages"].append(formatted_prompt)
24
+
25
  #print("\n\n*** Generate:")
26
  #input_ids = tokenizer(formatted_prompt, return_tensors='pt').input_ids.cuda()
27
  #output = model.generate(inputs=input_ids, temperature=temperature, do_sample=do_sample, top_p=top_p, top_k=top_k, max_new_tokens=max_new_tokens)
 
40
  repetition_penalty=repetition_penalty
41
  )
42
 
43
+ pipe_response = pipe(formatted_prompt)
44
+ st.session_state["llm_messages"].append(pipe_response)
45
+ print(pipe_response)
46
+ return pipe_response[0]['generated_text']