Robin Genolet commited on
Commit
4f4f63f
·
1 Parent(s): 2467547

fix: indent

Browse files
Files changed (1) hide show
  1. app.py +10 -167
app.py CHANGED
@@ -92,8 +92,6 @@ def display_streamlit_sidebar():
92
  max_new_tokens = form.slider(label="max_new_tokens", min_value=32, max_value=512, step=1, value=st.session_state["max_new_tokens"])
93
  repetition_penalty = form.slider(label="repetition_penalty", min_value=0.0, max_value=5.0, step=0.01, value=st.session_state["repetition_penalty"])
94
 
95
- form.write('Best practice is to only modify temperature or top_p, not both')
96
-
97
  submitted = form.form_submit_button("Start session")
98
  if submitted and not st.session_state['session_started']:
99
  print('Parameters updated...')
@@ -122,7 +120,7 @@ def init_session_state():
122
  st.session_state["top_k"] = 40
123
  st.session_state["max_new_tokens"] = 512
124
  st.session_state["repetition_penalty"] = 1.1
125
- st.session_state["system_message"] = "You are a medical expert that provides answers for a medically trained audience"
126
 
127
  def get_genders():
128
  return ['Male', 'Female']
@@ -156,135 +154,6 @@ def display_session_overview():
156
  st.write("Total compute time (ms): " + str(total_time))
157
 
158
 
159
- def remove_question(question_id):
160
- st.session_state["questions"] = [value for value in st.session_state["questions"] if
161
- str(value["question_id"]) != str(question_id)]
162
- st.rerun()
163
-
164
-
165
- def get_prompt_from_lead_symptom(df_config, df_prompt, lead_symptom, lang, fallback=True):
166
- de_lead_symptom = lead_symptom
167
-
168
- if lang != "DE":
169
- df_lead_symptom = df_config[df_config[lang + ": Symptome"] == lead_symptom]
170
- de_lead_symptom = df_lead_symptom["DE: Symptome"].iloc[0]
171
- print("DE lead symptom: " + de_lead_symptom)
172
-
173
- for i, row in df_prompt.iterrows():
174
- if de_lead_symptom in row["Questionnaire"]:
175
- return row["Prompt"]
176
-
177
- warning_text = "No guidelines found for lead symptom " + lead_symptom + " (" + de_lead_symptom + ")"
178
- if fallback:
179
- st.toast(warning_text + ", using generic prompt", icon='🚨')
180
- return st.session_state["system_prompt"]
181
- st.toast(warning_text, icon='🚨')
182
-
183
- return ""
184
-
185
-
186
- def get_scenarios(df):
187
- return [v for v in df.columns.values if v.startswith('TLC') or v.startswith('GP')]
188
-
189
-
190
- def get_gender_age_from_test_scenario(test_scenario):
191
- try:
192
- result = re.search(r"([FM])(\d+)", test_scenario)
193
- res_age = int(result.group(2))
194
- gender = result.group(1)
195
- res_gender = None
196
- if gender == "M":
197
- res_gender = "Male"
198
- elif gender == "F":
199
- res_gender = "Female"
200
- else:
201
- raise Exception('Unexpected gender')
202
-
203
- return res_gender, res_age
204
-
205
- except:
206
- st.error("Unable to extract name, gender; using 30M as default")
207
- return "Male", 30
208
-
209
- def get_freetext_to_reco(reco_freetext_cased, emg_class_enabled=False):
210
- reco_freetext = ""
211
- if reco_freetext_cased:
212
- reco_freetext = reco_freetext_cased.lower()
213
-
214
- if reco_freetext.startswith('treat remotely') or reco_freetext.startswith('telecare'):
215
- return 'TELECARE'
216
- if reco_freetext.startswith('treat ad-real') or reco_freetext.startswith('gp') \
217
- or reco_freetext.startswith('general practitioner'):
218
- return 'GP'
219
- if reco_freetext.startswith('emergency') or reco_freetext.startswith('emg') \
220
- or reco_freetext.startswith('urgent'):
221
- if emg_class_enabled:
222
- return 'EMERGENCY'
223
- return 'GP'
224
-
225
- if "gp" in reco_freetext or 'general practitioner' in reco_freetext \
226
- or "nicht über tele" in reco_freetext or 'durch einen arzt erford' in reco_freetext \
227
- or "persönliche untersuchung erfordert" in reco_freetext:
228
- return 'GP'
229
-
230
- if ("telecare" in reco_freetext or 'telemed' in reco_freetext or
231
- 'can be treated remotely' in reco_freetext):
232
- return 'TELECARE'
233
-
234
- if ('emergency' in reco_freetext or 'urgent' in reco_freetext or
235
- 'not be treated remotely' in reco_freetext or "nicht tele" in reco_freetext):
236
- return 'GP'
237
-
238
- warning_msg = 'Cannot extract reco from LLM text: ' + reco_freetext
239
- st.toast(warning_msg)
240
- print(warning_msg)
241
- return 'TRIAGE_IMPOSSIBLE'
242
-
243
-
244
- def get_structured_reco(row, index, emg_class_enabled):
245
- freetext_reco_col_name = "llm_reco_freetext_" + str(index)
246
- freetext_reco = row[freetext_reco_col_name].lower()
247
- return get_freetext_to_reco(freetext_reco, emg_class_enabled)
248
-
249
-
250
- def add_expected_dispo(row, emg_class_enabled):
251
- disposition = row["disposition"]
252
- if disposition == "GP" or disposition == "TELECARE":
253
- return disposition
254
- if disposition == "EMERGENCY":
255
- if emg_class_enabled:
256
- return "EMERGENCY"
257
- return "GP"
258
-
259
- raise Exception("Missing disposition for row " + str(row.name) + " with summary " + row["case_summary"])
260
-
261
-
262
- def get_test_scenarios(df):
263
- res = []
264
- for col in df.columns.values:
265
- if str(col).startswith('GP') or str(col).startswith('TLC'):
266
- res.append(col)
267
- return res
268
-
269
-
270
- def get_transcript(df, test_scenario, lang):
271
- transcript = ""
272
- for i, row in df.iterrows():
273
- transcript += "\nDoctor: " + row[lang + ": Fragen"]
274
- transcript += ", Patient: " + str(row[test_scenario])
275
- return transcript
276
-
277
-
278
- def get_expected_from_scenario(test_scenario):
279
- reco = test_scenario.split('_')[0]
280
- if reco == "GP":
281
- return "GP"
282
- elif reco == "TLC":
283
- return "TELECARE"
284
- else:
285
- raise Exception('Unexpected reco: ' + reco)
286
-
287
-
288
  def plot_report(title, expected, predicted, display_labels):
