import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer import torch AVAILABLE_MODELS = { "distilgpt2": "distilgpt2", "bloomz-560m": "bigscience/bloomz-560m", "gpt2-medium": "gpt2-medium", "opt-350m": "facebook/opt-350m", "pythia-160m": "EleutherAI/pythia-160m" } class TextGenerator: def __init__(self): self.model = None self.tokenizer = None def load_model(self, model_name: str) -> str: try: self.model = AutoModelForCausalLM.from_pretrained(AVAILABLE_MODELS[model_name]) self.tokenizer = AutoTokenizer.from_pretrained(AVAILABLE_MODELS[model_name]) return f"Successfully loaded {model_name}" except Exception as e: return f"Error loading model: {str(e)}" def get_next_token_predictions(self, text: str, top_k: int = 10): if not self.model or not self.tokenizer: return [], [] inputs = self.tokenizer(text, return_tensors="pt") with torch.no_grad(): outputs = self.model(**inputs) logits = outputs.logits[0, -1, :] probs = torch.nn.functional.softmax(logits, dim=-1) top_k_probs, top_k_indices = torch.topk(probs, top_k) top_k_tokens = [self.tokenizer.decode([idx.item()]) for idx in top_k_indices] return top_k_tokens, top_k_probs.tolist() generator = TextGenerator() def format_predictions(tokens, probs): if not tokens or not probs: return "No predictions available" formatted = "Predicted next tokens:\n\n" for token, prob in zip(tokens, probs): formatted += f"'{token}' : {prob:.4f}\n" return formatted def update_output(model_name, text, custom_token, selected_token): output = text if not generator.model or generator.model.name_or_path != AVAILABLE_MODELS[model_name]: load_message = generator.load_model(model_name) if "Error" in load_message: return text, "", "", gr.update(choices=[]), load_message if custom_token: output += custom_token elif selected_token: output += selected_token.strip("'") tokens, probs = generator.get_next_token_predictions(output) predictions = format_predictions(tokens, probs) token_choices = [f"'{token}'" for token in tokens] return output, "", "", gr.update(choices=token_choices), predictions demo = gr.Interface( fn=update_output, inputs=[ gr.Dropdown(choices=list(AVAILABLE_MODELS.keys()), value="distilgpt2", label="Select Model"), gr.Textbox(lines=5, label="Generated Text", placeholder="Start typing or select a token..."), gr.Textbox(label="Custom Token", placeholder="Type your own token..."), gr.Dropdown(choices=[], label="Select from predicted tokens") ], outputs=[ gr.Textbox(lines=5, label="Generated Text"), gr.Textbox(label="Custom Token"), gr.Textbox(label="Selected Token"), gr.Dropdown(label="Predicted Tokens"), gr.Textbox(lines=12, label="Predictions") ], title="Interactive Text Generation", description="Generate text by selecting predicted tokens or writing your own." )