Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
import requests | |
import threading | |
from datetime import datetime | |
from typing import List, Dict, Any, Generator | |
from session_manager import SessionManager | |
# Initialize session manager and get HF API key | |
session_manager = SessionManager() | |
HF_API_KEY = os.getenv("HF_API_KEY") | |
# Model endpoints configuration | |
MODEL_ENDPOINTS = { | |
"Qwen2.5-72B-Instruct": "https://api-inference.huggingface.co/models/Qwen/Qwen2.5-72B-Instruct", | |
"Llama3.3-70B-Instruct": "https://api-inference.huggingface.co/models/meta-llama/Llama-3.3-70B-Instruct", | |
"Qwen2.5-Coder-32B-Instruct": "https://api-inference.huggingface.co/models/Qwen/Qwen2.5-Coder-32B-Instruct", | |
} | |
def query_model(model_name: str, messages: List[Dict[str, str]]) -> str: | |
"""Query a single model with the chat history.""" | |
endpoint = MODEL_ENDPOINTS[model_name] | |
headers = { | |
"Authorization": f"Bearer {HF_API_KEY}", | |
"Content-Type": "application/json" | |
} | |
# Build full conversation history for context | |
conversation = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) | |
# Model-specific prompt formatting | |
model_prompts = { | |
"Qwen2.5-72B-Instruct": ( | |
f"<|im_start|>system\nCollaborate with other experts. Previous discussion:\n{conversation}<|im_end|>\n" | |
"<|im_start|>assistant\nMy analysis:" | |
), | |
"Llama3.3-70B-Instruct": ( | |
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n" | |
f"Build upon this discussion:\n{conversation}<|eot_id|>\n" | |
"<|start_header_id|>assistant<|end_header_id|>\nMy contribution:" | |
), | |
"Qwen2.5-Coder-32B-Instruct": ( | |
f"<|im_start|>system\nTechnical discussion context:\n{conversation}<|im_end|>\n" | |
"<|im_start|>assistant\nTechnical perspective:" | |
) | |
} | |
# Model-specific stop sequences | |
stop_sequences = { | |
"Qwen2.5-72B-Instruct": ["<|im_end|>", "<|endoftext|>"], | |
"Llama3.3-70B-Instruct": ["<|eot_id|>", "\nuser:"], | |
"Qwen2.5-Coder-32B-Instruct": ["<|im_end|>", "<|endoftext|>"] | |
} | |
payload = { | |
"inputs": model_prompts[model_name], | |
"parameters": { | |
"max_tokens": 2048, | |
"temperature": 0.7, | |
"stop_sequences": stop_sequences[model_name], | |
"return_full_text": False | |
} | |
} | |
try: | |
response = requests.post(endpoint, json=payload, headers=headers) | |
response.raise_for_status() | |
result = response.json()[0]['generated_text'] | |
# Basic cleanup | |
result = result.split('<|')[0] # Remove special tokens | |
result = result.replace('**', '').replace('##', '') | |
return result.strip() | |
except Exception as e: | |
return f"{model_name} error: {str(e)}" | |
def respond(message: str, history: List[List[str]], session_id: str) -> Generator[str, None, None]: | |
"""Handle sequential model responses with context preservation.""" | |
# Load or initialize session | |
session = session_manager.load_session(session_id) | |
if not isinstance(session, dict) or "history" not in session: | |
session = {"history": []} | |
# Build context from session history | |
messages = [] | |
for entry in session["history"]: | |
if entry["type"] == "user": | |
messages.append({"role": "user", "content": entry["content"]}) | |
else: | |
messages.append({"role": "assistant", "content": f"{entry['model']}: {entry['content']}"}) | |
# Add current message | |
messages.append({"role": "user", "content": message}) | |
session["history"].append({ | |
"timestamp": datetime.now().isoformat(), | |
"type": "user", | |
"content": message | |
}) | |
# First model | |
yield "π΅ Qwen2.5-Coder-32B-Instruct is thinking..." | |
response1 = query_model("Qwen2.5-Coder-32B-Instruct", messages) | |
session["history"].append({ | |
"timestamp": datetime.now().isoformat(), | |
"type": "assistant", | |
"model": "Qwen2.5-Coder-32B-Instruct", | |
"content": response1 | |
}) | |
messages.append({"role": "assistant", "content": f"Qwen2.5-Coder-32B-Instruct: {response1}"}) | |
yield f"π΅ **Qwen2.5-Coder-32B-Instruct**\n{response1}" | |
# Second model | |
yield f"π΅ **Qwen2.5-Coder-32B-Instruct**\n{response1}\n\nπ£ Qwen2.5-72B-Instruct is thinking..." | |
response2 = query_model("Qwen2.5-72B-Instruct", messages) | |
session["history"].append({ | |
"timestamp": datetime.now().isoformat(), | |
"type": "assistant", | |
"model": "Qwen2.5-72B-Instruct", | |
"content": response2 | |
}) | |
messages.append({"role": "assistant", "content": f"Qwen2.5-72B-Instruct: {response2}"}) | |
yield f"π΅ **Qwen2.5-Coder-32B-Instruct**\n{response1}\n\nπ£ **Qwen2.5-72B-Instruct**\n{response2}" | |
# Final model | |
yield f"π΅ **Qwen2.5-Coder-32B-Instruct**\n{response1}\n\nπ£ **Qwen2.5-72B-Instruct**\n{response2}\n\nπ‘ Llama3.3-70B-Instruct is thinking..." | |
response3 = query_model("Llama3.3-70B-Instruct", messages) | |
session["history"].append({ | |
"timestamp": datetime.now().isoformat(), | |
"type": "assistant", | |
"model": "Llama3.3-70B-Instruct", | |
"content": response3 | |
}) | |
messages.append({"role": "assistant", "content": f"Llama3.3-70B-Instruct: {response3}"}) | |
# Save session | |
session_manager.save_session(session_id, session) | |
# Return final combined response | |
yield ( | |
f"π΅ **Qwen2.5-Coder-32B-Instruct**\n{response1}\n\n" | |
f"π£ **Qwen2.5-72B-Instruct**\n{response2}\n\n" | |
f"π‘ **Llama3.3-70B-Instruct**\n{response3}" | |
) | |
with gr.Blocks() as demo: | |
# -- Include KaTeX for LaTeX rendering -- | |
gr.HTML(""" | |
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/katex.min.css" /> | |
<script defer src="https://cdn.jsdelivr.net/npm/[email protected]/dist/katex.min.js"></script> | |
<script defer src="https://cdn.jsdelivr.net/npm/[email protected]/dist/contrib/auto-render.min.js"></script> | |
<script> | |
// Re-render math whenever new content is added | |
document.addEventListener("DOMContentLoaded", function() { | |
const observer = new MutationObserver(function(mutations) { | |
for (const mutation of mutations) { | |
if (mutation.type === 'childList') { | |
renderMathInElement(document.body, { | |
delimiters: [ | |
{left: "$$", right: "$$", display: true}, | |
{left: "$", right: "$", display: false}, | |
] | |
}); | |
} | |
} | |
}); | |
observer.observe(document.body, { subtree: true, childList: true }); | |
}); | |
</script> | |
""") | |
gr.Markdown("## Multi-LLM Collaboration Chat (with LaTeX support)") | |
with gr.Row(): | |
session_id = gr.State(session_manager.create_session) | |
new_session = gr.Button("π New Session") | |
chatbot = gr.Chatbot(height=600) | |
msg = gr.Textbox(label="Message (Use $...$ or $$...$$ for LaTeX)") | |
def on_new_session(): | |
new_id = session_manager.create_session() | |
return new_id, [] | |
def user(message, history, session_id): | |
return "", history + [[message, None]] | |
def bot(history, session_id): | |
if history and history[-1][1] is None: | |
message = history[-1][0] | |
for response in respond(message, history[:-1], session_id): | |
history[-1][1] = response | |
yield history | |
msg.submit(user, [msg, chatbot, session_id], [msg, chatbot]).then( | |
bot, [chatbot, session_id], [chatbot] | |
) | |
new_session.click(on_new_session, None, [session_id, chatbot]) | |
if __name__ == "__main__": | |
demo.launch(share=True) | |