Spaces:
Running
Running
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
AVAILABLE_MODELS = { | |
"distilgpt2": "distilgpt2", | |
"bloomz-560m": "bigscience/bloomz-560m", | |
"gpt2-medium": "gpt2-medium", | |
"opt-350m": "facebook/opt-350m", | |
"pythia-160m": "EleutherAI/pythia-160m" | |
} | |
class TextGenerator: | |
def __init__(self): | |
self.model = None | |
self.tokenizer = None | |
def load_model(self, model_name: str) -> str: | |
try: | |
self.model = AutoModelForCausalLM.from_pretrained(AVAILABLE_MODELS[model_name]) | |
self.tokenizer = AutoTokenizer.from_pretrained(AVAILABLE_MODELS[model_name]) | |
return f"Successfully loaded {model_name}" | |
except Exception as e: | |
return f"Error loading model: {str(e)}" | |
def get_next_token_predictions(self, text: str, top_k: int = 10): | |
if not self.model or not self.tokenizer: | |
return [], [] | |
inputs = self.tokenizer(text, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = self.model(**inputs) | |
logits = outputs.logits[0, -1, :] | |
probs = torch.nn.functional.softmax(logits, dim=-1) | |
top_k_probs, top_k_indices = torch.topk(probs, top_k) | |
top_k_tokens = [self.tokenizer.decode([idx.item()]) for idx in top_k_indices] | |
return top_k_tokens, top_k_probs.tolist() | |
generator = TextGenerator() | |
def format_predictions(tokens, probs): | |
if not tokens or not probs: | |
return "No predictions available" | |
formatted = "Predicted next tokens:\n\n" | |
for token, prob in zip(tokens, probs): | |
formatted += f"'{token}' : {prob:.4f}\n" | |
return formatted | |
def update_output(model_name, text, custom_token, selected_token): | |
output = text | |
if not generator.model or generator.model.name_or_path != AVAILABLE_MODELS[model_name]: | |
load_message = generator.load_model(model_name) | |
if "Error" in load_message: | |
return text, "", "", gr.update(choices=[]), load_message | |
if custom_token: | |
output += custom_token | |
elif selected_token: | |
output += selected_token.strip("'") | |
tokens, probs = generator.get_next_token_predictions(output) | |
predictions = format_predictions(tokens, probs) | |
token_choices = [f"'{token}'" for token in tokens] | |
return output, "", "", gr.update(choices=token_choices), predictions | |
demo = gr.Interface( | |
fn=update_output, | |
inputs=[ | |
gr.Dropdown(choices=list(AVAILABLE_MODELS.keys()), value="distilgpt2", label="Select Model"), | |
gr.Textbox(lines=5, label="Generated Text", placeholder="Start typing or select a token..."), | |
gr.Textbox(label="Custom Token", placeholder="Type your own token..."), | |
gr.Dropdown(choices=[], label="Select from predicted tokens") | |
], | |
outputs=[ | |
gr.Textbox(lines=5, label="Generated Text"), | |
gr.Textbox(label="Custom Token"), | |
gr.Textbox(label="Selected Token"), | |
gr.Dropdown(label="Predicted Tokens"), | |
gr.Textbox(lines=12, label="Predictions") | |
], | |
title="Interactive Text Generation", | |
description="Generate text by selecting predicted tokens or writing your own." | |
) |