Brandon Toews commited on
Commit
156ed6f
·
1 Parent(s): 4c2fd83

Add secure application file

Browse files
Files changed (1) hide show
  1. app.py +181 -156
app.py CHANGED
@@ -1,164 +1,189 @@
1
  import streamlit as st
2
  from openai import OpenAI
3
 
4
- # Streamlit page configuration
5
- st.set_page_config(page_title="Financial Virtual Assistant", layout="wide")
6
-
7
- # Initialize session state for chat history and API client
8
- if "messages" not in st.session_state:
9
- st.session_state.messages = []
10
- if "client" not in st.session_state:
11
- base_url = f"https://brandontoews--vllm-openai-compatible-serve.modal.run/v1"
12
-
13
- # Initialize OpenAI client
14
- st.session_state.client = OpenAI(
15
- api_key=st.secrets["openai_api_key"], # Replace with your API key or use modal.Secret
16
- base_url=base_url
17
- )
18
- if "models" not in st.session_state:
19
- # Fetch available models from the server
20
- try:
21
- models = st.session_state.client.models.list().data
22
- st.session_state.models = [model.id for model in models]
23
- except Exception as e:
24
- st.session_state.models = ["neuralmagic/Mistral-7B-Instruct-v0.3-quantized.w8a16"] # Fallback if fetch fails
25
- st.warning(f"Failed to fetch models: {e}. Using default model.")
26
-
27
-
28
- # Function to estimate token count (heuristic: ~4 chars per token)
29
- def estimate_token_count(messages):
30
- total_chars = sum(len(message["content"]) for message in messages)
31
- return total_chars // 4 # Approximate: 4 characters per token
32
-
33
-
34
- # Function to truncate messages if token count exceeds limit
35
- def truncate_messages(messages, max_tokens=2048, keep_last_n=5):
36
- # Always keep the system prompt (if present) and the last N messages
37
- system_prompt = [msg for msg in messages if msg["role"] == "system"]
38
- non_system_messages = [msg for msg in messages if msg["role"] != "system"]
39
-
40
- # Estimate current token count
41
- current_tokens = estimate_token_count(messages)
42
-
43
- # If under the limit, no truncation needed
44
- if current_tokens <= max_tokens:
45
- return messages
46
-
47
- # Truncate older non-system messages, keeping the last N
48
- truncated_non_system_messages = non_system_messages[-keep_last_n:] if len(
49
- non_system_messages) > keep_last_n else non_system_messages
50
-
51
- # Reconstruct messages: system prompt (if any) + truncated non-system messages
52
- return system_prompt + truncated_non_system_messages
53
-
54
-
55
- # Function to get completion from vLLM server
56
- def get_completion(client, model_id, messages, stream=True, temperature=0.2, top_p=0.85, max_tokens=512):
57
- completion_args = {
58
- "model": model_id,
59
- "messages": messages,
60
- "temperature": temperature,
61
- "top_p": top_p,
62
- "max_tokens": max_tokens,
63
- "stream": stream,
64
- }
65
- try:
66
- response = client.chat.completions.create(**completion_args)
67
- return response
68
- except Exception as e:
69
- st.error(f"Error during API call: {e}")
70
- return None
71
-
72
-
73
- # Sidebar for configuration
74
- with st.sidebar:
75
- st.header("Chat Settings")
76
-
77
- # Model selection dropdown
78
- model_id = st.selectbox(
79
- "Select Model",
80
- options=st.session_state.models,
81
- index=0,
82
- help="Choose a model available on the vLLM server"
83
- )
84
-
85
- # System prompt input
86
- system_prompt = st.text_area(
87
- "System Prompt",
88
- value="You are a finance expert, providing clear, accurate, and concise answers to financial questions.",
89
- height=100,
90
- help="Enter a system prompt to guide the model's behavior (optional)"
91
- )
92
- if st.button("Apply System Prompt"):
93
- if st.session_state.messages and st.session_state.messages[0]["role"] == "system":
94
- st.session_state.messages[0] = {"role": "system", "content": system_prompt}
95
  else:
96
- st.session_state.messages.insert(0, {"role": "system", "content": system_prompt})
97
- st.success("System prompt updated!")
98
-
99
- # Other settings
100
- temperature = st.slider("Temperature", 0.0, 1.0, 0.2, help="Controls randomness of responses")
101
- top_p = st.slider("Top P", 0.0, 1.0, 0.85, help="Controls diversity via nucleus sampling")
102
- max_tokens = st.number_input("Max Tokens", min_value=1, value=512, help="Maximum length of response (optional)")
103
- if st.button("Clear Chat"):
104
- st.session_state.messages = (
105
- [{"role": "system", "content": system_prompt}] if system_prompt else []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  )
107
 
