choupijiang / app.py
luminoussg's picture
Update app.py
c9870b1 verified
raw
history blame
3.92 kB
import gradio as gr
import os
import requests
import threading
from typing import List, Dict, Any
# Get the Hugging Face API key from Spaces secrets
HF_API_KEY = os.getenv("HF_API_KEY")
# Model endpoints configuration
MODEL_ENDPOINTS = {
"Qwen2.5-72B-Instruct": "https://api-inference.huggingface.co/models/Qwen/Qwen2.5-72B-Instruct",
"Llama3.3-70B-Instruct": "https://api-inference.huggingface.co/models/meta-llama/Llama-3.3-70B-Instruct",
"Qwen2.5-Coder-32B-Instruct": "https://api-inference.huggingface.co/models/Qwen/Qwen2.5-Coder-32B-Instruct",
}
def query_model(model_name: str, messages: List[Dict[str, str]]) -> str:
"""Query a single model with the chat history"""
endpoint = MODEL_ENDPOINTS[model_name]
headers = {
"Authorization": f"Bearer {HF_API_KEY}",
"Content-Type": "application/json"
}
# Model-specific prompt formatting
model_prompts = {
"Qwen2.5-72B-Instruct": (
f"<|im_start|>user\n{messages[-1]['content']}<|im_end|>\n<|im_start|>assistant\n"
),
"Llama3.3-70B-Instruct": (
"<|begin_of_text|>"
"<|start_header_id|>user<|end_header_id|>\n\n"
f"{messages[-1]['content']}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
),
"Qwen2.5-Coder-32B-Instruct": (
f"<|im_start|>user\n{messages[-1]['content']}<|im_end|>\n<|im_start|>assistant\n"
)
}
# Model-specific stop sequences
stop_sequences = {
"Qwen2.5-72B-Instruct": ["<|im_end|>", "<|endoftext|>"],
"Llama3.3-70B-Instruct": ["<|eot_id|>", "\nuser:"],
"Qwen2.5-Coder-32B-Instruct": ["<|im_end|>", "<|endoftext|>"]
}
payload = {
"inputs": model_prompts[model_name],
"parameters": {
"max_tokens": 1024,
"temperature": 0.7,
"stop_sequences": stop_sequences[model_name],
"return_full_text": False
}
}
try:
response = requests.post(endpoint, json=payload, headers=headers)
response.raise_for_status()
result = response.json()[0]['generated_text']
# Clean up response formatting
result = result.split('<|')[0] # Remove any remaining special tokens
result = result.replace('**', '').replace('##', '') # Remove markdown
result = result.strip() # Remove leading/trailing whitespace
return result.split('\n\n')[0] # Return only first paragraph
except Exception as e:
return f"{model_name} error: {str(e)}"
def respond(message: str, history: List[List[str]]) -> str:
"""Handle chat responses from all models"""
# Prepare messages in OpenAI format
messages = [{"role": "user", "content": message}]
# Create threads for concurrent model queries
threads = []
results = {}
def get_model_response(model_name):
results[model_name] = query_model(model_name, messages)
for model_name in MODEL_ENDPOINTS:
thread = threading.Thread(target=get_model_response, args=(model_name,))
thread.start()
threads.append(thread)
# Wait for all threads to complete
for thread in threads:
thread.join()
# Format responses from all models
responses = []
for model_name, response in results.items():
responses.append(f"**{model_name}**:\n{response}")
# Format responses with clear separation
return "\n\n----------------------------------------\n\n".join(responses)
# Create the Gradio interface
chat_interface = gr.ChatInterface(
respond,
title="Multi-LLM Collaboration Chat",
description="A group chat with Qwen2.5-72B, Llama3.3-70B, and Qwen2.5-Coder-32B",
examples=["How can I optimize Python code?", "Explain quantum computing basics"],
theme="soft"
)
if __name__ == "__main__":
chat_interface.launch(share=True)