File size: 7,788 Bytes
68e5278
 
ee845c8
68e5278
156ed6f
 
 
 
 
 
 
 
 
 
 
 
68e5278
156ed6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68e5278
 
156ed6f
 
 
 
 
 
68e5278
156ed6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68e5278
156ed6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
import streamlit as st
from openai import OpenAI
import os

# Authentication function
def authenticate():
    st.title("Financial Virtual Assistant")
    st.subheader("Login")

    username = st.text_input("Username")
    password = st.text_input("Password", type="password")

    if st.button("Login"):
        if username == os.getenv('username') and password == os.getenv('password'):
            st.session_state.authenticated = True
            return True
        else:
            st.error("Invalid username or password")
    return False


# Check authentication state
if "authenticated" not in st.session_state:
    st.session_state.authenticated = False

if not st.session_state.authenticated:
    if authenticate():
        st.rerun()
else:
    # Streamlit page configuration
    st.set_page_config(page_title="Financial Virtual Assistant", layout="wide")

    # Initialize session state for chat history and API client
    if "messages" not in st.session_state:
        st.session_state.messages = []
    if "client" not in st.session_state:
        base_url = f"https://brandontoews--vllm-openai-compatible-serve.modal.run/v1"

        # Initialize OpenAI client
        st.session_state.client = OpenAI(
            api_key=os.getenv('openai_api_key'),  # Replace with your API key or use modal.Secret
            base_url=base_url
        )
    if "models" not in st.session_state:
        # Fetch available models from the server
        try:
            models = st.session_state.client.models.list().data
            st.session_state.models = [model.id for model in models]
        except Exception as e:
            st.session_state.models = ["neuralmagic/Mistral-7B-Instruct-v0.3-quantized.w8a16"]  # Fallback if fetch fails
            st.warning(f"Failed to fetch models: {e}. Using default model.")


    # Function to estimate token count (heuristic: ~4 chars per token)
    def estimate_token_count(messages):
        total_chars = sum(len(message["content"]) for message in messages)
        return total_chars // 4  # Approximate: 4 characters per token


    # Function to truncate messages if token count exceeds limit
    def truncate_messages(messages, max_tokens=2048, keep_last_n=5):
        # Always keep the system prompt (if present) and the last N messages
        system_prompt = [msg for msg in messages if msg["role"] == "system"]
        non_system_messages = [msg for msg in messages if msg["role"] != "system"]

        # Estimate current token count
        current_tokens = estimate_token_count(messages)

        # If under the limit, no truncation needed
        if current_tokens <= max_tokens:
            return messages

        # Truncate older non-system messages, keeping the last N
        truncated_non_system_messages = non_system_messages[-keep_last_n:] if len(
            non_system_messages) > keep_last_n else non_system_messages

        # Reconstruct messages: system prompt (if any) + truncated non-system messages
        return system_prompt + truncated_non_system_messages


    # Function to get completion from vLLM server
    def get_completion(client, model_id, messages, stream=True, temperature=0.2, top_p=0.85, max_tokens=512):
        completion_args = {
            "model": model_id,
            "messages": messages,
            "temperature": temperature,
            "top_p": top_p,
            "max_tokens": max_tokens,
            "stream": stream,
        }
        try:
            response = client.chat.completions.create(**completion_args)
            return response
        except Exception as e:
            st.error(f"Error during API call: {e}")
            return None


    # Sidebar for configuration
    with st.sidebar:
        st.header("Chat Settings")

        # Model selection dropdown
        model_id = st.selectbox(
            "Select Model",
            options=st.session_state.models,
            index=0,
            help="Choose a model available on the vLLM server"
        )

        # System prompt input
        system_prompt = st.text_area(
            "System Prompt",
            value="You are a finance expert, providing clear, accurate, and concise answers to financial questions.",
            height=100,
            help="Enter a system prompt to guide the model's behavior (optional)"
        )
        if st.button("Apply System Prompt"):
            if st.session_state.messages and st.session_state.messages[0]["role"] == "system":
                st.session_state.messages[0] = {"role": "system", "content": system_prompt}
            else:
                st.session_state.messages.insert(0, {"role": "system", "content": system_prompt})
            st.success("System prompt updated!")

        # Other settings
        temperature = st.slider("Temperature", 0.0, 1.0, 0.2, help="Controls randomness of responses")
        top_p = st.slider("Top P", 0.0, 1.0, 0.85, help="Controls diversity via nucleus sampling")
        max_tokens = st.number_input("Max Tokens", min_value=1, value=512, help="Maximum length of response (optional)")
        if st.button("Clear Chat"):
            st.session_state.messages = (
                [{"role": "system", "content": system_prompt}] if system_prompt else []
            )

    # Main chat interface
    st.title("Financial Virtual Assistant")
    st.write("Chat with a finance-tuned LLM powered by vLLM on Modal. Select a model and customize your system prompt!")

    # Display chat history
    for message in st.session_state.messages:
        if message["role"] == "system":
            with st.expander("System Prompt", expanded=False):
                st.markdown(f"**System**: {message['content']}")
        else:
            with st.chat_message(message["role"]):
                st.markdown(message["content"])

    # User input
    if prompt := st.chat_input("Type your message here..."):
        # Add user message to history
        st.session_state.messages.append({"role": "user", "content": prompt})
        with st.chat_message("user"):
            st.markdown(prompt)

        # Truncate messages if necessary to stay under token limit
        st.session_state.messages = truncate_messages(
            st.session_state.messages,
            max_tokens=2048 - max_tokens,  # Reserve space for the output
            keep_last_n=5
        )
        # Debug: Log token count and messages
        current_tokens = estimate_token_count(st.session_state.messages)
        st.write(f"Debug: Current token count: {current_tokens}")
        st.write(f"Debug: Messages sent to model: {st.session_state.messages}")

        # Get and display assistant response
        with st.chat_message("assistant"):
            response = get_completion(
                st.session_state.client,
                model_id,
                st.session_state.messages,
                stream=True,
                temperature=temperature,
                top_p=top_p,
                max_tokens=max_tokens
            )
            if response:
                # Stream the response
                placeholder = st.empty()
                assistant_message = ""
                for chunk in response:
                    if chunk.choices[0].delta.content:
                        assistant_message += chunk.choices[0].delta.content
                        placeholder.markdown(assistant_message + "▌")  # Cursor effect
                placeholder.markdown(assistant_message)  # Final message without cursor
                st.session_state.messages.append({"role": "assistant", "content": assistant_message})
            else:
                st.error("Failed to get a response from the server.")

    # Instructions
    st.caption("Built with Streamlit and vLLM on Modal. Adjust settings in the sidebar and chat away!")