108
- # Main chat interface
109
- st.title("Financial Virtual Assistant")
110
- st.write("Chat with a finance-tuned LLM powered by vLLM on Modal. Select a model and customize your system prompt!")
111
-
112
- # Display chat history
113
- for message in st.session_state.messages:
114
- if message["role"] == "system":
115
- with st.expander("System Prompt", expanded=False):
116
- st.markdown(f"**System**: {message['content']}")
117
- else:
118
- with st.chat_message(message["role"]):
119
- st.markdown(message["content"])
120
-
121
- # User input
122
- if prompt := st.chat_input("Type your message here..."):
123
- # Add user message to history
124
- st.session_state.messages.append({"role": "user", "content": prompt})
125
- with st.chat_message("user"):
126
- st.markdown(prompt)
127
-
128
- # Truncate messages if necessary to stay under token limit
129
- st.session_state.messages = truncate_messages(
130
- st.session_state.messages,
131
- max_tokens=2048 - max_tokens, # Reserve space for the output
132
- keep_last_n=5
133
- )
134
- # Debug: Log token count and messages
135
- current_tokens = estimate_token_count(st.session_state.messages)
136
- st.write(f"Debug: Current token count: {current_tokens}")
137
- st.write(f"Debug: Messages sent to model: {st.session_state.messages}")
138
-
139
- # Get and display assistant response
140
- with st.chat_message("assistant"):
141
- response = get_completion(
142
- st.session_state.client,
143
- model_id,
144
- st.session_state.messages,
145
- stream=True,
146
- temperature=temperature,
147
- top_p=top_p,
148
- max_tokens=max_tokens
149
  )
