File size: 6,537 Bytes
43b5bef
c1e5d4c
518be16
eebaa87
0d6849e
eebaa87
 
43b5bef
eebaa87
cc3006a
cf38aa5
43b5bef
42b5787
43b5bef
 
 
 
 
 
c1e5d4c
eebaa87
 
 
c1e5d4c
518be16
 
c1e5d4c
518be16
eebaa87
 
 
 
 
c9870b1
 
eebaa87
6733659
c9870b1
 
6733659
eebaa87
6733659
c9870b1
 
6733659
4766698
c9870b1
 
 
 
 
 
 
 
 
c1e5d4c
c9870b1
6617dfe
eebaa87
42b5787
c9870b1
 
6617dfe
518be16
eebaa87
c1e5d4c
 
 
eebaa87
 
 
 
c1e5d4c
c9870b1
c1e5d4c
eebaa87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57a76f2
eebaa87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc3006a
43b5bef
eebaa87
 
 
 
 
 
 
fdaf591
eebaa87
 
fdaf591
 
eebaa87
 
 
8190eb3
eebaa87
 
8190eb3
 
 
 
 
eebaa87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fdaf591
eebaa87
 
 
 
 
 
43b5bef
c1e5d4c
eebaa87
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import gradio as gr
import os
import requests
import time
from datetime import datetime
from typing import List, Dict
from session_manager import SessionManager  # only if you need sessions

# Initialize session manager and get HF API key (adjust if not using sessions)
session_manager = SessionManager()
HF_API_KEY = os.getenv("HF_API_KEY")

# Model endpoints configuration
MODEL_ENDPOINTS = {
    "Qwen2.5-72B-Instruct": "https://api-inference.huggingface.co/models/Qwen/Qwen2.5-72B-Instruct",
    "Llama3.3-70B-Instruct": "https://api-inference.huggingface.co/models/meta-llama/Llama-3.3-70B-Instruct",
    "Qwen2.5-Coder-32B-Instruct": "https://api-inference.huggingface.co/models/Qwen/Qwen2.5-Coder-32B-Instruct",
}

def query_model(model_name: str, messages: List[Dict[str, str]]) -> str:
    """
    Query a single model with the conversation so far (list of dicts with 'role' and 'content').
    """
    endpoint = MODEL_ENDPOINTS[model_name]
    headers = {
        "Authorization": f"Bearer {HF_API_KEY}",
        "Content-Type": "application/json"
    }

    # Combine conversation into a single string (simple example)
    conversation = "\n".join(f"{m['role']}: {m['content']}" for m in messages)

    # Model-specific prompt formatting
    model_prompts = {
        "Qwen2.5-72B-Instruct": (
            f"<|im_start|>system\nCollaborate with other experts:\n{conversation}<|im_end|>\n"
            "<|im_start|>assistant\nMy analysis:"
        ),
        "Llama3.3-70B-Instruct": (
            "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n"
            f"Build on the conversation:\n{conversation}<|eot_id|>\n"
            "<|start_header_id|>assistant<|end_header_id|>\nMy contribution:"
        ),
        "Qwen2.5-Coder-32B-Instruct": (
            f"<|im_start|>system\nTechnical discussion context:\n{conversation}<|im_end|>\n"
            "<|im_start|>assistant\nTechnical perspective:"
        )
    }

    stop_sequences = {
        "Qwen2.5-72B-Instruct": ["<|im_end|>", "<|endoftext|>"],
        "Llama3.3-70B-Instruct": ["<|eot_id|>", "\nuser:"],
        "Qwen2.5-Coder-32B-Instruct": ["<|im_end|>", "<|endoftext|>"]
    }

    payload = {
        "inputs": model_prompts[model_name],
        "parameters": {
            "max_tokens": 1024,
            "temperature": 0.7,
            "stop_sequences": stop_sequences[model_name],
            "return_full_text": False
        }
    }

    try:
        response = requests.post(endpoint, json=payload, headers=headers)
        response.raise_for_status()
        generated = response.json()[0]["generated_text"]
        # Clean up possible leftover tokens
        generated = generated.split("<|")[0].strip()
        return generated
    except Exception as e:
        return f"{model_name} error: {str(e)}"


