Spaces:
Sleeping
Sleeping
Brandon Toews
commited on
Commit
·
68e5278
1
Parent(s):
4f422e9
Add application file
Browse files
app.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from openai import OpenAI
|
3 |
+
|
4 |
+
# Streamlit page configuration
|
5 |
+
st.set_page_config(page_title="Financial Virtual Assistant", layout="wide")
|
6 |
+
|
7 |
+
# Initialize session state for chat history and API client
|
8 |
+
if "messages" not in st.session_state:
|
9 |
+
st.session_state.messages = []
|
10 |
+
if "client" not in st.session_state:
|
11 |
+
base_url = f"https://brandontoews--vllm-openai-compatible-serve.modal.run/v1"
|
12 |
+
|
13 |
+
# Initialize OpenAI client
|
14 |
+
st.session_state.client = OpenAI(
|
15 |
+
api_key=st.secrets["openai_api_key"], # Replace with your API key or use modal.Secret
|
16 |
+
base_url=base_url
|
17 |
+
)
|
18 |
+
if "models" not in st.session_state:
|
19 |
+
# Fetch available models from the server
|
20 |
+
try:
|
21 |
+
models = st.session_state.client.models.list().data
|
22 |
+
st.session_state.models = [model.id for model in models]
|
23 |
+
except Exception as e:
|
24 |
+
st.session_state.models = ["neuralmagic/Mistral-7B-Instruct-v0.3-quantized.w8a16"] # Fallback if fetch fails
|
25 |
+
st.warning(f"Failed to fetch models: {e}. Using default model.")
|
26 |
+
|
27 |
+
|
28 |
+
# Function to estimate token count (heuristic: ~4 chars per token)
|
29 |
+
def estimate_token_count(messages):
|
30 |
+
total_chars = sum(len(message["content"]) for message in messages)
|
31 |
+
return total_chars // 4 # Approximate: 4 characters per token
|
32 |
+
|
33 |
+
|
34 |
+
# Function to truncate messages if token count exceeds limit
|
35 |
+
def truncate_messages(messages, max_tokens=2048, keep_last_n=5):
|
36 |
+
# Always keep the system prompt (if present) and the last N messages
|
37 |
+
system_prompt = [msg for msg in messages if msg["role"] == "system"]
|
38 |
+
non_system_messages = [msg for msg in messages if msg["role"] != "system"]
|
39 |
+
|
40 |
+
# Estimate current token count
|
41 |
+
current_tokens = estimate_token_count(messages)
|
42 |
+
|
43 |
+
# If under the limit, no truncation needed
|
44 |
+
if current_tokens <= max_tokens:
|
45 |
+
return messages
|
46 |
+
|
47 |
+
# Truncate older non-system messages, keeping the last N
|
48 |
+
truncated_non_system_messages = non_system_messages[-keep_last_n:] if len(
|
49 |
+
non_system_messages) > keep_last_n else non_system_messages
|
50 |
+
|
51 |
+
# Reconstruct messages: system prompt (if any) + truncated non-system messages
|
52 |
+
return system_prompt + truncated_non_system_messages
|
53 |
+
|
54 |
+
|
55 |
+
# Function to get completion from vLLM server
|
56 |
+
def get_completion(client, model_id, messages, stream=True, temperature=0.2, top_p=0.85, max_tokens=512):
|
57 |
+
completion_args = {
|
58 |
+
"model": model_id,
|
59 |
+
"messages": messages,
|
60 |
+
"temperature": temperature,
|
61 |
+
"top_p": top_p,
|
62 |
+
"max_tokens": max_tokens,
|
63 |
+
"stream": stream,
|
64 |
+
}
|
65 |
+
try:
|
66 |
+
response = client.chat.completions.create(**completion_args)
|
67 |
+
return response
|
68 |
+
except Exception as e:
|
69 |
+
st.error(f"Error during API call: {e}")
|
70 |
+
return None
|
71 |
+
|
72 |
+
|
73 |
+
# Sidebar for configuration
|
74 |
+
with st.sidebar:
|
75 |
+
st.header("Chat Settings")
|
76 |
+
|
77 |
+
# Model selection dropdown
|
78 |
+
model_id = st.selectbox(
|
79 |
+
"Select Model",
|
80 |
+
options=st.session_state.models,
|
81 |
+
index=0,
|
82 |
+
help="Choose a model available on the vLLM server"
|
83 |
+
)
|
84 |
+
|
85 |
+
# System prompt input
|
86 |
+
system_prompt = st.text_area(
|
87 |
+
"System Prompt",
|
88 |
+
value="You are a finance expert, providing clear, accurate, and concise answers to financial questions.",
|
89 |
+
height=100,
|
90 |
+
help="Enter a system prompt to guide the model's behavior (optional)"
|
91 |
+
)
|
92 |
+
if st.button("Apply System Prompt"):
|
93 |
+
if st.session_state.messages and st.session_state.messages[0]["role"] == "system":
|
94 |
+
st.session_state.messages[0] = {"role": "system", "content": system_prompt}
|
95 |
+
else:
|
96 |
+
st.session_state.messages.insert(0, {"role": "system", "content": system_prompt})
|
97 |
+
st.success("System prompt updated!")
|
98 |
+
|
99 |
+
# Other settings
|
100 |
+
temperature = st.slider("Temperature", 0.0, 1.0, 0.2, help="Controls randomness of responses")
|
101 |
+
top_p = st.slider("Top P", 0.0, 1.0, 0.85, help="Controls diversity via nucleus sampling")
|
102 |
+
max_tokens = st.number_input("Max Tokens", min_value=1, value=512, help="Maximum length of response (optional)")
|
103 |
+
if st.button("Clear Chat"):
|
104 |
+
st.session_state.messages = (
|
105 |
+
[{"role": "system", "content": system_prompt}] if system_prompt else []
|
106 |
+
)
|
107 |
+
|
108 |
+
# Main chat interface
|
109 |
+
st.title("Financial Virtual Assistant")
|
110 |
+
st.write("Chat with a finance-tuned LLM powered by vLLM on Modal. Select a model and customize your system prompt!")
|
111 |
+
|
112 |
+
# Display chat history
|
113 |
+
for message in st.session_state.messages:
|
114 |
+
if message["role"] == "system":
|
115 |
+
with st.expander("System Prompt", expanded=False):
|
116 |
+
st.markdown(f"**System**: {message['content']}")
|
117 |
+
else:
|
118 |
+
with st.chat_message(message["role"]):
|
119 |
+
st.markdown(message["content"])
|
120 |
+
|
121 |
+
# User input
|
122 |
+
if prompt := st.chat_input("Type your message here..."):
|
123 |
+
# Add user message to history
|
124 |
+
st.session_state.messages.append({"role": "user", "content": prompt})
|
125 |
+
with st.chat_message("user"):
|
126 |
+
st.markdown(prompt)
|
127 |
+
|
128 |
+
# Truncate messages if necessary to stay under token limit
|
129 |
+
st.session_state.messages = truncate_messages(
|
130 |
+
st.session_state.messages,
|
131 |
+
max_tokens=2048 - max_tokens, # Reserve space for the output
|
132 |
+
keep_last_n=5
|
133 |
+
)
|
134 |
+
# Debug: Log token count and messages
|
135 |
+
current_tokens = estimate_token_count(st.session_state.messages)
|
136 |
+
st.write(f"Debug: Current token count: {current_tokens}")
|
137 |
+
st.write(f"Debug: Messages sent to model: {st.session_state.messages}")
|
138 |
+
|
139 |
+
# Get and display assistant response
|
140 |
+
with st.chat_message("assistant"):
|
141 |
+
response = get_completion(
|
142 |
+
st.session_state.client,
|
143 |
+
model_id,
|
144 |
+
st.session_state.messages,
|
145 |
+
stream=True,
|
146 |
+
temperature=temperature,
|
147 |
+
top_p=top_p,
|
148 |
+
max_tokens=max_tokens
|
149 |
+
)
|
150 |
+
if response:
|
151 |
+
# Stream the response
|
152 |
+
placeholder = st.empty()
|
153 |
+
assistant_message = ""
|
154 |
+
for chunk in response:
|
155 |
+
if chunk.choices[0].delta.content:
|
156 |
+
assistant_message += chunk.choices[0].delta.content
|
157 |
+
placeholder.markdown(assistant_message + "▌") # Cursor effect
|
158 |
+
placeholder.markdown(assistant_message) # Final message without cursor
|
159 |
+
st.session_state.messages.append({"role": "assistant", "content": assistant_message})
|
160 |
+
else:
|
161 |
+
st.error("Failed to get a response from the server.")
|
162 |
+
|
163 |
+
# Instructions
|
164 |
+
st.caption("Built with Streamlit and vLLM on Modal. Adjust settings in the sidebar and chat away!")
|