150
- if response:
151
- # Stream the response
152
- placeholder = st.empty()
153
- assistant_message = ""
154
- for chunk in response:
155
- if chunk.choices[0].delta.content:
156
- assistant_message += chunk.choices[0].delta.content
157
- placeholder.markdown(assistant_message + "▌") # Cursor effect
158
- placeholder.markdown(assistant_message) # Final message without cursor
159
- st.session_state.messages.append({"role": "assistant", "content": assistant_message})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  else:
161
- st.error("Failed to get a response from the server.")
162
-
163
- # Instructions
164
- st.caption("Built with Streamlit and vLLM on Modal. Adjust settings in the sidebar and chat away!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  from openai import OpenAI
3
 
4
+ # Authentication function
5
+ def authenticate():
6
+ st.title("Financial Virtual Assistant")
7
+ st.subheader("Login")
8
+
9
+ username = st.text_input("Username")
10
+ password = st.text_input("Password", type="password")
11
+
12
+ if st.button("Login"):
13
+ if username == os.getenv('username') and password == os.getenv('password'):
14
+ st.session_state.authenticated = True
15
+ return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  else:
17
+ st.error("Invalid username or password")
18
+ return False
19
+
20
+
21
+ # Check authentication state
22
+ if "authenticated" not in st.session_state:
23
+ st.session_state.authenticated = False
24
+
25
+ if not st.session_state.authenticated:
26
+ if authenticate():
27
+ st.rerun()
28
+ else:
29
+ # Streamlit page configuration
30
+ st.set_page_config(page_title="Financial Virtual Assistant", layout="wide")
31
+
32
+ # Initialize session state for chat history and API client
33
+ if "messages" not in st.session_state:
34
+ st.session_state.messages = []
35
+ if "client" not in st.session_state:
36
+ base_url = f"https://brandontoews--vllm-openai-compatible-serve.modal.run/v1"
37
+
38
+ # Initialize OpenAI client
39
+ st.session_state.client = OpenAI(
40
+ api_key=os.getenv('openai_api_key'), # Replace with your API key or use modal.Secret
41
+ base_url=base_url
42
+ )
43
+ if "models" not in st.session_state:
44
+ # Fetch available models from the server
45
+ try:
46
+ models = st.session_state.client.models.list().data
47
+ st.session_state.models = [model.id for model in models]
48
+ except Exception as e:
49
+ st.session_state.models = ["neuralmagic/Mistral-7B-Instruct-v0.3-quantized.w8a16"] # Fallback if fetch fails
50
+ st.warning(f"Failed to fetch models: {e}. Using default model.")
51
+
52
+
53
+ # Function to estimate token count (heuristic: ~4 chars per token)
54
+ def estimate_token_count(messages):
55
+ total_chars = sum(len(message["content"]) for message in messages)
56
+ return total_chars // 4 # Approximate: 4 characters per token
57
+
58
+
59
+ # Function to truncate messages if token count exceeds limit
60
+ def truncate_messages(messages, max_tokens=2048, keep_last_n=5):
61
+ # Always keep the system prompt (if present) and the last N messages
62
+ system_prompt = [msg for msg in messages if msg["role"] == "system"]
63
+ non_system_messages = [msg for msg in messages if msg["role"] != "system"]
64
+
65
+ # Estimate current token count
66
+ current_tokens = estimate_token_count(messages)
67
+
68
+ # If under the limit, no truncation needed
69
+ if current_tokens <= max_tokens:
70
+ return messages
71
+
72
+ # Truncate older non-system messages, keeping the last N
73
+ truncated_non_system_messages = non_system_messages[-keep_last_n:] if len(
74
+ non_system_messages) > keep_last_n else non_system_messages
75
+
76
+ # Reconstruct messages: system prompt (if any) + truncated non-system messages
77
+ return system_prompt + truncated_non_system_messages
78
+
79
+
80
+ # Function to get completion from vLLM server
81
+ def get_completion(client, model_id, messages, stream=True, temperature=0.2, top_p=0.85, max_tokens=512):
82
+ completion_args = {
83
+ "model": model_id,
84
+ "messages": messages,
85
+ "temperature": temperature,
86
+ "top_p": top_p,
87
+ "max_tokens": max_tokens,
88
+ "stream": stream,
89
+ }
90
+ try:
91
+ response = client.chat.completions.create(**completion_args)
92
+ return response
93
+ except Exception as e:
94
+ st.error(f"Error during API call: {e}")
95
+ return None
96
+
97
+
98
+ # Sidebar for configuration
99
+ with st.sidebar:
100
+ st.header("Chat Settings")
101
+
102
+ # Model selection dropdown
103
+ model_id = st.selectbox(
104
+ "Select Model",
105
+ options=st.session_state.models,
106
+ index=0,
107
+ help="Choose a model available on the vLLM server"
108
  )
109
 
110
+ # System prompt input
111
+ system_prompt = st.text_area(
112
+ "System Prompt",
113
+ value="You are a finance expert, providing clear, accurate, and concise answers to financial questions.",
114
+ height=100,
115
+ help="Enter a system prompt to guide the model's behavior (optional)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  )
117
+ if st.button("Apply System Prompt"):
118
+ if st.session_state.messages and st.session_state.messages[0]["role"] == "system":
119
+ st.session_state.messages[0] = {"role": "system", "content": system_prompt}
120
+ else:
121
+ st.session_state.messages.insert(0, {"role": "system", "content": system_prompt})
122
+ st.success("System prompt updated!")
123
+
124
+ # Other settings
125
+ temperature = st.slider("Temperature", 0.0, 1.0, 0.2, help="Controls randomness of responses")
126
+ top_p = st.slider("Top P", 0.0, 1.0, 0.85, help="Controls diversity via nucleus sampling")
127
+ max_tokens = st.number_input("Max Tokens", min_value=1, value=512, help="Maximum length of response (optional)")
128
+ if st.button("Clear Chat"):
129
+ st.session_state.messages = (
130
+ [{"role": "system", "content": system_prompt}] if system_prompt else []
131
+ )
132
+
133
+ # Main chat interface
134
+ st.title("Financial Virtual Assistant")
135
+ st.write("Chat with a finance-tuned LLM powered by vLLM on Modal. Select a model and customize your system prompt!")
136
+
137
+ # Display chat history
138
+ for message in st.session_state.messages:
139
+ if message["role"] == "system":
140
+ with st.expander("System Prompt", expanded=False):
141
+ st.markdown(f"**System**: {message['content']}")
142
  else:
143
+ with st.chat_message(message["role"]):
144
+ st.markdown(message["content"])
145
+
146
+ # User input
147
+ if prompt := st.chat_input("Type your message here..."):
148
+ # Add user message to history
149
+ st.session_state.messages.append({"role": "user", "content": prompt})
150
+ with st.chat_message("user"):
151
+ st.markdown(prompt)
152
+
153
+ # Truncate messages if necessary to stay under token limit
154
+ st.session_state.messages = truncate_messages(
155
+ st.session_state.messages,
156
+ max_tokens=2048 - max_tokens, # Reserve space for the output
157
+ keep_last_n=5
158
+ )
159
+ # Debug: Log token count and messages
160
+ current_tokens = estimate_token_count(st.session_state.messages)
161
+ st.write(f"Debug: Current token count: {current_tokens}")
162
+ st.write(f"Debug: Messages sent to model: {st.session_state.messages}")
163
+
164
+ # Get and display assistant response
165
+ with st.chat_message("assistant"):
166
+ response = get_completion(
167
+ st.session_state.client,
168
+ model_id,
169
+ st.session_state.messages,
170
+ stream=True,
171
+ temperature=temperature,
172
+ top_p=top_p,
173
+ max_tokens=max_tokens
174
+ )
175
+ if response:
176
+ # Stream the response
177
+ placeholder = st.empty()
178
+ assistant_message = ""
179
+ for chunk in response:
180
+ if chunk.choices[0].delta.content:
181
+ assistant_message += chunk.choices[0].delta.content
182
+ placeholder.markdown(assistant_message + "▌") # Cursor effect
183
+ placeholder.markdown(assistant_message) # Final message without cursor
184
+ st.session_state.messages.append({"role": "assistant", "content": assistant_message})
185
+ else:
186
+ st.error("Failed to get a response from the server.")
187
+
188
+ # Instructions
189
+ st.caption("Built with Streamlit and vLLM on Modal. Adjust settings in the sidebar and chat away!")