def on_new_session():
    """Create a new session and clear the chat."""
    new_id = session_manager.create_session()
    return new_id, []

def user_message(user_msg, history, session_id):
    """
    After the user hits enter, append the user's message to the conversation.
    Return updated conversation so the UI can display it.
    """
    if not user_msg.strip():
        return "", history  # if user didn't type anything
    # Append the new user message to the conversation
    history.append({"role": "user", "content": user_msg})
    return "", history

def bot_reply(history, session_id):
    """
    Stream the multi-model response. We rely on the *last* user message in `history`,
    then call each model in turn, appending partial updates. Yields updated conversation each time.
    """
    if not history or history[-1]["role"] != "user":
        return  # There's no new user message to respond to

    # Optionally load existing session, if you have session logic
    session = session_manager.load_session(session_id) if session_id else None
    if session is None:
        session = {"history": []}

    # 1) Qwen2.5-Coder-32B
    # Add an assistant message placeholder
    history.append({"role": "assistant", "content": "πŸ”΅ Qwen2.5-Coder-32B-Instruct is thinking..."})
    yield history

    resp1 = query_model("Qwen2.5-Coder-32B-Instruct", history)
    updated_content = f"πŸ”΅ **Qwen2.5-Coder-32B-Instruct**\n{resp1}"
    history[-1]["content"] = updated_content
    yield history

    # 2) Qwen2.5-72B
    updated_content += "\n\n🟣 Qwen2.5-72B-Instruct is thinking..."
    history[-1]["content"] = updated_content
    yield history

    resp2 = query_model("Qwen2.5-72B-Instruct", history)
    updated_content += f"\n\n🟣 **Qwen2.5-72B-Instruct**\n{resp2}"
    history[-1]["content"] = updated_content
    yield history

    # 3) Llama3.3-70B
    updated_content += "\n\n🟑 Llama3.3-70B-Instruct is thinking..."
    history[-1]["content"] = updated_content
    yield history

    resp3 = query_model("Llama3.3-70B-Instruct", history)
    updated_content += f"\n\n🟑 **Llama3.3-70B-Instruct**\n{resp3}"
    history[-1]["content"] = updated_content
    yield history

    # Save session, if needed
    session["history"] = history
    session_manager.save_session(session_id, session)

def clear_chat():
    """
    Clears the Chatbot entirely (set it to an empty list).
    """
    return []

# Build the Gradio Blocks interface
with gr.Blocks() as demo:
    gr.Markdown("## Multi-LLM Collaboration Chat (Streaming)")

    with gr.Row():
        session_id = gr.State(session_manager.create_session)
        new_session_btn = gr.Button("πŸ”„ New Session")

    # Chatbot with "type='messages'" for streaming messages and LaTeX delimiters
    chatbot = gr.Chatbot(
        type="messages",
        height=550,
        latex_delimiters=[
            {"left": "$", "right": "$", "display": False},  # inline math
            {"left": "$$", "right": "$$", "display": True}   # display math
        ]
    )

    msg = gr.Textbox(label="Your Message")
    clear_btn = gr.Button("Clear")

    # Wire up the events:
    # 1) On user submit:
    msg.submit(
        fn=user_message, 
        inputs=[msg, chatbot, session_id], 
        outputs=[msg, chatbot], 
        queue=False
    ).then(
        fn=bot_reply, 
        inputs=[chatbot, session_id], 
        outputs=[chatbot]
    )

    # 2) On "Clear" click, empty the chat:
    clear_btn.click(fn=clear_chat, outputs=chatbot, queue=False)

    # 3) On "New Session" click, get a fresh session ID and clear chat:
    new_session_btn.click(fn=on_new_session, outputs=[session_id, chatbot], queue=False)

if __name__ == "__main__":
    demo.launch()