Robin Genolet commited on
Commit
73c445e
·
1 Parent(s): 6d70053

fix: prompt

Browse files
Files changed (1) hide show
  1. app.py +4 -8
app.py CHANGED
@@ -184,16 +184,14 @@ def plot_report(title, expected, predicted, display_labels):
184
 
185
 
186
  def get_prompt_format(model_name):
187
- if model_name == "TheBloke/Llama-2-13B-chat-GPTQ":
188
  return '''[INST] <<SYS>>
189
- You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
190
  <</SYS>>
191
  {prompt}[/INST]
192
 
193
  '''
194
- if model_name == "TheBloke/Llama-2-7B-Chat-GPTQ":
195
- return "[INST] <<SYS>>You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n{prompt}[/INST]"
196
-
197
  if model_name == "TheBloke/meditron-7B-GPTQ" or model_name == "TheBloke/meditron-70B-GPTQ":
198
  return '''<|im_start|>system
199
  {system_message}<|im_end|>
@@ -215,7 +213,7 @@ def display_llm_output():
215
 
216
  prompt_format_str = get_prompt_format(st.session_state["model_name_or_path"])
217
  prompt_format = form.text_area('Prompt format', value=prompt_format_str)
218
- system_prompt = ""#form.text_area('System prompt', value=st.session_state["system_prompt"])
219
  prompt = form.text_area('Prompt', value=st.session_state["prompt"])
220
 
221
  submitted = form.form_submit_button('Submit')
@@ -225,7 +223,6 @@ def display_llm_output():
225
  st.session_state["prompt"] = prompt
226
  formatted_prompt = format_prompt(prompt_format, system_prompt, prompt)
227
  print(f"Formatted prompt: {format_prompt}")
228
- print(f"top_k: {st.session_state['top_k']}")
229
  llm_response = get_llm_response(
230
  st.session_state["model_name_or_path"],
231
  st.session_state["temperature"],
@@ -236,7 +233,6 @@ def display_llm_output():
236
  st.session_state["repetition_penalty"],
237
  formatted_prompt)
238
  st.write(llm_response)
239
- st.write('Done displaying LLM response')
240
 
241
  def main():
242
  print('Running Local LLM PoC Streamlit app...')
 
184
 
185
 
186
  def get_prompt_format(model_name):
187
+ if model_name == "TheBloke/Llama-2-13B-chat-GPTQ" or model_name== "TheBloke/Llama-2-7B-Chat-GPTQ":
188
  return '''[INST] <<SYS>>
189
+ {system_message}
190
  <</SYS>>
191
  {prompt}[/INST]
192
 
193
  '''
194
+
 
 
195
  if model_name == "TheBloke/meditron-7B-GPTQ" or model_name == "TheBloke/meditron-70B-GPTQ":
196
  return '''<|im_start|>system
197
  {system_message}<|im_end|>
 
213
 
214
  prompt_format_str = get_prompt_format(st.session_state["model_name_or_path"])
215
  prompt_format = form.text_area('Prompt format', value=prompt_format_str)
216
+ system_prompt = form.text_area('System prompt', value=st.session_state["system_prompt"])
217
  prompt = form.text_area('Prompt', value=st.session_state["prompt"])
218
 
219
  submitted = form.form_submit_button('Submit')
 
223
  st.session_state["prompt"] = prompt
224
  formatted_prompt = format_prompt(prompt_format, system_prompt, prompt)
225
  print(f"Formatted prompt: {format_prompt}")
 
226
  llm_response = get_llm_response(
227
  st.session_state["model_name_or_path"],
228
  st.session_state["temperature"],
 
233
  st.session_state["repetition_penalty"],
234
  formatted_prompt)
235
  st.write(llm_response)
 
236
 
237
  def main():
238
  print('Running Local LLM PoC Streamlit app...')