289
  st.markdown('#### ' + title)
290
  conf_matrix = confusion_matrix(expected, predicted, labels=display_labels)
@@ -312,49 +181,23 @@ def plot_report(title, expected, predicted, display_labels):
312
  pass
313
 
314
 
315
- def get_complete_prompt(generic_prompt, guidelines_prompt):
316
- complete_prompt = ""
317
- if generic_prompt:
318
- complete_prompt += generic_prompt
319
-
320
- if generic_prompt and guidelines_prompt:
321
- complete_prompt += ".\n\n"
322
-
323
- if guidelines_prompt:
324
- complete_prompt += guidelines_prompt
325
-
326
- return complete_prompt
327
-
328
-
329
- def run_command(args):
330
- """Run command, transfer stdout/stderr back into Streamlit and manage error"""
331
- cmd = ' '.join(args)
332
- result = subprocess.run(cmd, capture_output=True, text=True)
333
- print(result)
334
-
335
- def get_diarized_f_path(audio_f_name):
336
- # TODO p2: Quick hack, cleaner with os or regexes
337
- base_name = audio_f_name.split('.')[0]
338
- return DATA_FOLDER + base_name + ".txt"
339
-
340
-
341
  def get_prompt_format(model_name):
342
  if model_name == "TheBloke/Llama-2-13B-chat-GPTQ":
343
  return '''[INST] <<SYS>>
344
- You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
345
- <</SYS>>
346
- {prompt}[/INST]
347
 
348
- '''
349
  if model_name == "TheBloke/Llama-2-7B-Chat-GPTQ":
350
  return "[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n{prompt}[/INST]"
351
 
352
  if model_name == "TheBloke/meditron-7B-GPTQ" or model_name == "TheBloke/meditron-70B-GPTQ":
353
  return '''<|im_start|>system
354
- {system_message}<|im_end|>
355
- <|im_start|>user
356
- {prompt}<|im_end|>
357
- <|im_start|>assistant'''
358
 
359
  return ""
360
 
@@ -370,7 +213,7 @@ def display_llm_output():
370
 
371
  prompt_format_str = get_prompt_format(st.session_state["model_name_or_path"])
372
  prompt_format = form.text_area('Prompt format', value=prompt_format_str)
373
- system_prompt = form.text_area('System prompt', value=st.session_state["system_prompt"])
374
  prompt = form.text_area('Prompt', value=st.session_state["prompt"])
375
 
376
  submitted = form.form_submit_button('Submit')
 
92
  max_new_tokens = form.slider(label="max_new_tokens", min_value=32, max_value=512, step=1, value=st.session_state["max_new_tokens"])
93
  repetition_penalty = form.slider(label="repetition_penalty", min_value=0.0, max_value=5.0, step=0.01, value=st.session_state["repetition_penalty"])
94
 
 
 
95
  submitted = form.form_submit_button("Start session")
96
  if submitted and not st.session_state['session_started']:
97
  print('Parameters updated...')
 
120
  st.session_state["top_k"] = 40
121
  st.session_state["max_new_tokens"] = 512
122
  st.session_state["repetition_penalty"] = 1.1
123
+ st.session_state["system_prompt"] = "You are a medical expert that provides answers for a medically trained audience"
124
 
125
  def get_genders():
126
  return ['Male', 'Female']
 
154
  st.write("Total compute time (ms): " + str(total_time))
155
 
156
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  def plot_report(title, expected, predicted, display_labels):
158
  st.markdown('#### ' + title)
159
  conf_matrix = confusion_matrix(expected, predicted, labels=display_labels)
 
181
  pass
182
 
183
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  def get_prompt_format(model_name):
185
  if model_name == "TheBloke/Llama-2-13B-chat-GPTQ":
186
  return '''[INST] <<SYS>>
187
+ You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
188
+ <</SYS>>
189
+ {prompt}[/INST]
190
 
191
+ '''
192
  if model_name == "TheBloke/Llama-2-7B-Chat-GPTQ":
193
  return "[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n{prompt}[/INST]"
194
 
195
  if model_name == "TheBloke/meditron-7B-GPTQ" or model_name == "TheBloke/meditron-70B-GPTQ":
196
  return '''<|im_start|>system
197
+ {system_message}<|im_end|>
198
+ <|im_start|>user
199
+ {prompt}<|im_end|>
200
+ <|im_start|>assistant'''
201
 
202
  return ""
203
 
 
213
 
214
  prompt_format_str = get_prompt_format(st.session_state["model_name_or_path"])
215
  prompt_format = form.text_area('Prompt format', value=prompt_format_str)
216
+ system_prompt = ""#form.text_area('System prompt', value=st.session_state["system_prompt"])
217
  prompt = form.text_area('Prompt', value=st.session_state["prompt"])
218
 
219
  submitted = form.form_submit_button('Submit')