Spaces:
Running
Running
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
AVAILABLE_MODELS = { | |
"distilgpt2": "distilgpt2", | |
"bloomz-560m": "bigscience/bloomz-560m", | |
"gpt2-medium": "gpt2-medium", | |
"opt-350m": "facebook/opt-350m", | |
"pythia-160m": "EleutherAI/pythia-160m" | |
} | |
generator = None | |
def load_model(model_name): | |
global generator | |
try: | |
model = AutoModelForCausalLM.from_pretrained(AVAILABLE_MODELS[model_name]) | |
tokenizer = AutoTokenizer.from_pretrained(AVAILABLE_MODELS[model_name]) | |
generator = (model, tokenizer) | |
return f"Successfully loaded {model_name}" | |
except Exception as e: | |
return f"Error loading model: {str(e)}" | |
def get_predictions(text, model_name): | |
global generator | |
if not generator: | |
load_model(model_name) | |
model, tokenizer = generator | |
inputs = tokenizer(text, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits[0, -1, :] | |
probs = torch.nn.functional.softmax(logits, dim=-1) | |
top_k_probs, top_k_indices = torch.topk(probs, k=10) | |
top_k_tokens = [tokenizer.decode([idx.item()]) for idx in top_k_indices] | |
predictions = "\n".join([f"'{token}' : {prob:.4f}" for token, prob in zip(top_k_tokens, top_k_probs)]) | |
return top_k_tokens, predictions | |
def generate(model_name, text, token_choice="", custom_token=""): | |
if token_choice: | |
text += token_choice.strip("'") | |
if custom_token: | |
text += custom_token | |
tokens, predictions = get_predictions(text, model_name) | |
return text, gr.Dropdown(choices=[f"'{t}'" for t in tokens]), predictions | |
with gr.Blocks() as demo: | |
gr.Markdown("# Interactive Text Generation") | |
model_name = gr.Dropdown( | |
choices=list(AVAILABLE_MODELS.keys()), | |
value="distilgpt2", | |
label="Select Model" | |
) | |
text = gr.Textbox( | |
lines=5, | |
label="Text", | |
placeholder="Type or select tokens to generate text..." | |
) | |
with gr.Row(): | |
token_choice = gr.Dropdown( | |
choices=[], | |
label="Select predicted token" | |
) | |
custom_token = gr.Textbox( | |
label="Or type custom token" | |
) | |
predictions = gr.Textbox( | |
label="Predictions", | |
lines=10 | |
) | |
for component in [model_name, token_choice, custom_token]: | |
component.change( | |
generate, | |
inputs=[model_name, text, token_choice, custom_token], | |
outputs=[text, token_choice, predictions] | |
) | |
demo.queue().launch() |