Robin Genolet commited on
Commit
4ac6668
·
1 Parent(s): a358346

feat: memoize model

Browse files
Files changed (2) hide show
  1. app.py +5 -5
  2. utils/epfl_meditron_utils.py +10 -2
app.py CHANGED
@@ -36,7 +36,7 @@ def display_streamlit_sidebar():
36
  do_sample = form.checkbox('do_sample', value=st.session_state["do_sample"])
37
  top_p = form.slider(label="top_p", min_value=0.0, max_value=1.0, step=0.01, value=st.session_state["top_p"])
38
  top_k = form.slider(label="top_k", min_value=1, max_value=1000, step=1, value=st.session_state["top_k"])
39
- max_new_tokens = form.slider(label="max_new_tokens", min_value=32, max_value=16384, step=1, value=st.session_state["max_new_tokens"])
40
  repetition_penalty = form.slider(label="repetition_penalty", min_value=0.0, max_value=5.0, step=0.01, value=st.session_state["repetition_penalty"])
41
 
42
  submitted = form.form_submit_button("Start session")
@@ -76,7 +76,7 @@ def init_session_state():
76
  st.session_state["do_sample"] = True
77
  st.session_state["top_p"] = 0.95
78
  st.session_state["top_k"] = 40
79
- st.session_state["max_new_tokens"] = 512
80
  st.session_state["repetition_penalty"] = 1.1
81
  st.session_state["system_prompt"] = "You are a medical expert that provides answers for a medically trained audience"
82
  st.session_state["prompt"] = ""
@@ -143,9 +143,9 @@ def display_llm_output():
143
  form = st.form('llm')
144
 
145
  prompt_format_str = get_prompt_format(st.session_state["model_name_or_path"])
146
- prompt_format = form.text_area('Prompt format', value=prompt_format_str)
147
- system_prompt = form.text_area('System prompt', value=st.session_state["system_prompt"])
148
- prompt = form.text_area('Prompt', value=st.session_state["prompt"])
149
 
150
  submitted = form.form_submit_button('Submit')
151
 
 
36
  do_sample = form.checkbox('do_sample', value=st.session_state["do_sample"])
37
  top_p = form.slider(label="top_p", min_value=0.0, max_value=1.0, step=0.01, value=st.session_state["top_p"])
38
  top_k = form.slider(label="top_k", min_value=1, max_value=1000, step=1, value=st.session_state["top_k"])
39
+ max_new_tokens = form.slider(label="max_new_tokens", min_value=32, max_value=4096, step=1, value=st.session_state["max_new_tokens"])
40
  repetition_penalty = form.slider(label="repetition_penalty", min_value=0.0, max_value=5.0, step=0.01, value=st.session_state["repetition_penalty"])
41
 
42
  submitted = form.form_submit_button("Start session")
 
76
  st.session_state["do_sample"] = True
77
  st.session_state["top_p"] = 0.95
78
  st.session_state["top_k"] = 40
79
+ st.session_state["max_new_tokens"] = 4096
80
  st.session_state["repetition_penalty"] = 1.1
81
  st.session_state["system_prompt"] = "You are a medical expert that provides answers for a medically trained audience"
82
  st.session_state["prompt"] = ""
 
143
  form = st.form('llm')
144
 
145
  prompt_format_str = get_prompt_format(st.session_state["model_name_or_path"])
146
+ prompt_format = form.text_area('Prompt format', value=prompt_format_str, height=300)
147
+ system_prompt = form.text_area('System message', value=st.session_state["system_prompt"], height=300)
148
+ prompt = form.text_area('Prompt', value=st.session_state["prompt"], height=400)
149
 
150
  submitted = form.form_submit_button('Submit')
151
 
utils/epfl_meditron_utils.py CHANGED
@@ -9,11 +9,19 @@ def gptq_model_options():
9
  "TheBloke/meditron-70B-GPTQ",
10
  ]
11
 
 
 
 
12
  def get_llm_response(model_name_or_path, temperature, do_sample, top_p, top_k, max_new_tokens, repetition_penalty, formatted_prompt):
13
- model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
 
 
 
14
  device_map="auto",
15
  trust_remote_code=False,
16
  revision="main")
 
 
17
 
18
  tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
19
 
@@ -30,7 +38,7 @@ def get_llm_response(model_name_or_path, temperature, do_sample, top_p, top_k, m
30
  print("*** Pipeline:")
31
  pipe = pipeline(
32
  "text-generation",
33
- model=model,
34
  tokenizer=tokenizer,
35
  max_new_tokens=max_new_tokens,
36
  do_sample=do_sample,
 
9
  "TheBloke/meditron-70B-GPTQ",
10
  ]
11
 
12
+ loaded_model = None
13
+ loaded_model_name = ""
14
+
15
  def get_llm_response(model_name_or_path, temperature, do_sample, top_p, top_k, max_new_tokens, repetition_penalty, formatted_prompt):
16
+ if loaded_model != model_name_or_path:
17
+ global loaded_model
18
+ global loaded_model_name
19
+ loaded_model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
20
  device_map="auto",
21
  trust_remote_code=False,
22
  revision="main")
23
+ loaded_model_name = model_name_or_path
24
+
25
 
26
  tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
27
 
 
38
  print("*** Pipeline:")
39
  pipe = pipeline(
40
  "text-generation",
41
+ model=loaded_model,
42
  tokenizer=tokenizer,
43
  max_new_tokens=max_new_tokens,
44
  do_sample=do_sample,