FinVA_Demo / app.py
Brandon Toews
Fixed secure application file
ee845c8
import streamlit as st
from openai import OpenAI
import os
# Authentication function
def authenticate():
st.title("Financial Virtual Assistant")
st.subheader("Login")
username = st.text_input("Username")
password = st.text_input("Password", type="password")
if st.button("Login"):
if username == os.getenv('username') and password == os.getenv('password'):
st.session_state.authenticated = True
return True
else:
st.error("Invalid username or password")
return False
# Check authentication state
if "authenticated" not in st.session_state:
st.session_state.authenticated = False
if not st.session_state.authenticated:
if authenticate():
st.rerun()
else:
# Streamlit page configuration
st.set_page_config(page_title="Financial Virtual Assistant", layout="wide")
# Initialize session state for chat history and API client
if "messages" not in st.session_state:
st.session_state.messages = []
if "client" not in st.session_state:
base_url = f"https://brandontoews--vllm-openai-compatible-serve.modal.run/v1"
# Initialize OpenAI client
st.session_state.client = OpenAI(
api_key=os.getenv('openai_api_key'), # Replace with your API key or use modal.Secret
base_url=base_url
)
if "models" not in st.session_state:
# Fetch available models from the server
try:
models = st.session_state.client.models.list().data
st.session_state.models = [model.id for model in models]
except Exception as e:
st.session_state.models = ["neuralmagic/Mistral-7B-Instruct-v0.3-quantized.w8a16"] # Fallback if fetch fails
st.warning(f"Failed to fetch models: {e}. Using default model.")
# Function to estimate token count (heuristic: ~4 chars per token)
def estimate_token_count(messages):
total_chars = sum(len(message["content"]) for message in messages)
return total_chars // 4 # Approximate: 4 characters per token
# Function to truncate messages if token count exceeds limit
def truncate_messages(messages, max_tokens=2048, keep_last_n=5):
# Always keep the system prompt (if present) and the last N messages
system_prompt = [msg for msg in messages if msg["role"] == "system"]
non_system_messages = [msg for msg in messages if msg["role"] != "system"]
# Estimate current token count
current_tokens = estimate_token_count(messages)
# If under the limit, no truncation needed
if current_tokens <= max_tokens:
return messages
# Truncate older non-system messages, keeping the last N
truncated_non_system_messages = non_system_messages[-keep_last_n:] if len(
non_system_messages) > keep_last_n else non_system_messages
# Reconstruct messages: system prompt (if any) + truncated non-system messages
return system_prompt + truncated_non_system_messages
# Function to get completion from vLLM server
def get_completion(client, model_id, messages, stream=True, temperature=0.2, top_p=0.85, max_tokens=512):
completion_args = {
"model": model_id,
"messages": messages,
"temperature": temperature,
"top_p": top_p,
"max_tokens": max_tokens,
"stream": stream,
}
try:
response = client.chat.completions.create(**completion_args)
return response
except Exception as e:
st.error(f"Error during API call: {e}")
return None
# Sidebar for configuration
with st.sidebar:
st.header("Chat Settings")
# Model selection dropdown
model_id = st.selectbox(
"Select Model",
options=st.session_state.models,
index=0,
help="Choose a model available on the vLLM server"
)
# System prompt input
system_prompt = st.text_area(
"System Prompt",
value="You are a finance expert, providing clear, accurate, and concise answers to financial questions.",
height=100,
help="Enter a system prompt to guide the model's behavior (optional)"
)
if st.button("Apply System Prompt"):
if st.session_state.messages and st.session_state.messages[0]["role"] == "system":
st.session_state.messages[0] = {"role": "system", "content": system_prompt}
else:
st.session_state.messages.insert(0, {"role": "system", "content": system_prompt})
st.success("System prompt updated!")
# Other settings
temperature = st.slider("Temperature", 0.0, 1.0, 0.2, help="Controls randomness of responses")
top_p = st.slider("Top P", 0.0, 1.0, 0.85, help="Controls diversity via nucleus sampling")
max_tokens = st.number_input("Max Tokens", min_value=1, value=512, help="Maximum length of response (optional)")
if st.button("Clear Chat"):
st.session_state.messages = (
[{"role": "system", "content": system_prompt}] if system_prompt else []
)
# Main chat interface
st.title("Financial Virtual Assistant")
st.write("Chat with a finance-tuned LLM powered by vLLM on Modal. Select a model and customize your system prompt!")
# Display chat history
for message in st.session_state.messages:
if message["role"] == "system":
with st.expander("System Prompt", expanded=False):
st.markdown(f"**System**: {message['content']}")
else:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# User input
if prompt := st.chat_input("Type your message here..."):
# Add user message to history
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
# Truncate messages if necessary to stay under token limit
st.session_state.messages = truncate_messages(
st.session_state.messages,
max_tokens=2048 - max_tokens, # Reserve space for the output
keep_last_n=5
)
# Debug: Log token count and messages
current_tokens = estimate_token_count(st.session_state.messages)
st.write(f"Debug: Current token count: {current_tokens}")
st.write(f"Debug: Messages sent to model: {st.session_state.messages}")
# Get and display assistant response
with st.chat_message("assistant"):
response = get_completion(
st.session_state.client,
model_id,
st.session_state.messages,
stream=True,
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens
)
if response:
# Stream the response
placeholder = st.empty()
assistant_message = ""
for chunk in response:
if chunk.choices[0].delta.content:
assistant_message += chunk.choices[0].delta.content
placeholder.markdown(assistant_message + "▌") # Cursor effect
placeholder.markdown(assistant_message) # Final message without cursor
st.session_state.messages.append({"role": "assistant", "content": assistant_message})
else:
st.error("Failed to get a response from the server.")
# Instructions
st.caption("Built with Streamlit and vLLM on Modal. Adjust settings in the sidebar and chat away!")