Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import spaces | |
# Dictionary to store loaded models and tokenizers | |
loaded_models = {} | |
# List of available models (update with your preferred models) | |
models = ["gpt2", "gpt2-medium", "gpt2-large", "EleutherAI/gpt-neo-1.3B"] | |
def load_model(model_name): | |
if model_name not in loaded_models: | |
print(f"Loading model: {model_name}") | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name).to("cuda" if torch.cuda.is_available() else "cpu") | |
loaded_models[model_name] = (model, tokenizer) | |
return loaded_models[model_name] | |
def get_model_response(model_name, message): | |
model, tokenizer = load_model(model_name) | |
inputs = tokenizer(message, return_tensors="pt").to(model.device) | |
with torch.no_grad(): | |
outputs = model.generate(**inputs, max_length=100, num_return_sequences=1, temperature=0.7) | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return response | |
def chat(message, history, model1, model2): | |
response1 = get_model_response(model1, message) | |
response2 = get_model_response(model2, message) | |
return [(message, f"{model1}: {response1}\n\n{model2}: {response2}")] | |
def vote(direction, history): | |
if history: | |
last_interaction = history[-1] | |
vote_text = f"\n\nUser voted: {'π' if direction == 'up' else 'π'}" | |
updated_interaction = (last_interaction[0], last_interaction[1] + vote_text) | |
return history[:-1] + [updated_interaction] | |
return history | |
with gr.Blocks() as demo: | |
gr.Markdown("# Hugging Face Model Comparison Chat") | |
with gr.Row(): | |
model1_dropdown = gr.Dropdown(choices=models, label="Model 1", value=models[0]) | |
model2_dropdown = gr.Dropdown(choices=models, label="Model 2", value=models[1]) | |
chatbot = gr.Chatbot() | |
msg = gr.Textbox(label="Your message") | |
clear = gr.Button("Clear") | |
with gr.Row(): | |
upvote = gr.Button("π Upvote") | |
downvote = gr.Button("π Downvote") | |
msg.submit(chat, [msg, chatbot, model1_dropdown, model2_dropdown], chatbot) | |
clear.click(lambda: None, None, chatbot, queue=False) | |
upvote.click(vote, ["up", chatbot], chatbot) | |
downvote.click(vote, ["down", chatbot], chatbot) | |
if __name__ == "__main__": | |
demo.launch() |