Brandon Toews commited on
Commit
68e5278
·
1 Parent(s): 4f422e9

Add application file

Browse files
Files changed (1) hide show
  1. app.py +164 -0
app.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from openai import OpenAI
3
+
4
+ # Streamlit page configuration
5
+ st.set_page_config(page_title="Financial Virtual Assistant", layout="wide")
6
+
7
+ # Initialize session state for chat history and API client
8
+ if "messages" not in st.session_state:
9
+ st.session_state.messages = []
10
+ if "client" not in st.session_state:
11
+ base_url = f"https://brandontoews--vllm-openai-compatible-serve.modal.run/v1"
12
+
13
+ # Initialize OpenAI client
14
+ st.session_state.client = OpenAI(
15
+ api_key=st.secrets["openai_api_key"], # Replace with your API key or use modal.Secret
16
+ base_url=base_url
17
+ )
18
+ if "models" not in st.session_state:
19
+ # Fetch available models from the server
20
+ try:
21
+ models = st.session_state.client.models.list().data
22
+ st.session_state.models = [model.id for model in models]
23
+ except Exception as e:
24
+ st.session_state.models = ["neuralmagic/Mistral-7B-Instruct-v0.3-quantized.w8a16"] # Fallback if fetch fails
25
+ st.warning(f"Failed to fetch models: {e}. Using default model.")
26
+
27
+
28
+ # Function to estimate token count (heuristic: ~4 chars per token)
29
+ def estimate_token_count(messages):
30
+ total_chars = sum(len(message["content"]) for message in messages)
31
+ return total_chars // 4 # Approximate: 4 characters per token
32
+
33
+
34
+ # Function to truncate messages if token count exceeds limit
35
+ def truncate_messages(messages, max_tokens=2048, keep_last_n=5):
36
+ # Always keep the system prompt (if present) and the last N messages
37
+ system_prompt = [msg for msg in messages if msg["role"] == "system"]
38
+ non_system_messages = [msg for msg in messages if msg["role"] != "system"]
39
+
40
+ # Estimate current token count
41
+ current_tokens = estimate_token_count(messages)
42
+
43
+ # If under the limit, no truncation needed
44
+ if current_tokens <= max_tokens:
45
+ return messages
46
+
47
+ # Truncate older non-system messages, keeping the last N
48
+ truncated_non_system_messages = non_system_messages[-keep_last_n:] if len(
49
+ non_system_messages) > keep_last_n else non_system_messages
50
+
51
+ # Reconstruct messages: system prompt (if any) + truncated non-system messages
52
+ return system_prompt + truncated_non_system_messages
53
+
54
+
55
+ # Function to get completion from vLLM server
56
+ def get_completion(client, model_id, messages, stream=True, temperature=0.2, top_p=0.85, max_tokens=512):
57
+ completion_args = {
58
+ "model": model_id,
59
+ "messages": messages,
60
+ "temperature": temperature,
61
+ "top_p": top_p,
62
+ "max_tokens": max_tokens,
63
+ "stream": stream,
64
+ }
65
+ try:
66
+ response = client.chat.completions.create(**completion_args)
67
+ return response
68
+ except Exception as e:
69
+ st.error(f"Error during API call: {e}")
70
+ return None
71
+
72
+
73
+ # Sidebar for configuration
74
+ with st.sidebar:
75
+ st.header("Chat Settings")
76
+
77
+ # Model selection dropdown
78
+ model_id = st.selectbox(
79
+ "Select Model",
80
+ options=st.session_state.models,
81
+ index=0,
82
+ help="Choose a model available on the vLLM server"
83
+ )
84
+
85
+ # System prompt input
86
+ system_prompt = st.text_area(
87
+ "System Prompt",
88
+ value="You are a finance expert, providing clear, accurate, and concise answers to financial questions.",
89
+ height=100,
90
+ help="Enter a system prompt to guide the model's behavior (optional)"
91
+ )
92
+ if st.button("Apply System Prompt"):
93
+ if st.session_state.messages and st.session_state.messages[0]["role"] == "system":
94
+ st.session_state.messages[0] = {"role": "system", "content": system_prompt}
95
+ else:
96
+ st.session_state.messages.insert(0, {"role": "system", "content": system_prompt})
97
+ st.success("System prompt updated!")
98
+
99
+ # Other settings
100
+ temperature = st.slider("Temperature", 0.0, 1.0, 0.2, help="Controls randomness of responses")
101
+ top_p = st.slider("Top P", 0.0, 1.0, 0.85, help="Controls diversity via nucleus sampling")
102
+ max_tokens = st.number_input("Max Tokens", min_value=1, value=512, help="Maximum length of response (optional)")
103
+ if st.button("Clear Chat"):
104
+ st.session_state.messages = (
105
+ [{"role": "system", "content": system_prompt}] if system_prompt else []
106
+ )
107
+
108
+ # Main chat interface
109
+ st.title("Financial Virtual Assistant")
110
+ st.write("Chat with a finance-tuned LLM powered by vLLM on Modal. Select a model and customize your system prompt!")
111
+
112
+ # Display chat history
113
+ for message in st.session_state.messages:
114
+ if message["role"] == "system":
115
+ with st.expander("System Prompt", expanded=False):
116
+ st.markdown(f"**System**: {message['content']}")
117
+ else:
118
+ with st.chat_message(message["role"]):
119
+ st.markdown(message["content"])
120
+
121
+ # User input
122
+ if prompt := st.chat_input("Type your message here..."):
123
+ # Add user message to history
124
+ st.session_state.messages.append({"role": "user", "content": prompt})
125
+ with st.chat_message("user"):
126
+ st.markdown(prompt)
127
+
128
+ # Truncate messages if necessary to stay under token limit
129
+ st.session_state.messages = truncate_messages(
130
+ st.session_state.messages,
131
+ max_tokens=2048 - max_tokens, # Reserve space for the output
132
+ keep_last_n=5
133
+ )
134
+ # Debug: Log token count and messages
135
+ current_tokens = estimate_token_count(st.session_state.messages)
136
+ st.write(f"Debug: Current token count: {current_tokens}")
137
+ st.write(f"Debug: Messages sent to model: {st.session_state.messages}")
138
+
139
+ # Get and display assistant response
140
+ with st.chat_message("assistant"):
141
+ response = get_completion(
142
+ st.session_state.client,
143
+ model_id,
144
+ st.session_state.messages,
145
+ stream=True,
146
+ temperature=temperature,
147
+ top_p=top_p,
148
+ max_tokens=max_tokens
149
+ )
150
+ if response:
151
+ # Stream the response
152
+ placeholder = st.empty()
153
+ assistant_message = ""
154
+ for chunk in response:
155
+ if chunk.choices[0].delta.content:
156
+ assistant_message += chunk.choices[0].delta.content
157
+ placeholder.markdown(assistant_message + "▌") # Cursor effect
158
+ placeholder.markdown(assistant_message) # Final message without cursor
159
+ st.session_state.messages.append({"role": "assistant", "content": assistant_message})
160
+ else:
161
+ st.error("Failed to get a response from the server.")
162
+
163
+ # Instructions
164
+ st.caption("Built with Streamlit and vLLM on Modal. Adjust settings in the sidebar and chat away!")