import streamlit as st from openai import OpenAI import os # Authentication function def authenticate(): st.title("Financial Virtual Assistant") st.subheader("Login") username = st.text_input("Username") password = st.text_input("Password", type="password") if st.button("Login"): if username == os.getenv('username') and password == os.getenv('password'): st.session_state.authenticated = True return True else: st.error("Invalid username or password") return False # Check authentication state if "authenticated" not in st.session_state: st.session_state.authenticated = False if not st.session_state.authenticated: if authenticate(): st.rerun() else: # Streamlit page configuration st.set_page_config(page_title="Financial Virtual Assistant", layout="wide") # Initialize session state for chat history and API client if "messages" not in st.session_state: st.session_state.messages = [] if "client" not in st.session_state: base_url = f"https://brandontoews--vllm-openai-compatible-serve.modal.run/v1" # Initialize OpenAI client st.session_state.client = OpenAI( api_key=os.getenv('openai_api_key'), # Replace with your API key or use modal.Secret base_url=base_url ) if "models" not in st.session_state: # Fetch available models from the server try: models = st.session_state.client.models.list().data st.session_state.models = [model.id for model in models] except Exception as e: st.session_state.models = ["neuralmagic/Mistral-7B-Instruct-v0.3-quantized.w8a16"] # Fallback if fetch fails st.warning(f"Failed to fetch models: {e}. Using default model.") # Function to estimate token count (heuristic: ~4 chars per token) def estimate_token_count(messages): total_chars = sum(len(message["content"]) for message in messages) return total_chars // 4 # Approximate: 4 characters per token # Function to truncate messages if token count exceeds limit def truncate_messages(messages, max_tokens=2048, keep_last_n=5): # Always keep the system prompt (if present) and the last N messages system_prompt = [msg for msg in messages if msg["role"] == "system"] non_system_messages = [msg for msg in messages if msg["role"] != "system"] # Estimate current token count current_tokens = estimate_token_count(messages) # If under the limit, no truncation needed if current_tokens <= max_tokens: return messages # Truncate older non-system messages, keeping the last N truncated_non_system_messages = non_system_messages[-keep_last_n:] if len( non_system_messages) > keep_last_n else non_system_messages # Reconstruct messages: system prompt (if any) + truncated non-system messages return system_prompt + truncated_non_system_messages # Function to get completion from vLLM server def get_completion(client, model_id, messages, stream=True, temperature=0.2, top_p=0.85, max_tokens=512): completion_args = { "model": model_id, "messages": messages, "temperature": temperature, "top_p": top_p, "max_tokens": max_tokens, "stream": stream, } try: response = client.chat.completions.create(**completion_args) return response except Exception as e: st.error(f"Error during API call: {e}") return None # Sidebar for configuration with st.sidebar: st.header("Chat Settings") # Model selection dropdown model_id = st.selectbox( "Select Model", options=st.session_state.models, index=0, help="Choose a model available on the vLLM server" ) # System prompt input system_prompt = st.text_area( "System Prompt", value="You are a finance expert, providing clear, accurate, and concise answers to financial questions.", height=100, help="Enter a system prompt to guide the model's behavior (optional)" ) if st.button("Apply System Prompt"): if st.session_state.messages and st.session_state.messages[0]["role"] == "system": st.session_state.messages[0] = {"role": "system", "content": system_prompt} else: st.session_state.messages.insert(0, {"role": "system", "content": system_prompt}) st.success("System prompt updated!") # Other settings temperature = st.slider("Temperature", 0.0, 1.0, 0.2, help="Controls randomness of responses") top_p = st.slider("Top P", 0.0, 1.0, 0.85, help="Controls diversity via nucleus sampling") max_tokens = st.number_input("Max Tokens", min_value=1, value=512, help="Maximum length of response (optional)") if st.button("Clear Chat"): st.session_state.messages = ( [{"role": "system", "content": system_prompt}] if system_prompt else [] ) # Main chat interface st.title("Financial Virtual Assistant") st.write("Chat with a finance-tuned LLM powered by vLLM on Modal. Select a model and customize your system prompt!") # Display chat history for message in st.session_state.messages: if message["role"] == "system": with st.expander("System Prompt", expanded=False): st.markdown(f"**System**: {message['content']}") else: with st.chat_message(message["role"]): st.markdown(message["content"]) # User input if prompt := st.chat_input("Type your message here..."): # Add user message to history st.session_state.messages.append({"role": "user", "content": prompt}) with st.chat_message("user"): st.markdown(prompt) # Truncate messages if necessary to stay under token limit st.session_state.messages = truncate_messages( st.session_state.messages, max_tokens=2048 - max_tokens, # Reserve space for the output keep_last_n=5 ) # Debug: Log token count and messages current_tokens = estimate_token_count(st.session_state.messages) st.write(f"Debug: Current token count: {current_tokens}") st.write(f"Debug: Messages sent to model: {st.session_state.messages}") # Get and display assistant response with st.chat_message("assistant"): response = get_completion( st.session_state.client, model_id, st.session_state.messages, stream=True, temperature=temperature, top_p=top_p, max_tokens=max_tokens ) if response: # Stream the response placeholder = st.empty() assistant_message = "" for chunk in response: if chunk.choices[0].delta.content: assistant_message += chunk.choices[0].delta.content placeholder.markdown(assistant_message + "▌") # Cursor effect placeholder.markdown(assistant_message) # Final message without cursor st.session_state.messages.append({"role": "assistant", "content": assistant_message}) else: st.error("Failed to get a response from the server.") # Instructions st.caption("Built with Streamlit and vLLM on Modal. Adjust settings in the sidebar and chat away!")