Spaces:
Paused
Paused
import re | |
import uuid | |
import pandas as pd | |
import streamlit as st | |
import re | |
import matplotlib.pyplot as plt | |
import subprocess | |
import sys | |
import io | |
from utils.default_values import get_system_prompt, get_guidelines_dict | |
from utils.epfl_meditron_utils import get_llm_response, gptq_model_options | |
from utils.openai_utils import get_available_engines, get_search_query_type_options | |
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay | |
from sklearn.metrics import classification_report | |
DATA_FOLDER = "data/" | |
POC_VERSION = "0.1.0" | |
MAX_QUESTIONS = 10 | |
AVAILABLE_LANGUAGES = ["DE", "EN", "FR"] | |
st.set_page_config(page_title='Medgate Whisper PoC', page_icon='public/medgate.png') | |
# Azure apparently truncates message if longer than 200, see | |
MAX_SYSTEM_MESSAGE_TOKENS = 200 | |
def format_question(q): | |
res = q | |
# Remove numerical prefixes, if any, e.g. '1. [...]' | |
if re.match(r'^[0-9].\s', q): | |
res = res[3:] | |
# Replace doc reference by doc name | |
if len(st.session_state["citations"]) > 0: | |
for source_ref in re.findall(r'\[doc[0-9]+\]', res): | |
citation_number = int(re.findall(r'[0-9]+', source_ref)[0]) | |
citation_index = citation_number - 1 if citation_number > 0 else 0 | |
citation = st.session_state["citations"][citation_index] | |
source_title = citation["title"] | |
res = res.replace(source_ref, '[' + source_title + ']') | |
return res.strip() | |
def get_text_from_row(text): | |
res = str(text) | |
if res == "nan": | |
return "" | |
return res | |
def get_questions_from_df(df, lang, test_scenario_name): | |
questions = [] | |
for i, row in df.iterrows(): | |
questions.append({ | |
"question": row[lang + ": Fragen"], | |
"answer": get_text_from_row(row[test_scenario_name]), | |
"question_id": uuid.uuid4() | |
}) | |
return questions | |
def get_questions(df, lead_symptom, lang, test_scenario_name): | |
print(str(st.session_state["lead_symptom"]) + " -> " + lead_symptom) | |
print(str(st.session_state["scenario_name"]) + " -> " + test_scenario_name) | |
if st.session_state["lead_symptom"] != lead_symptom or st.session_state["scenario_name"] != test_scenario_name: | |
st.session_state["lead_symptom"] = lead_symptom | |
st.session_state["scenario_name"] = test_scenario_name | |
symptom_col_name = st.session_state["language"] + ": Symptome" | |
df_questions = df[(df[symptom_col_name] == lead_symptom)] | |
st.session_state["questions"] = get_questions_from_df(df_questions, lang, test_scenario_name) | |
return st.session_state["questions"] | |
def display_streamlit_sidebar(): | |
st.sidebar.title("Local LLM PoC " + str(POC_VERSION)) | |
st.sidebar.write('**Parameters**') | |
form = st.sidebar.form("config_form", clear_on_submit=True) | |
model_name_or_path = form.selectbox("Select model", gptq_model_options()) | |
temperature = form.slider(label="Temperature", min_value=0.0, max_value=1.0, step=0.01, value=st.session_state["temperature"]) | |
do_sample = form.checkbox('do_sample', value=st.session_state["do_sample"]) | |
top_p = form.slider(label="top_p", min_value=0.0, max_value=1.0, step=0.01, value=st.session_state["top_p"]) | |
top_k = form.slider(label="top_k", min_value=1, max_value=1000, step=1, value=st.session_state["top_k"]) | |
max_new_tokens = form.slider(label="max_new_tokens", min_value=32, max_value=512, step=1, value=st.session_state["max_new_tokens"]) | |
repetition_penalty = form.slider(label="repetition_penalty", min_value=0.0, max_value=5.0, step=0.01, value=st.session_state["repetition_penalty"]) | |
submitted = form.form_submit_button("Start session") | |
if submitted: | |
print('Parameters updated...') | |
st.session_state['session_started'] = True | |
st.session_state["session_events"] = [] | |
st.session_state["model_name_or_path"] = model_name_or_path | |
st.session_state["temperature"] = temperature | |
st.session_state["do_sample"] = do_sample | |
st.session_state["top_p"] = top_p | |
st.session_state["top_k"] = top_k | |
st.session_state["max_new_tokens"] = max_new_tokens | |
st.session_state["repetition_penalty"] = repetition_penalty | |
st.rerun() | |
def init_session_state(): | |
print('init_session_state()') | |
st.session_state['session_started'] = False | |
st.session_state["session_events"] = [] | |
st.session_state["model_name_or_path"] = "TheBloke/meditron-7B-GPTQ" | |
st.session_state["temperature"] = 0.01 | |
st.session_state["do_sample"] = True | |
st.session_state["top_p"] = 0.95 | |
st.session_state["top_k"] = 40 | |
st.session_state["max_new_tokens"] = 512 | |
st.session_state["repetition_penalty"] = 1.1 | |
st.session_state["system_prompt"] = "You are a medical expert that provides answers for a medically trained audience" | |
st.session_state["prompt"] = "" | |
st.session_state["llm_messages"] = [] | |
def get_genders(): | |
return ['Male', 'Female'] | |
def display_session_overview(): | |
st.subheader('History of LLM queries') | |
st.write(st.session_state["llm_messages"]) | |
st.subheader("Session costs overview") | |
df_session_overview = pd.DataFrame.from_dict(st.session_state["session_events"]) | |
st.write(df_session_overview) | |
if "prompt_tokens" in df_session_overview: | |
prompt_tokens = df_session_overview["prompt_tokens"].sum() | |
st.write("Prompt tokens: " + str(prompt_tokens)) | |
prompt_cost = df_session_overview["prompt_cost_chf"].sum() | |
st.write("Prompt CHF: " + str(prompt_cost)) | |
completion_tokens = df_session_overview["completion_tokens"].sum() | |
st.write("Completion tokens: " + str(completion_tokens)) | |
completion_cost = df_session_overview["completion_cost_chf"].sum() | |
st.write("Completion CHF: " + str(completion_cost)) | |
completion_cost = df_session_overview["total_cost_chf"].sum() | |
st.write("Total costs CHF: " + str(completion_cost)) | |
total_time = df_session_overview["response_time"].sum() | |
st.write("Total compute time (ms): " + str(total_time)) | |
def plot_report(title, expected, predicted, display_labels): | |
st.markdown('#### ' + title) | |
conf_matrix = confusion_matrix(expected, predicted, labels=display_labels) | |
conf_matrix_plot = ConfusionMatrixDisplay(confusion_matrix=conf_matrix, display_labels=display_labels) | |
conf_matrix_plot.plot() | |
st.pyplot(plt.gcf()) | |
report = classification_report(expected, predicted, output_dict=True) | |
df_report = pd.DataFrame(report).transpose() | |
st.write(df_report) | |
df_rp = df_report | |
df_rp = df_rp.drop('support', axis=1) | |
df_rp = df_rp.drop(['accuracy', 'macro avg', 'weighted avg']) | |
try: | |
ax = df_rp.plot(kind="bar", legend=True) | |
for container in ax.containers: | |
ax.bar_label(container, fontsize=7) | |
plt.xticks(rotation=45) | |
plt.legend(loc=(1.04, 0)) | |
st.pyplot(plt.gcf()) | |
except Exception as e: | |
# Out of bounds | |
pass | |
def get_prompt_format(model_name): | |
if model_name == "TheBloke/Llama-2-13B-chat-GPTQ" or model_name== "TheBloke/Llama-2-7B-Chat-GPTQ": | |
return '''[INST] <<SYS>> | |
{system_message} | |
<</SYS>> | |
{prompt}[/INST] | |
''' | |
if model_name == "TheBloke/meditron-7B-GPTQ" or model_name == "TheBloke/meditron-70B-GPTQ": | |
return '''<|im_start|>system | |
{system_message}<|im_end|> | |
<|im_start|>user | |
{prompt}<|im_end|> | |
<|im_start|>assistant''' | |
return "" | |
def format_prompt(template, system_message, prompt): | |
if template == "": | |
return f"{system_message} {prompt}" | |
return template.format(system_message=system_message, prompt=prompt) | |
def display_llm_output(): | |
st.header("LLM") | |
form = st.form('llm') | |
prompt_format_str = get_prompt_format(st.session_state["model_name_or_path"]) | |
prompt_format = form.text_area('Prompt format', value=prompt_format_str) | |
system_prompt = form.text_area('System prompt', value=st.session_state["system_prompt"]) | |
prompt = form.text_area('Prompt', value=st.session_state["prompt"]) | |
submitted = form.form_submit_button('Submit') | |
if submitted: | |
st.session_state["system_prompt"] = system_prompt | |
st.session_state["prompt"] = prompt | |
formatted_prompt = format_prompt(prompt_format, system_prompt, prompt) | |
print(f"Formatted prompt: {format_prompt}") | |
llm_response = get_llm_response( | |
st.session_state["model_name_or_path"], | |
st.session_state["temperature"], | |
st.session_state["do_sample"], | |
st.session_state["top_p"], | |
st.session_state["top_k"], | |
st.session_state["max_new_tokens"], | |
st.session_state["repetition_penalty"], | |
formatted_prompt) | |
st.write(llm_response) | |
def main(): | |
print('Running Local LLM PoC Streamlit app...') | |
session_inactive_info = st.empty() | |
if "session_started" not in st.session_state or not st.session_state["session_started"]: | |
init_session_state() | |
display_streamlit_sidebar() | |
else: | |
display_streamlit_sidebar() | |
session_inactive_info.empty() | |
display_llm_output() | |
display_session_overview() | |
if __name__ == '__main__': | |
main() | |