File size: 7,306 Bytes
eefa003
9497fc8
31bd9a2
9497fc8
31bd9a2
0d67078
eefa003
4d9f591
 
0d67078
31bd9a2
b363c40
eefa003
0d67078
 
eefa003
0d67078
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9497fc8
 
31bd9a2
0d67078
 
 
eefa003
 
 
 
 
 
0d67078
 
9497fc8
eefa003
0d67078
9e185d2
 
 
 
 
eefa003
0d67078
9e185d2
9497fc8
eefa003
0d67078
31bd9a2
 
1ed1b11
 
 
 
 
 
 
 
 
 
 
 
 
 
31bd9a2
0d67078
 
 
 
 
 
 
 
9e185d2
31bd9a2
9e185d2
31bd9a2
 
 
 
9e185d2
0d67078
 
9e185d2
0d67078
 
 
 
 
 
 
9e185d2
0d67078
9e185d2
0d67078
 
9e185d2
0d67078
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e185d2
0d67078
 
9e185d2
 
 
0d67078
 
9e185d2
31bd9a2
eefa003
0d67078
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1ed1b11
 
 
 
0d67078
 
 
 
7b87598
 
 
0d67078
 
 
 
1ed1b11
 
 
0d67078
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9497fc8
1ed1b11
0d67078
7b87598
1ed1b11
0d67078
31bd9a2
0d67078
 
 
 
1ed1b11
 
0d67078
31bd9a2
0d67078
 
 
 
 
 
 
 
1ed1b11
0d67078
 
 
 
 
 
 
 
31bd9a2
80a0edc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
import datetime

from openai import OpenAI
import gradio as gr

from utils import COMMUNITY_POSTFIX_URL, get_model_config, log_message, check_format, models_config

print(f"Gradio version: {gr.__version__}")

DEFAULT_MODEL_NAME = "Apriel-Nemotron-15b-Thinker"


chat_start_count = 0
model_config = None
client = None


def setup_model(model_name, intial=False):
    global model_config, client
    model_config = get_model_config(model_name)
    log_message(f"update_model() --> Model config: {model_config}")
    client = OpenAI(
        api_key=model_config.get('AUTH_TOKEN'),
        base_url=model_config.get('VLLM_API_URL')
    )

    _model_hf_name = model_config.get("MODEL_HF_URL").split('https://huggingface.co/')[1]
    _link = f"<a href='{model_config.get('MODEL_HF_URL')}{COMMUNITY_POSTFIX_URL}' target='_blank'>{_model_hf_name}</a>"
    _description = f"Please use the community section on this space to provide feedback! {_link}"

    print(f"Switched to model {_model_hf_name}")

    if intial:
        return
    else:
        return _description


