File size: 3,248 Bytes
c1965a3
39209e4
4c07c1e
ab0aa8d
4c07c1e
c1965a3
39209e4
 
 
 
 
 
 
f873ce7
4c07c1e
a846510
c1965a3
f873ce7
a846510
 
c1965a3
f873ce7
a846510
f873ce7
 
c1965a3
f873ce7
 
a846510
39209e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c07c1e
39209e4
 
c1965a3
39209e4
 
 
c1965a3
 
 
39209e4
c2ec273
c1965a3
f31f69d
39209e4
 
c1965a3
 
39209e4
 
 
 
 
4c07c1e
39209e4
 
c1965a3
4c07c1e
c1965a3
39209e4
c1965a3
 
 
 
a846510
f873ce7
c1965a3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# app.py
import json, requests, gradio as gr

API_URL = "http://0.0.0.0:8000/v1/chat/completions"

def stream_completion(message, history, max_tokens, temperature, top_p, beta):
    """Gradio callback: stream the assistant’s reply token-by-token."""
    # -------- build OpenAI-style message list (no system prompt) -------------
    messages = [{"role": "user", "content": u}         # past user turns
                if i % 2 == 0 else                    # even idx → user
                {"role": "assistant", "content": u}    # odd  idx → assistant
                for i, (u, _) in enumerate(sum(([h[0], h[1]] for h in history), []))
                if u]                                  # drop empty strings
    messages.append({"role": "user", "content": message})

    payload = {
        "model": "Qwen/Qwen3-4B",
        "messages": messages,
        "temperature": temperature,
        "top_p": top_p,
        "max_tokens": int(max_tokens),
        "stream": True,
    }
    headers = {
        "Content-Type": "application/json",
        "X-MIXINPUTS-BETA": str(beta),
    }

    try:
        with requests.post(API_URL,
                           json=payload,
                           stream=True,
                           headers=headers,
                           timeout=(10, None)) as resp:
            resp.raise_for_status()

            assistant = ""
            # iterate over the HTTP chunks
            for raw in resp.iter_lines(decode_unicode=True, delimiter=b"\n"):
                if not raw:
                    continue
                if raw.startswith("data: "):
                    data = raw[6:]                 # strip the 'data: ' prefix
                else:
                    data = raw

                if data.strip() == "[DONE]":
                    break

                delta = json.loads(data)["choices"][0]["delta"].get("content", "")
                assistant += delta
                yield history + [(message, assistant)]  # live update in Gradio
    except Exception as err:
        yield history + [(message, f"[ERROR] {err}")]

# ---------------------------- UI --------------------------------------------
with gr.Blocks(title="🎨 Mixture of Inputs (MoI) Demo") as demo:
    gr.Markdown(
        "## 🎨 Mixture of Inputs (MoI) Demo  \n"
        "Streaming vLLM demo with dynamic **beta** adjustment in MoI "
        "(higher beta → less blending)."
    )

    with gr.Row():  # sliders first
        beta         = gr.Slider(0.0, 10.0, value=1.0,  step=0.1,  label="MoI β")
        temperature  = gr.Slider(0.1, 1.0,  value=0.6,  step=0.1,  label="Temperature")
        top_p        = gr.Slider(0.1, 1.0,  value=0.80, step=0.05, label="Top-p")
        max_tokens   = gr.Slider(1,   2048, value=512,  step=1,    label="Max new tokens")

    chatbot   = gr.Chatbot(height=450)
    user_box  = gr.Textbox(placeholder="Type a message and press Enter…", show_label=False)
    clear_btn = gr.Button("Clear chat")

    user_box.submit(
        fn=stream_completion,
        inputs=[user_box, chatbot, max_tokens, temperature, top_p, beta],
        outputs=chatbot,
    )
    clear_btn.click(lambda: None, None, chatbot, queue=False)

if __name__ == "__main__":
    demo.launch()