Spaces:
Sleeping
Sleeping
Brandon Toews
commited on
Commit
·
156ed6f
1
Parent(s):
4c2fd83
Add secure application file
Browse files
app.py
CHANGED
@@ -1,164 +1,189 @@
|
|
1 |
import streamlit as st
|
2 |
from openai import OpenAI
|
3 |
|
4 |
-
#
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
base_url=base_url
|
17 |
-
)
|
18 |
-
if "models" not in st.session_state:
|
19 |
-
# Fetch available models from the server
|
20 |
-
try:
|
21 |
-
models = st.session_state.client.models.list().data
|
22 |
-
st.session_state.models = [model.id for model in models]
|
23 |
-
except Exception as e:
|
24 |
-
st.session_state.models = ["neuralmagic/Mistral-7B-Instruct-v0.3-quantized.w8a16"] # Fallback if fetch fails
|
25 |
-
st.warning(f"Failed to fetch models: {e}. Using default model.")
|
26 |
-
|
27 |
-
|
28 |
-
# Function to estimate token count (heuristic: ~4 chars per token)
|
29 |
-
def estimate_token_count(messages):
|
30 |
-
total_chars = sum(len(message["content"]) for message in messages)
|
31 |
-
return total_chars // 4 # Approximate: 4 characters per token
|
32 |
-
|
33 |
-
|
34 |
-
# Function to truncate messages if token count exceeds limit
|
35 |
-
def truncate_messages(messages, max_tokens=2048, keep_last_n=5):
|
36 |
-
# Always keep the system prompt (if present) and the last N messages
|
37 |
-
system_prompt = [msg for msg in messages if msg["role"] == "system"]
|
38 |
-
non_system_messages = [msg for msg in messages if msg["role"] != "system"]
|
39 |
-
|
40 |
-
# Estimate current token count
|
41 |
-
current_tokens = estimate_token_count(messages)
|
42 |
-
|
43 |
-
# If under the limit, no truncation needed
|
44 |
-
if current_tokens <= max_tokens:
|
45 |
-
return messages
|
46 |
-
|
47 |
-
# Truncate older non-system messages, keeping the last N
|
48 |
-
truncated_non_system_messages = non_system_messages[-keep_last_n:] if len(
|
49 |
-
non_system_messages) > keep_last_n else non_system_messages
|
50 |
-
|
51 |
-
# Reconstruct messages: system prompt (if any) + truncated non-system messages
|
52 |
-
return system_prompt + truncated_non_system_messages
|
53 |
-
|
54 |
-
|
55 |
-
# Function to get completion from vLLM server
|
56 |
-
def get_completion(client, model_id, messages, stream=True, temperature=0.2, top_p=0.85, max_tokens=512):
|
57 |
-
completion_args = {
|
58 |
-
"model": model_id,
|
59 |
-
"messages": messages,
|
60 |
-
"temperature": temperature,
|
61 |
-
"top_p": top_p,
|
62 |
-
"max_tokens": max_tokens,
|
63 |
-
"stream": stream,
|
64 |
-
}
|
65 |
-
try:
|
66 |
-
response = client.chat.completions.create(**completion_args)
|
67 |
-
return response
|
68 |
-
except Exception as e:
|
69 |
-
st.error(f"Error during API call: {e}")
|
70 |
-
return None
|
71 |
-
|
72 |
-
|
73 |
-
# Sidebar for configuration
|
74 |
-
with st.sidebar:
|
75 |
-
st.header("Chat Settings")
|
76 |
-
|
77 |
-
# Model selection dropdown
|
78 |
-
model_id = st.selectbox(
|
79 |
-
"Select Model",
|
80 |
-
options=st.session_state.models,
|
81 |
-
index=0,
|
82 |
-
help="Choose a model available on the vLLM server"
|
83 |
-
)
|
84 |
-
|
85 |
-
# System prompt input
|
86 |
-
system_prompt = st.text_area(
|
87 |
-
"System Prompt",
|
88 |
-
value="You are a finance expert, providing clear, accurate, and concise answers to financial questions.",
|
89 |
-
height=100,
|
90 |
-
help="Enter a system prompt to guide the model's behavior (optional)"
|
91 |
-
)
|
92 |
-
if st.button("Apply System Prompt"):
|
93 |
-
if st.session_state.messages and st.session_state.messages[0]["role"] == "system":
|
94 |
-
st.session_state.messages[0] = {"role": "system", "content": system_prompt}
|
95 |
else:
|
96 |
-
st.
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
)
|
107 |
|
108 |
-
#
|
109 |
-
st.
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
if message["role"] == "system":
|
115 |
-
with st.expander("System Prompt", expanded=False):
|
116 |
-
st.markdown(f"**System**: {message['content']}")
|
117 |
-
else:
|
118 |
-
with st.chat_message(message["role"]):
|
119 |
-
st.markdown(message["content"])
|
120 |
-
|
121 |
-
# User input
|
122 |
-
if prompt := st.chat_input("Type your message here..."):
|
123 |
-
# Add user message to history
|
124 |
-
st.session_state.messages.append({"role": "user", "content": prompt})
|
125 |
-
with st.chat_message("user"):
|
126 |
-
st.markdown(prompt)
|
127 |
-
|
128 |
-
# Truncate messages if necessary to stay under token limit
|
129 |
-
st.session_state.messages = truncate_messages(
|
130 |
-
st.session_state.messages,
|
131 |
-
max_tokens=2048 - max_tokens, # Reserve space for the output
|
132 |
-
keep_last_n=5
|
133 |
-
)
|
134 |
-
# Debug: Log token count and messages
|
135 |
-
current_tokens = estimate_token_count(st.session_state.messages)
|
136 |
-
st.write(f"Debug: Current token count: {current_tokens}")
|
137 |
-
st.write(f"Debug: Messages sent to model: {st.session_state.messages}")
|
138 |
-
|
139 |
-
# Get and display assistant response
|
140 |
-
with st.chat_message("assistant"):
|
141 |
-
response = get_completion(
|
142 |
-
st.session_state.client,
|
143 |
-
model_id,
|
144 |
-
st.session_state.messages,
|
145 |
-
stream=True,
|
146 |
-
temperature=temperature,
|
147 |
-
top_p=top_p,
|
148 |
-
max_tokens=max_tokens
|
149 |
)
|
150 |
-
if
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
else:
|
161 |
-
st.
|
162 |
-
|
163 |
-
|
164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
from openai import OpenAI
|
3 |
|
4 |
+
# Authentication function
|
5 |
+
def authenticate():
|
6 |
+
st.title("Financial Virtual Assistant")
|
7 |
+
st.subheader("Login")
|
8 |
+
|
9 |
+
username = st.text_input("Username")
|
10 |
+
password = st.text_input("Password", type="password")
|
11 |
+
|
12 |
+
if st.button("Login"):
|
13 |
+
if username == os.getenv('username') and password == os.getenv('password'):
|
14 |
+
st.session_state.authenticated = True
|
15 |
+
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
else:
|
17 |
+
st.error("Invalid username or password")
|
18 |
+
return False
|
19 |
+
|
20 |
+
|
21 |
+
# Check authentication state
|
22 |
+
if "authenticated" not in st.session_state:
|
23 |
+
st.session_state.authenticated = False
|
24 |
+
|
25 |
+
if not st.session_state.authenticated:
|
26 |
+
if authenticate():
|
27 |
+
st.rerun()
|
28 |
+
else:
|
29 |
+
# Streamlit page configuration
|
30 |
+
st.set_page_config(page_title="Financial Virtual Assistant", layout="wide")
|
31 |
+
|
32 |
+
# Initialize session state for chat history and API client
|
33 |
+
if "messages" not in st.session_state:
|
34 |
+
st.session_state.messages = []
|
35 |
+
if "client" not in st.session_state:
|
36 |
+
base_url = f"https://brandontoews--vllm-openai-compatible-serve.modal.run/v1"
|
37 |
+
|
38 |
+
# Initialize OpenAI client
|
39 |
+
st.session_state.client = OpenAI(
|
40 |
+
api_key=os.getenv('openai_api_key'), # Replace with your API key or use modal.Secret
|
41 |
+
base_url=base_url
|
42 |
+
)
|
43 |
+
if "models" not in st.session_state:
|
44 |
+
# Fetch available models from the server
|
45 |
+
try:
|
46 |
+
models = st.session_state.client.models.list().data
|
47 |
+
st.session_state.models = [model.id for model in models]
|
48 |
+
except Exception as e:
|
49 |
+
st.session_state.models = ["neuralmagic/Mistral-7B-Instruct-v0.3-quantized.w8a16"] # Fallback if fetch fails
|
50 |
+
st.warning(f"Failed to fetch models: {e}. Using default model.")
|
51 |
+
|
52 |
+
|
53 |
+
# Function to estimate token count (heuristic: ~4 chars per token)
|
54 |
+
def estimate_token_count(messages):
|
55 |
+
total_chars = sum(len(message["content"]) for message in messages)
|
56 |
+
return total_chars // 4 # Approximate: 4 characters per token
|
57 |
+
|
58 |
+
|
59 |
+
# Function to truncate messages if token count exceeds limit
|
60 |
+
def truncate_messages(messages, max_tokens=2048, keep_last_n=5):
|
61 |
+
# Always keep the system prompt (if present) and the last N messages
|
62 |
+
system_prompt = [msg for msg in messages if msg["role"] == "system"]
|
63 |
+
non_system_messages = [msg for msg in messages if msg["role"] != "system"]
|
64 |
+
|
65 |
+
# Estimate current token count
|
66 |
+
current_tokens = estimate_token_count(messages)
|
67 |
+
|
68 |
+
# If under the limit, no truncation needed
|
69 |
+
if current_tokens <= max_tokens:
|
70 |
+
return messages
|
71 |
+
|
72 |
+
# Truncate older non-system messages, keeping the last N
|
73 |
+
truncated_non_system_messages = non_system_messages[-keep_last_n:] if len(
|
74 |
+
non_system_messages) > keep_last_n else non_system_messages
|
75 |
+
|
76 |
+
# Reconstruct messages: system prompt (if any) + truncated non-system messages
|
77 |
+
return system_prompt + truncated_non_system_messages
|
78 |
+
|
79 |
+
|
80 |
+
# Function to get completion from vLLM server
|
81 |
+
def get_completion(client, model_id, messages, stream=True, temperature=0.2, top_p=0.85, max_tokens=512):
|
82 |
+
completion_args = {
|
83 |
+
"model": model_id,
|
84 |
+
"messages": messages,
|
85 |
+
"temperature": temperature,
|
86 |
+
"top_p": top_p,
|
87 |
+
"max_tokens": max_tokens,
|
88 |
+
"stream": stream,
|
89 |
+
}
|
90 |
+
try:
|
91 |
+
response = client.chat.completions.create(**completion_args)
|
92 |
+
return response
|
93 |
+
except Exception as e:
|
94 |
+
st.error(f"Error during API call: {e}")
|
95 |
+
return None
|
96 |
+
|
97 |
+
|
98 |
+
# Sidebar for configuration
|
99 |
+
with st.sidebar:
|
100 |
+
st.header("Chat Settings")
|
101 |
+
|
102 |
+
# Model selection dropdown
|
103 |
+
model_id = st.selectbox(
|
104 |
+
"Select Model",
|
105 |
+
options=st.session_state.models,
|
106 |
+
index=0,
|
107 |
+
help="Choose a model available on the vLLM server"
|
108 |
)
|
109 |
|
110 |
+
# System prompt input
|
111 |
+
system_prompt = st.text_area(
|
112 |
+
"System Prompt",
|
113 |
+
value="You are a finance expert, providing clear, accurate, and concise answers to financial questions.",
|
114 |
+
height=100,
|
115 |
+
help="Enter a system prompt to guide the model's behavior (optional)"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
)
|
117 |
+
if st.button("Apply System Prompt"):
|
118 |
+
if st.session_state.messages and st.session_state.messages[0]["role"] == "system":
|
119 |
+
st.session_state.messages[0] = {"role": "system", "content": system_prompt}
|
120 |
+
else:
|
121 |
+
st.session_state.messages.insert(0, {"role": "system", "content": system_prompt})
|
122 |
+
st.success("System prompt updated!")
|
123 |
+
|
124 |
+
# Other settings
|
125 |
+
temperature = st.slider("Temperature", 0.0, 1.0, 0.2, help="Controls randomness of responses")
|
126 |
+
top_p = st.slider("Top P", 0.0, 1.0, 0.85, help="Controls diversity via nucleus sampling")
|
127 |
+
max_tokens = st.number_input("Max Tokens", min_value=1, value=512, help="Maximum length of response (optional)")
|
128 |
+
if st.button("Clear Chat"):
|
129 |
+
st.session_state.messages = (
|
130 |
+
[{"role": "system", "content": system_prompt}] if system_prompt else []
|
131 |
+
)
|
132 |
+
|
133 |
+
# Main chat interface
|
134 |
+
st.title("Financial Virtual Assistant")
|
135 |
+
st.write("Chat with a finance-tuned LLM powered by vLLM on Modal. Select a model and customize your system prompt!")
|
136 |
+
|
137 |
+
# Display chat history
|
138 |
+
for message in st.session_state.messages:
|
139 |
+
if message["role"] == "system":
|
140 |
+
with st.expander("System Prompt", expanded=False):
|
141 |
+
st.markdown(f"**System**: {message['content']}")
|
142 |
else:
|
143 |
+
with st.chat_message(message["role"]):
|
144 |
+
st.markdown(message["content"])
|
145 |
+
|
146 |
+
# User input
|
147 |
+
if prompt := st.chat_input("Type your message here..."):
|
148 |
+
# Add user message to history
|
149 |
+
st.session_state.messages.append({"role": "user", "content": prompt})
|
150 |
+
with st.chat_message("user"):
|
151 |
+
st.markdown(prompt)
|
152 |
+
|
153 |
+
# Truncate messages if necessary to stay under token limit
|
154 |
+
st.session_state.messages = truncate_messages(
|
155 |
+
st.session_state.messages,
|
156 |
+
max_tokens=2048 - max_tokens, # Reserve space for the output
|
157 |
+
keep_last_n=5
|
158 |
+
)
|
159 |
+
# Debug: Log token count and messages
|
160 |
+
current_tokens = estimate_token_count(st.session_state.messages)
|
161 |
+
st.write(f"Debug: Current token count: {current_tokens}")
|
162 |
+
st.write(f"Debug: Messages sent to model: {st.session_state.messages}")
|
163 |
+
|
164 |
+
# Get and display assistant response
|
165 |
+
with st.chat_message("assistant"):
|
166 |
+
response = get_completion(
|
167 |
+
st.session_state.client,
|
168 |
+
model_id,
|
169 |
+
st.session_state.messages,
|
170 |
+
stream=True,
|
171 |
+
temperature=temperature,
|
172 |
+
top_p=top_p,
|
173 |
+
max_tokens=max_tokens
|
174 |
+
)
|
175 |
+
if response:
|
176 |
+
# Stream the response
|
177 |
+
placeholder = st.empty()
|
178 |
+
assistant_message = ""
|
179 |
+
for chunk in response:
|
180 |
+
if chunk.choices[0].delta.content:
|
181 |
+
assistant_message += chunk.choices[0].delta.content
|
182 |
+
placeholder.markdown(assistant_message + "▌") # Cursor effect
|
183 |
+
placeholder.markdown(assistant_message) # Final message without cursor
|
184 |
+
st.session_state.messages.append({"role": "assistant", "content": assistant_message})
|
185 |
+
else:
|
186 |
+
st.error("Failed to get a response from the server.")
|
187 |
+
|
188 |
+
# Instructions
|
189 |
+
st.caption("Built with Streamlit and vLLM on Modal. Adjust settings in the sidebar and chat away!")
|