Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import logging | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Dictionary to store loaded models and tokenizers | |
loaded_models = {} | |
# List of available models (ensure these are correct and accessible) | |
models = [ | |
"Qwen/Qwen2.5-7B-Instruct", | |
"meta-llama/Llama-3.2-1B" | |
] | |
def load_all_models(): | |
""" | |
Pre-loads all models and their tokenizers into memory. | |
""" | |
for model_name in models: | |
if model_name not in loaded_models: | |
try: | |
logger.info(f"Loading model: {model_name}") | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name).to( | |
"cuda" if torch.cuda.is_available() else "cpu") | |
loaded_models[model_name] = (model, tokenizer) | |
logger.info(f"Successfully loaded {model_name}") | |
except Exception as e: | |
logger.error(f"Failed to load model {model_name}: {e}") | |
def get_model_response(model_name, message): | |
""" | |
Generates a response from the specified model given a user message. | |
""" | |
try: | |
model, tokenizer = loaded_models[model_name] | |
inputs = tokenizer(message, return_tensors="pt").to(model.device) | |
# Generate response with appropriate parameters | |
with torch.no_grad(): | |
outputs = model.generate( | |
**inputs, | |
max_length=512, | |
do_sample=True, | |
top_p=0.95, | |
top_k=50 | |
) | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return response | |
except KeyError: | |
logger.error(f"Model {model_name} not found in loaded_models.") | |
return f"Error: Model {model_name} not loaded." | |
except Exception as e: | |
logger.error(f"Error generating response from {model_name}: {e}") | |
return f"Error generating response: {e}" | |
def chat(message, history1, history2, model1, model2): | |
""" | |
Handles the chat interaction by getting responses from both models | |
and updating their respective histories. | |
""" | |
response1 = get_model_response(model1, message) | |
response2 = get_model_response(model2, message) | |
history1 = history1 or [] | |
history2 = history2 or [] | |
# Update history for Model 1 | |
history1.append(("User", message)) | |
history1.append((model1.split("/")[-1], response1)) | |
# Update history for Model 2 | |
history2.append(("User", message)) | |
history2.append((model2.split("/")[-1], response2)) | |
return history1, history2 | |
# Initialize vote counts | |
vote_counts = {"model1": 0, "model2": 0} | |
def upvote_vote(model1, model2): | |
""" | |
Increments the vote count for Model 1 and returns updated counts. | |
""" | |
vote_counts["model1"] += 1 | |
return f"Votes - {model1.split('/')[-1]}: {vote_counts['model1']}, {model2.split('/')[-1]}: {vote_counts['model2']}" | |
def downvote_vote(model1, model2): | |
""" | |
Increments the vote count for Model 2 and returns updated counts. | |
""" | |
vote_counts["model2"] += 1 | |
return f"Votes - {model1.split('/')[-1]}: {vote_counts['model1']}, {model2.split('/')[-1]}: {vote_counts['model2']}" | |
def clear_chat(): | |
""" | |
Clears both chat histories and resets vote counts. | |
""" | |
global vote_counts | |
vote_counts = {"model1": 0, "model2": 0} | |
return [], [], "Votes - 0, 0" | |
# Pre-load all models before building the Gradio interface | |
load_all_models() | |
with gr.Blocks() as demo: | |
gr.Markdown("# π€ Model Comparison Space") | |
# Dropdowns for selecting models | |
with gr.Row(): | |
model1_dropdown = gr.Dropdown(choices=models, label="Model 1", value=models[0]) | |
model2_dropdown = gr.Dropdown(choices=models, label="Model 2", value=models[1]) | |
# Separate chatboxes for each model | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("### Model 1 Chat") | |
chatbot1 = gr.Chatbot(label=f"{models[0].split('/')[-1]} Chat History") | |
with gr.Column(): | |
gr.Markdown("### Model 2 Chat") | |
chatbot2 = gr.Chatbot(label=f"{models[1].split('/')[-1]} Chat History") | |
# Input textbox for user message | |
msg = gr.Textbox(label="π¬ Your Message", placeholder="Type your message here...") | |
# Buttons for upvote, downvote, and clearing the chat | |
with gr.Row(): | |
upvote = gr.Button("π Upvote Model 1") | |
downvote = gr.Button("π Downvote Model 2") | |
clear = gr.Button("π§Ή Clear Chat") | |
# Textbox to display vote counts | |
vote_text = gr.Textbox(label="π Vote Counts", value="Votes - 0, 0", interactive=False) | |
# Define interactions | |
msg.submit( | |
chat, | |
inputs=[msg, chatbot1, chatbot2, model1_dropdown, model2_dropdown], | |
outputs=[chatbot1, chatbot2] | |
) | |
upvote.click( | |
upvote_vote, | |
inputs=[model1_dropdown, model2_dropdown], | |
outputs=vote_text | |
) | |
downvote.click( | |
downvote_vote, | |
inputs=[model1_dropdown, model2_dropdown], | |
outputs=vote_text | |
) | |
clear.click( | |
clear_chat, | |
outputs=[chatbot1, chatbot2, vote_text] | |
) | |
if __name__ == "__main__": | |
demo.launch(share=True) | |