File size: 2,913 Bytes
c1965a3
 
5331b33
f873ce7
c1965a3
c2ec273
4c07c1e
314161f
4c07c1e
c1965a3
 
 
 
 
 
 
 
 
 
 
 
 
f873ce7
4c07c1e
a846510
c1965a3
f873ce7
a846510
 
c1965a3
f873ce7
a846510
f873ce7
 
c1965a3
f873ce7
 
a846510
c1965a3
 
 
4c07c1e
c1965a3
f873ce7
c1965a3
f873ce7
c1965a3
 
 
 
 
 
 
 
 
c2ec273
c1965a3
f31f69d
 
c1965a3
 
 
 
f31f69d
b7b68a4
 
f31f69d
 
4c07c1e
c1965a3
 
 
4c07c1e
c1965a3
 
 
 
 
 
 
a846510
f873ce7
c2ec273
5331b33
c1965a3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# app.py
import json
import requests, threading
import sseclient
import gradio as gr
from server import setup_mixinputs, launch_vllm_server

API_URL = "http://localhost:8000/v1/chat/completions"


def stream_completion(message, history, max_tokens, temperature, top_p, beta):
    """
    Gradio callback: takes the newest user message + full chat history,
    returns an updated history while streaming assistant tokens.
    """
    # ------- build OpenAI-style message list (no system prompt) -------------
    messages = []
    for usr, bot in history:
        if usr:
            messages.append({"role": "user", "content": usr})
        if bot:
            messages.append({"role": "assistant", "content": bot})
    messages.append({"role": "user", "content": message})

    payload = {
        "model": "Qwen/Qwen3-4B",
        "messages": messages,
        "temperature": temperature,
        "top_p": top_p,
        "max_tokens": int(max_tokens),
        "stream": True,
    }
    headers = {
        "Content-Type": "application/json",
        "X-MIXINPUTS-BETA": str(beta),
    }

    try:
        resp = requests.post(API_URL, json=payload, stream=True, headers=headers, timeout=60)
        resp.raise_for_status()
        client = sseclient.SSEClient(resp)

        assistant = ""
        for event in client.events():
            if event.data.strip() == "[DONE]":
                break
            delta = json.loads(event.data)["choices"][0]["delta"].get("content", "")
            assistant += delta
            yield history + [(message, assistant)]  # update the chat box live

    except Exception as err:
        yield history + [(message, f"[ERROR] {err}")]


# ----------------------- UI ---------------------------------------------
with gr.Blocks(title="🎨 Mixture of Inputs (MoI) Demo") as demo:
    gr.Markdown(
        "## 🎨 Mixture of Inputs (MoI) Demo  \n"
        "Streaming vLLM demo with dynamic **beta** adjustment in MoI, higher beta means less blending."
    )

    # sliders first – all on one row
    with gr.Row():
        beta = gr.Slider(0.0, 10.0, value=1.0, step=0.1, label="MoI Beta")
        temperature = gr.Slider(0.1, 1.0, value=0.6, step=0.1, label="Temperature")
        top_p = gr.Slider(0.1, 1.0, value=0.80, step=0.05, label="Top-p")
        max_tokens = gr.Slider(1, 2048, value=512, step=1, label="Max new tokens")


    chatbot = gr.Chatbot(height=450)
    user_box = gr.Textbox(placeholder="Type a message and press Enter…", show_label=False)
    clear_btn = gr.Button("Clear chat")

    # wiring
    user_box.submit(
        stream_completion,
        inputs=[user_box, chatbot, max_tokens, temperature, top_p, beta],
        outputs=chatbot,
    )
    clear_btn.click(lambda: None, None, chatbot, queue=False)

if __name__ == "__main__":
    setup_mixinputs()
    threading.Thread(target=launch_vllm_server, daemon=True).start()
    demo.launch()