Spaces:
Paused
Paused
Robin Genolet
commited on
Commit
·
4ac6668
1
Parent(s):
a358346
feat: memoize model
Browse files- app.py +5 -5
- utils/epfl_meditron_utils.py +10 -2
app.py
CHANGED
@@ -36,7 +36,7 @@ def display_streamlit_sidebar():
|
|
36 |
do_sample = form.checkbox('do_sample', value=st.session_state["do_sample"])
|
37 |
top_p = form.slider(label="top_p", min_value=0.0, max_value=1.0, step=0.01, value=st.session_state["top_p"])
|
38 |
top_k = form.slider(label="top_k", min_value=1, max_value=1000, step=1, value=st.session_state["top_k"])
|
39 |
-
max_new_tokens = form.slider(label="max_new_tokens", min_value=32, max_value=
|
40 |
repetition_penalty = form.slider(label="repetition_penalty", min_value=0.0, max_value=5.0, step=0.01, value=st.session_state["repetition_penalty"])
|
41 |
|
42 |
submitted = form.form_submit_button("Start session")
|
@@ -76,7 +76,7 @@ def init_session_state():
|
|
76 |
st.session_state["do_sample"] = True
|
77 |
st.session_state["top_p"] = 0.95
|
78 |
st.session_state["top_k"] = 40
|
79 |
-
st.session_state["max_new_tokens"] =
|
80 |
st.session_state["repetition_penalty"] = 1.1
|
81 |
st.session_state["system_prompt"] = "You are a medical expert that provides answers for a medically trained audience"
|
82 |
st.session_state["prompt"] = ""
|
@@ -143,9 +143,9 @@ def display_llm_output():
|
|
143 |
form = st.form('llm')
|
144 |
|
145 |
prompt_format_str = get_prompt_format(st.session_state["model_name_or_path"])
|
146 |
-
prompt_format = form.text_area('Prompt format', value=prompt_format_str)
|
147 |
-
system_prompt = form.text_area('System
|
148 |
-
prompt = form.text_area('Prompt', value=st.session_state["prompt"])
|
149 |
|
150 |
submitted = form.form_submit_button('Submit')
|
151 |
|
|
|
36 |
do_sample = form.checkbox('do_sample', value=st.session_state["do_sample"])
|
37 |
top_p = form.slider(label="top_p", min_value=0.0, max_value=1.0, step=0.01, value=st.session_state["top_p"])
|
38 |
top_k = form.slider(label="top_k", min_value=1, max_value=1000, step=1, value=st.session_state["top_k"])
|
39 |
+
max_new_tokens = form.slider(label="max_new_tokens", min_value=32, max_value=4096, step=1, value=st.session_state["max_new_tokens"])
|
40 |
repetition_penalty = form.slider(label="repetition_penalty", min_value=0.0, max_value=5.0, step=0.01, value=st.session_state["repetition_penalty"])
|
41 |
|
42 |
submitted = form.form_submit_button("Start session")
|
|
|
76 |
st.session_state["do_sample"] = True
|
77 |
st.session_state["top_p"] = 0.95
|
78 |
st.session_state["top_k"] = 40
|
79 |
+
st.session_state["max_new_tokens"] = 4096
|
80 |
st.session_state["repetition_penalty"] = 1.1
|
81 |
st.session_state["system_prompt"] = "You are a medical expert that provides answers for a medically trained audience"
|
82 |
st.session_state["prompt"] = ""
|
|
|
143 |
form = st.form('llm')
|
144 |
|
145 |
prompt_format_str = get_prompt_format(st.session_state["model_name_or_path"])
|
146 |
+
prompt_format = form.text_area('Prompt format', value=prompt_format_str, height=300)
|
147 |
+
system_prompt = form.text_area('System message', value=st.session_state["system_prompt"], height=300)
|
148 |
+
prompt = form.text_area('Prompt', value=st.session_state["prompt"], height=400)
|
149 |
|
150 |
submitted = form.form_submit_button('Submit')
|
151 |
|
utils/epfl_meditron_utils.py
CHANGED
@@ -9,11 +9,19 @@ def gptq_model_options():
|
|
9 |
"TheBloke/meditron-70B-GPTQ",
|
10 |
]
|
11 |
|
|
|
|
|
|
|
12 |
def get_llm_response(model_name_or_path, temperature, do_sample, top_p, top_k, max_new_tokens, repetition_penalty, formatted_prompt):
|
13 |
-
|
|
|
|
|
|
|
14 |
device_map="auto",
|
15 |
trust_remote_code=False,
|
16 |
revision="main")
|
|
|
|
|
17 |
|
18 |
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
|
19 |
|
@@ -30,7 +38,7 @@ def get_llm_response(model_name_or_path, temperature, do_sample, top_p, top_k, m
|
|
30 |
print("*** Pipeline:")
|
31 |
pipe = pipeline(
|
32 |
"text-generation",
|
33 |
-
model=
|
34 |
tokenizer=tokenizer,
|
35 |
max_new_tokens=max_new_tokens,
|
36 |
do_sample=do_sample,
|
|
|
9 |
"TheBloke/meditron-70B-GPTQ",
|
10 |
]
|
11 |
|
12 |
+
loaded_model = None
|
13 |
+
loaded_model_name = ""
|
14 |
+
|
15 |
def get_llm_response(model_name_or_path, temperature, do_sample, top_p, top_k, max_new_tokens, repetition_penalty, formatted_prompt):
|
16 |
+
if loaded_model != model_name_or_path:
|
17 |
+
global loaded_model
|
18 |
+
global loaded_model_name
|
19 |
+
loaded_model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
|
20 |
device_map="auto",
|
21 |
trust_remote_code=False,
|
22 |
revision="main")
|
23 |
+
loaded_model_name = model_name_or_path
|
24 |
+
|
25 |
|
26 |
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
|
27 |
|
|
|
38 |
print("*** Pipeline:")
|
39 |
pipe = pipeline(
|
40 |
"text-generation",
|
41 |
+
model=loaded_model,
|
42 |
tokenizer=tokenizer,
|
43 |
max_new_tokens=max_new_tokens,
|
44 |
do_sample=do_sample,
|