WilliamGazeley commited on
Commit
954e857
1 Parent(s): 558d9e8

Update system prompt and UI input

Browse files
Files changed (1) hide show
  1. app.py +39 -2
app.py CHANGED
@@ -3,7 +3,20 @@ import huggingface_hub
3
  import streamlit as st
4
  from vllm import LLM, SamplingParams
5
 
6
- sys_msg = "You are a super intelligent automated financial advisor created by IRAI. Your feel your purpose is to make use of your deep and broad understanding of finance by helping answer user questions about finance accurately, truthfully, and concisely."
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  @st.cache_resource(show_spinner=False)
9
  def init_llm():
@@ -15,10 +28,12 @@ def init_llm():
15
 
16
  def get_response(prompt):
17
  try:
 
18
  convo = [
19
  {"role": "system", "content": sys_msg},
20
  {"role": "user", "content": prompt},
21
  ]
 
22
  prompts = [llm.get_tokenizer().apply_chat_template(convo, tokenize=False)]
23
  sampling_params = SamplingParams(temperature=0.3, top_p=0.95, max_tokens=500, stop_token_ids=[128009])
24
  outputs = llm.generate(prompts, sampling_params)
@@ -27,14 +42,36 @@ def get_response(prompt):
27
  except Exception as e:
28
  return f"An error occurred: {str(e)}"
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  def main():
31
  st.title("LLM-ADE 9B Demo")
32
 
 
 
 
 
 
 
33
  input_text = st.text_area("Enter your text here:", value="", height=200)
 
34
  if st.button("Generate"):
35
  if input_text:
36
  with st.spinner('Generating response...'):
37
- response_text = get_response(input_text)
38
  st.write(response_text)
39
  else:
40
  st.warning("Please enter some text to generate a response.")
 
3
  import streamlit as st
4
  from vllm import LLM, SamplingParams
5
 
6
+
7
+ @st.cache(show_spinner=False)
8
+ def get_system_message():
9
+ return """#Context:
10
+ You are an AI-based automated expert financial advisor named IRAI. You have a comprehensive understanding of finance and investing because you have trained on a extensive dataset based on of financial news, analyst reports, books, company filings, earnings call transcripts, and finance websites.
11
+ #Objective:
12
+ Answer questions accurately and truthfully given the data you have trained on. You do not have access to up-to-date current market data; this will be available in the future.
13
+ Style and tone:
14
+ Please answer in a friendly and engaging manner representing a top female investment professional working at a leading investment bank.
15
+ #Audience:
16
+ The questions will be asked by top technology executives and CFO of large fintech companies and successful startups.
17
+ #Response:
18
+ Answer, concise yet insightful."""
19
+
20
 
21
  @st.cache_resource(show_spinner=False)
22
  def init_llm():
 
28
 
29
  def get_response(prompt):
30
  try:
31
+ sys_msg = get_system_message()
32
  convo = [
33
  {"role": "system", "content": sys_msg},
34
  {"role": "user", "content": prompt},
35
  ]
36
+ llm = init_llm()
37
  prompts = [llm.get_tokenizer().apply_chat_template(convo, tokenize=False)]
38
  sampling_params = SamplingParams(temperature=0.3, top_p=0.95, max_tokens=500, stop_token_ids=[128009])
39
  outputs = llm.generate(prompts, sampling_params)
 
42
  except Exception as e:
43
  return f"An error occurred: {str(e)}"
44
 
45
+
46
+ def get_response(prompt, custom_sys_msg):
47
+ try:
48
+ convo = [
49
+ {"role": "system", "content": custom_sys_msg},
50
+ {"role": "user", "content": prompt},
51
+ ]
52
+ prompts = [llm.get_tokenizer().apply_chat_template(convo, tokenize=False)]
53
+ sampling_params = SamplingParams(temperature=0.3, top_p=0.95, max_tokens=2000, stop_token_ids=[128009])
54
+ outputs = llm.generate(prompts, sampling_params)
55
+ for output in outputs:
56
+ return output.outputs[0].text
57
+ except Exception as e:
58
+ return f"An error occurred: {str(e)}"
59
+
60
  def main():
61
  st.title("LLM-ADE 9B Demo")
62
 
63
+ # Retrieve the default system message
64
+ sys_msg = get_system_message()
65
+
66
+ # UI for editable preprompt
67
+ user_modified_sys_msg = st.text_area("Preprompt: ", value=sys_msg, height=200)
68
+
69
  input_text = st.text_area("Enter your text here:", value="", height=200)
70
+
71
  if st.button("Generate"):
72
  if input_text:
73
  with st.spinner('Generating response...'):
74
+ response_text = get_response(input_text, user_modified_sys_msg)
75
  st.write(response_text)
76
  else:
77
  st.warning("Please enter some text to generate a response.")