File size: 5,682 Bytes
6b368a6
 
 
 
 
78c6646
6b368a6
77ff92d
6b368a6
 
 
 
 
 
77ff92d
6b368a6
77ff92d
 
 
6b368a6
77ff92d
6b368a6
50951b2
6b368a6
 
 
 
 
 
 
 
 
 
 
77ff92d
 
6b368a6
 
 
 
 
 
 
77ff92d
 
 
 
 
 
6b368a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3462d7e
 
 
 
6b368a6
 
 
 
 
 
 
 
 
 
 
 
 
77ff92d
6b368a6
 
77ff92d
6b368a6
 
 
77ff92d
 
 
 
6b368a6
 
 
 
e212cb7
6b368a6
 
 
 
 
 
 
77ff92d
 
6b368a6
 
77ff92d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b368a6
 
7ca108f
 
 
6b368a6
 
 
 
 
 
 
 
 
3462d7e
 
6b368a6
 
 
 
77ff92d
6b368a6
 
 
 
77ff92d
6b368a6
 
 
 
 
3462d7e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
from openai import OpenAI
import streamlit as st
import numpy as np
from PIL import Image
from time import perf_counter
import itertools

# Page Configuration
st.set_page_config(
    page_title= "Unify Router Demo",
    page_icon="./assets/unify_spiral.png",
    layout = "wide",
    initial_sidebar_state="collapsed"
)
router_avatar = np.array(Image.open('./assets/unify_spiral.png'))

# Custom font
with open( "./style.css" ) as css:
    st.markdown( f'<style>{css.read()}</style>' , unsafe_allow_html= True)

# Info message
st.info(
    body="This demo is only a preview. Check out our [Chat UI](https://unify.ai/chat) for the full experience, including more endpoints, and extra customization!",
    icon="ℹ️"
)

# Parameter choices
strategies = {
    'πŸƒ fastest': "tks-per-sec",
    'βŒ› most responsive': "ttft",
    "πŸ’΅ cheapest": "input-cost",
}
models = {
    'πŸ¦™ Llama2 70B Chat': "llama-2-70b-chat",
    'πŸ’¨ Mixtral 8x7B Instruct': "mixtral-8x7b-instruct-v0.1",
    'πŸ’Ž Gemma 7B': "gemma-7b-it",
}

# Body
Parameters_Col, Chat_Col = st.columns([1,3])

with Parameters_Col:

    st.image(
         "./assets/unify_logo.png",
         use_column_width="auto",
     )
    st.markdown("Send your prompts to the best LLM endpoint and optimize performance, all with a **single API**")

    strategy = st.selectbox(
        label = 'I want the', 
        options = tuple(strategies.keys()), 
        help="Choose the metric to optimize the routing for. \
            Fastest picks the endpoint with the highest output tokens per seconds. \
            Most responsive picks the endpoint with the smallest time to complete the request. \
            Cheapest picks the endpoint with the lowest output tokens cost",
    )
    model = st.selectbox(
        label = 'endpoint for',
        options = tuple(models.keys()), 
        help="Select a model to optimize for. The same model can be offered by different model endpoint providers. The router lets you find the optimal endpoint for your chosen model, target metric, and input prompt",
    )
    with st.expander("Advanced Inputs"):
        max_tokens = st.slider(
            label = "Maximum Number Of Tokens",
            min_value=100,
            max_value=2000,
            value=500,
            step=100,
            help = "The maximum number of tokens that can be generated."
        )
        temperature = st.slider(
            label = "Temperature",
            min_value=0.0,
            max_value=1.,
            value=0.5,
            step=0.5,
            help = "The model's output randomness. Higher values give more random outputs."
        )

with Chat_Col:

    # Initializing empty chat space and messages state
    if "messages" not in st.session_state:
        st.session_state.messages = []
    msgs = st.container(height = 350)

    # Writing conversation history
    for msg in st.session_state.messages:
        if msg["role"] == "user":
            msgs.chat_message(msg["role"]).write(msg["content"])
        else:
            msgs.chat_message(msg["role"], avatar=router_avatar).write(msg["content"])

    # Preparing client    
    client = OpenAI(
        base_url="https://api.unify.ai/v0/",
        api_key=st.secrets["UNIFY_API"]
    )

    # Processing prompt box input
    if prompt := st.chat_input("Enter your prompt.."):

        # Displaying user prompt and saving in message states
        st.session_state.messages.append({"role": "user", "content": prompt})
        with msgs.chat_message("user"):
            st.write(prompt)
    
        # Displaying output, metrics, and saving output in message states
        with msgs.status("Routing your prompt..",expanded=True):
            # Sending prompt to model endpoint
            start = perf_counter()
            stream = client.chat.completions.create(
                model="@".join([
                        models[model],
                        strategies[strategy]
                    ]),
                messages=[
                    {"role": m["role"], "content": m["content"]}
                    for m in st.session_state.messages
                ],
                stream=True,
                max_tokens=max_tokens,
                temperature=temperature
            )
            time_to_completion = round(perf_counter() - start, 2)

            # Writing answer progressively
            stream, stream_copy = itertools.tee(stream)
            st.write_stream(stream)
            chunks = [chunk for chunk in stream_copy]

            # Computing metrics
            last_chunk = chunks[-1]
            cost = round(last_chunk.usage["cost"],6)
            output_tokens = last_chunk.usage["completion_tokens"]
            tokens_per_second = round(output_tokens / time_to_completion, 2)
            
            # Displaying model, provider, and metrics
            provider = " ".join(chunks[0].model.split("@")[-1].split("-")).title()
            if " Ai" in provider:
                provider = provider.replace("Ai", "AI")
            st.markdown(f"Model: **{model}**. Provider: **{provider}**")
            st.markdown(
                f"**{tokens_per_second}** Tokens Per Second - \
                  **{time_to_completion}** Seconds to complete - \
                  **{cost:.6f}** $"
                )

        # Saving output to message states
        output_chunks = [chunk.choices[0].delta.content or "" for chunk in chunks]
        response = ''.join(output_chunks) 
        st.session_state.messages.append({"role": "assistant", "content": response})

    # Cancel / Stop button
    if st.button("Clear Chat", key="clear"):
        msgs.empty()
        st.session_state.messages = []