import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer import logging # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Dictionary to store loaded models and tokenizers loaded_models = {} # List of available models (ensure these are correct and accessible) models = [ "Qwen/Qwen2.5-7B-Instruct", "Qwen/Qwen2.5-0.5B-Instruct" ] def load_all_models(): """ Pre-loads all models and their tokenizers into memory. """ for model_name in models: if model_name not in loaded_models: try: logger.info(f"Loading model: {model_name}") tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name).to( "cuda" if torch.cuda.is_available() else "cpu") loaded_models[model_name] = (model, tokenizer) logger.info(f"Successfully loaded {model_name}") except Exception as e: logger.error(f"Failed to load model {model_name}: {e}") def get_model_response(model_name, message): """ Generates a response from the specified model given a user message. """ try: model, tokenizer = loaded_models[model_name] inputs = tokenizer(message, return_tensors="pt").to(model.device) # Generate response with appropriate parameters with torch.no_grad(): outputs = model.generate( **inputs, max_length=512, do_sample=True, top_p=0.95, top_k=50 ) response = tokenizer.decode(outputs[0], skip_special_tokens=True) return response except KeyError: logger.error(f"Model {model_name} not found in loaded_models.") return f"Error: Model {model_name} not loaded." except Exception as e: logger.error(f"Error generating response from {model_name}: {e}") return f"Error generating response: {e}" def chat(message, history1, history2, model1, model2): """ Handles the chat interaction by getting responses from both models and updating their respective histories. """ response1 = get_model_response(model1, message) response2 = get_model_response(model2, message) history1 = history1 or [] history2 = history2 or [] # Update history for Model 1 history1.append(("User", message)) history1.append((model1.split("/")[-1], response1)) # Update history for Model 2 history2.append(("User", message)) history2.append((model2.split("/")[-1], response2)) return history1, history2 # Initialize vote counts vote_counts = {"model1": 0, "model2": 0} def upvote_vote(model1, model2): """ Increments the vote count for Model 1 and returns updated counts. """ vote_counts["model1"] += 1 return f"Votes - {model1.split('/')[-1]}: {vote_counts['model1']}, {model2.split('/')[-1]}: {vote_counts['model2']}" def downvote_vote(model1, model2): """ Increments the vote count for Model 2 and returns updated counts. """ vote_counts["model2"] += 1 return f"Votes - {model1.split('/')[-1]}: {vote_counts['model1']}, {model2.split('/')[-1]}: {vote_counts['model2']}" def clear_chat(): """ Clears both chat histories and resets vote counts. """ global vote_counts vote_counts = {"model1": 0, "model2": 0} return [], [], "Votes - 0, 0" # Pre-load all models before building the Gradio interface load_all_models() with gr.Blocks() as demo: gr.Markdown("# ๐Ÿค– Model Comparison Space") # Dropdowns for selecting models with gr.Row(): model1_dropdown = gr.Dropdown(choices=models, label="Model 1", value=models[0]) model2_dropdown = gr.Dropdown(choices=models, label="Model 2", value=models[1]) # Separate chatboxes for each model with gr.Row(): with gr.Column(): gr.Markdown("### Model 1 Chat") chatbot1 = gr.Chatbot(label=f"{models[0].split('/')[-1]} Chat History") with gr.Column(): gr.Markdown("### Model 2 Chat") chatbot2 = gr.Chatbot(label=f"{models[1].split('/')[-1]} Chat History") # Input textbox for user message msg = gr.Textbox(label="๐Ÿ’ฌ Your Message", placeholder="Type your message here...") # Buttons for upvote, downvote, and clearing the chat with gr.Row(): upvote = gr.Button("๐Ÿ‘ Upvote Model 1") downvote = gr.Button("๐Ÿ‘ Downvote Model 2") clear = gr.Button("๐Ÿงน Clear Chat") # Textbox to display vote counts vote_text = gr.Textbox(label="๐Ÿ† Vote Counts", value="Votes - 0, 0", interactive=False) # Define interactions msg.submit( chat, inputs=[msg, chatbot1, chatbot2, model1_dropdown, model2_dropdown], outputs=[chatbot1, chatbot2] ) upvote.click( upvote_vote, inputs=[model1_dropdown, model2_dropdown], outputs=vote_text ) downvote.click( downvote_vote, inputs=[model1_dropdown, model2_dropdown], outputs=vote_text ) clear.click( clear_chat, outputs=[chatbot1, chatbot2, vote_text] ) if __name__ == "__main__": demo.launch(share=True)