Spaces:
Sleeping
Sleeping
File size: 7,788 Bytes
68e5278 ee845c8 68e5278 156ed6f 68e5278 156ed6f 68e5278 156ed6f 68e5278 156ed6f 68e5278 156ed6f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
import streamlit as st
from openai import OpenAI
import os
# Authentication function
def authenticate():
st.title("Financial Virtual Assistant")
st.subheader("Login")
username = st.text_input("Username")
password = st.text_input("Password", type="password")
if st.button("Login"):
if username == os.getenv('username') and password == os.getenv('password'):
st.session_state.authenticated = True
return True
else:
st.error("Invalid username or password")
return False
# Check authentication state
if "authenticated" not in st.session_state:
st.session_state.authenticated = False
if not st.session_state.authenticated:
if authenticate():
st.rerun()
else:
# Streamlit page configuration
st.set_page_config(page_title="Financial Virtual Assistant", layout="wide")
# Initialize session state for chat history and API client
if "messages" not in st.session_state:
st.session_state.messages = []
if "client" not in st.session_state:
base_url = f"https://brandontoews--vllm-openai-compatible-serve.modal.run/v1"
# Initialize OpenAI client
st.session_state.client = OpenAI(
api_key=os.getenv('openai_api_key'), # Replace with your API key or use modal.Secret
base_url=base_url
)
if "models" not in st.session_state:
# Fetch available models from the server
try:
models = st.session_state.client.models.list().data
st.session_state.models = [model.id for model in models]
except Exception as e:
st.session_state.models = ["neuralmagic/Mistral-7B-Instruct-v0.3-quantized.w8a16"] # Fallback if fetch fails
st.warning(f"Failed to fetch models: {e}. Using default model.")
# Function to estimate token count (heuristic: ~4 chars per token)
def estimate_token_count(messages):
total_chars = sum(len(message["content"]) for message in messages)
return total_chars // 4 # Approximate: 4 characters per token
# Function to truncate messages if token count exceeds limit
def truncate_messages(messages, max_tokens=2048, keep_last_n=5):
# Always keep the system prompt (if present) and the last N messages
system_prompt = [msg for msg in messages if msg["role"] == "system"]
non_system_messages = [msg for msg in messages if msg["role"] != "system"]
# Estimate current token count
current_tokens = estimate_token_count(messages)
# If under the limit, no truncation needed
if current_tokens <= max_tokens:
return messages
# Truncate older non-system messages, keeping the last N
truncated_non_system_messages = non_system_messages[-keep_last_n:] if len(
non_system_messages) > keep_last_n else non_system_messages
# Reconstruct messages: system prompt (if any) + truncated non-system messages
return system_prompt + truncated_non_system_messages
# Function to get completion from vLLM server
def get_completion(client, model_id, messages, stream=True, temperature=0.2, top_p=0.85, max_tokens=512):
completion_args = {
"model": model_id,
"messages": messages,
"temperature": temperature,
"top_p": top_p,
"max_tokens": max_tokens,
"stream": stream,
}
try:
response = client.chat.completions.create(**completion_args)
return response
except Exception as e:
st.error(f"Error during API call: {e}")
return None
# Sidebar for configuration
with st.sidebar:
st.header("Chat Settings")
# Model selection dropdown
model_id = st.selectbox(
"Select Model",
options=st.session_state.models,
index=0,
help="Choose a model available on the vLLM server"
)
# System prompt input
system_prompt = st.text_area(
"System Prompt",
value="You are a finance expert, providing clear, accurate, and concise answers to financial questions.",
height=100,
help="Enter a system prompt to guide the model's behavior (optional)"
)
if st.button("Apply System Prompt"):
if st.session_state.messages and st.session_state.messages[0]["role"] == "system":
st.session_state.messages[0] = {"role": "system", "content": system_prompt}
else:
st.session_state.messages.insert(0, {"role": "system", "content": system_prompt})
st.success("System prompt updated!")
# Other settings
temperature = st.slider("Temperature", 0.0, 1.0, 0.2, help="Controls randomness of responses")
top_p = st.slider("Top P", 0.0, 1.0, 0.85, help="Controls diversity via nucleus sampling")
max_tokens = st.number_input("Max Tokens", min_value=1, value=512, help="Maximum length of response (optional)")
if st.button("Clear Chat"):
st.session_state.messages = (
[{"role": "system", "content": system_prompt}] if system_prompt else []
)
# Main chat interface
st.title("Financial Virtual Assistant")
st.write("Chat with a finance-tuned LLM powered by vLLM on Modal. Select a model and customize your system prompt!")
# Display chat history
for message in st.session_state.messages:
if message["role"] == "system":
with st.expander("System Prompt", expanded=False):
st.markdown(f"**System**: {message['content']}")
else:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# User input
if prompt := st.chat_input("Type your message here..."):
# Add user message to history
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
# Truncate messages if necessary to stay under token limit
st.session_state.messages = truncate_messages(
st.session_state.messages,
max_tokens=2048 - max_tokens, # Reserve space for the output
keep_last_n=5
)
# Debug: Log token count and messages
current_tokens = estimate_token_count(st.session_state.messages)
st.write(f"Debug: Current token count: {current_tokens}")
st.write(f"Debug: Messages sent to model: {st.session_state.messages}")
# Get and display assistant response
with st.chat_message("assistant"):
response = get_completion(
st.session_state.client,
model_id,
st.session_state.messages,
stream=True,
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens
)
if response:
# Stream the response
placeholder = st.empty()
assistant_message = ""
for chunk in response:
if chunk.choices[0].delta.content:
assistant_message += chunk.choices[0].delta.content
placeholder.markdown(assistant_message + "▌") # Cursor effect
placeholder.markdown(assistant_message) # Final message without cursor
st.session_state.messages.append({"role": "assistant", "content": assistant_message})
else:
st.error("Failed to get a response from the server.")
# Instructions
st.caption("Built with Streamlit and vLLM on Modal. Adjust settings in the sidebar and chat away!") |