test_comparison / app.py
hanzla javaid
updates
02caa8d
raw
history blame
5.34 kB
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Dictionary to store loaded models and tokenizers
loaded_models = {}
# List of available models (ensure these are correct and accessible)
models = [
"Qwen/Qwen2.5-7B-Instruct",
"meta-llama/Llama-3.2-1B"
]
def load_all_models():
"""
Pre-loads all models and their tokenizers into memory.
"""
for model_name in models:
if model_name not in loaded_models:
try:
logger.info(f"Loading model: {model_name}")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name).to(
"cuda" if torch.cuda.is_available() else "cpu")
loaded_models[model_name] = (model, tokenizer)
logger.info(f"Successfully loaded {model_name}")
except Exception as e:
logger.error(f"Failed to load model {model_name}: {e}")
def get_model_response(model_name, message):
"""
Generates a response from the specified model given a user message.
"""
try:
model, tokenizer = loaded_models[model_name]
inputs = tokenizer(message, return_tensors="pt").to(model.device)
# Generate response with appropriate parameters
with torch.no_grad():
outputs = model.generate(
**inputs,
max_length=512,
do_sample=True,
top_p=0.95,
top_k=50
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
except KeyError:
logger.error(f"Model {model_name} not found in loaded_models.")
return f"Error: Model {model_name} not loaded."
except Exception as e:
logger.error(f"Error generating response from {model_name}: {e}")
return f"Error generating response: {e}"
def chat(message, history1, history2, model1, model2):
"""
Handles the chat interaction by getting responses from both models
and updating their respective histories.
"""
response1 = get_model_response(model1, message)
response2 = get_model_response(model2, message)
history1 = history1 or []
history2 = history2 or []
# Update history for Model 1
history1.append(("User", message))
history1.append((model1.split("/")[-1], response1))
# Update history for Model 2
history2.append(("User", message))
history2.append((model2.split("/")[-1], response2))
return history1, history2
# Initialize vote counts
vote_counts = {"model1": 0, "model2": 0}
def upvote_vote(model1, model2):
"""
Increments the vote count for Model 1 and returns updated counts.
"""
vote_counts["model1"] += 1
return f"Votes - {model1.split('/')[-1]}: {vote_counts['model1']}, {model2.split('/')[-1]}: {vote_counts['model2']}"
def downvote_vote(model1, model2):
"""
Increments the vote count for Model 2 and returns updated counts.
"""
vote_counts["model2"] += 1
return f"Votes - {model1.split('/')[-1]}: {vote_counts['model1']}, {model2.split('/')[-1]}: {vote_counts['model2']}"
def clear_chat():
"""
Clears both chat histories and resets vote counts.
"""
global vote_counts
vote_counts = {"model1": 0, "model2": 0}
return [], [], "Votes - 0, 0"
# Pre-load all models before building the Gradio interface
load_all_models()
with gr.Blocks() as demo:
gr.Markdown("# πŸ€– Model Comparison Space")
# Dropdowns for selecting models
with gr.Row():
model1_dropdown = gr.Dropdown(choices=models, label="Model 1", value=models[0])
model2_dropdown = gr.Dropdown(choices=models, label="Model 2", value=models[1])
# Separate chatboxes for each model
with gr.Row():
with gr.Column():
gr.Markdown("### Model 1 Chat")
chatbot1 = gr.Chatbot(label=f"{models[0].split('/')[-1]} Chat History")
with gr.Column():
gr.Markdown("### Model 2 Chat")
chatbot2 = gr.Chatbot(label=f"{models[1].split('/')[-1]} Chat History")
# Input textbox for user message
msg = gr.Textbox(label="πŸ’¬ Your Message", placeholder="Type your message here...")
# Buttons for upvote, downvote, and clearing the chat
with gr.Row():
upvote = gr.Button("πŸ‘ Upvote Model 1")
downvote = gr.Button("πŸ‘ Downvote Model 2")
clear = gr.Button("🧹 Clear Chat")
# Textbox to display vote counts
vote_text = gr.Textbox(label="πŸ† Vote Counts", value="Votes - 0, 0", interactive=False)
# Define interactions
msg.submit(
chat,
inputs=[msg, chatbot1, chatbot2, model1_dropdown, model2_dropdown],
outputs=[chatbot1, chatbot2]
)
upvote.click(
upvote_vote,
inputs=[model1_dropdown, model2_dropdown],
outputs=vote_text
)
downvote.click(
downvote_vote,
inputs=[model1_dropdown, model2_dropdown],
outputs=vote_text
)
clear.click(
clear_chat,
outputs=[chatbot1, chatbot2, vote_text]
)
if __name__ == "__main__":
demo.launch(share=True)