def chat_fn(message, history):
    log_message(f"{'-' * 80}")
    log_message(f"chat_fn() --> Message: {message}")
    log_message(f"chat_fn() --> History: {history}")

    global chat_start_count
    chat_start_count = chat_start_count + 1
    print(
        f"{datetime.datetime.now()}: chat_start_count: {chat_start_count}, turns: {int(len(history if history else []) / 3)}")

    is_reasoning = model_config.get("REASONING")

    # Remove any assistant messages with metadata from history for multiple turns
    log_message(f"Original History: {history}")
    check_format(history, "messages")
    history = [item for item in history if
               not (isinstance(item, dict) and
                    item.get("role") == "assistant" and
                    isinstance(item.get("metadata"), dict) and
                    item.get("metadata", {}).get("title") is not None)]
    log_message(f"Updated History: {history}")
    check_format(history, "messages")

    history.append({"role": "user", "content": message})
    log_message(f"History with user message: {history}")
    check_format(history, "messages")

    # Create the streaming response
    try:
        stream = client.chat.completions.create(
            model=model_config.get('MODEL_NAME'),
            messages=history,
            temperature=0.8,
            stream=True
        )
    except Exception as e:
        print(f"Error: {e}")
        yield gr.ChatMessage(
            role="assistant",
            content="😔 The model is unavailable at the moment. Please try again later.",
        )
        return

    if is_reasoning:
        history.append(gr.ChatMessage(
            role="assistant",
            content="Thinking...",
            metadata={"title": "🧠 Thought"}
        ))
        log_message(f"History added thinking: {history}")
        check_format(history, "messages")

    output = ""
    completion_started = False
    for chunk in stream:
        # Extract the new content from the delta field
        content = getattr(chunk.choices[0].delta, "content", "")
        output += content

        if is_reasoning:
            parts = output.split("[BEGIN FINAL RESPONSE]")

            if len(parts) > 1:
                if parts[1].endswith("[END FINAL RESPONSE]"):
                    parts[1] = parts[1].replace("[END FINAL RESPONSE]", "")
                if parts[1].endswith("[END FINAL RESPONSE]\n<|end|>"):
                    parts[1] = parts[1].replace("[END FINAL RESPONSE]\n<|end|>", "")
                if parts[1].endswith("<|end|>"):
                    parts[1] = parts[1].replace("<|end|>", "")

            history[-1 if not completion_started else -2] = gr.ChatMessage(
                role="assistant",
                content=parts[0],
                metadata={"title": "🧠 Thought"}
            )
            if completion_started:
                history[-1] = gr.ChatMessage(
                    role="assistant",
                    content=parts[1]
                )
            elif len(parts) > 1 and not completion_started:
                completion_started = True
                history.append(gr.ChatMessage(
                    role="assistant",
                    content=parts[1]
                ))
        else:
            if output.endswith("<|end|>"):
                output = output.replace("<|end|>", "")
            history[-1] = gr.ChatMessage(
                role="assistant",
                content=output
            )

        # only yield the most recent assistant messages
        messages_to_yield = history[-1:] if not completion_started else history[-2:]
        # check_format(messages_to_yield, "messages")
        # log_message(f"Yielding messages: {messages_to_yield}")
        yield messages_to_yield

    log_message(f"Final History: {history}")
    check_format(history, "messages")


title = None
description = None

with gr.Blocks(theme=gr.themes.Default(primary_hue="green")) as demo:
    gr.HTML("""
    <style>
        .model-message {
            text-align: end;
        }
    
        .model-dropdown-container {
            display: flex;
            align-items: center;
            gap: 10px;
            padding: 0;
        }
        
        .chatbot {
            max-height: 1400px;
        }
        
        @media (max-width: 800px) {
            .responsive-row {
                flex-direction: column;
            }
            .model-message {
                text-align: start;
            }
            .model-dropdown-container {
                flex-direction: column;
                align-items: flex-start;
            }
            .chatbot {
                max-height: 900px;
            }
        }
    """)

    with gr.Row(variant="panel", elem_classes="responsive-row"):
        with gr.Column(scale=1, min_width=400, elem_classes="model-dropdown-container"):
            model_dropdown = gr.Dropdown(
                choices=[f"Model: {model}" for model in models_config.keys()],
                value=f"Model: {DEFAULT_MODEL_NAME}",
                label=None,
                interactive=True,
                container=False,
                scale=0,
                min_width=400
            )
        with gr.Column(scale=4, min_width=0):
            description_html = gr.HTML(description, elem_classes="model-message")

    chatbot = gr.Chatbot(
        type="messages",
        height="calc(100dvh - 280px)",
        elem_classes="chatbot",
    )

    chat_interface = gr.ChatInterface(
        chat_fn,
        description="",
        type="messages",
        chatbot=chatbot,
        fill_height=True,
    )

    # Add this line to ensure the model is reset to default on page reload
    demo.load(lambda: setup_model(DEFAULT_MODEL_NAME, intial=False), [], [description_html])


    def update_model_and_clear(model_name):
        # Remove the "Model: " prefix to get the actual model name
        actual_model_name = model_name.replace("Model: ", "")
        desc = setup_model(actual_model_name)
        chatbot.clear()  # Critical line
        return desc


    model_dropdown.change(
        fn=update_model_and_clear,
        inputs=[model_dropdown],
        outputs=[description_html]
    )

demo.launch(ssr_mode=False)