import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer import torch AVAILABLE_MODELS = { "distilgpt2": "distilgpt2", "bloomz-560m": "bigscience/bloomz-560m", "gpt2-medium": "gpt2-medium", "opt-350m": "facebook/opt-350m", "pythia-160m": "EleutherAI/pythia-160m" } generator = None def load_model(model_name): global generator try: model = AutoModelForCausalLM.from_pretrained(AVAILABLE_MODELS[model_name]) tokenizer = AutoTokenizer.from_pretrained(AVAILABLE_MODELS[model_name]) generator = (model, tokenizer) return f"Successfully loaded {model_name}" except Exception as e: return f"Error loading model: {str(e)}" def get_predictions(text, model_name): global generator if not generator: load_model(model_name) model, tokenizer = generator inputs = tokenizer(text, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits[0, -1, :] probs = torch.nn.functional.softmax(logits, dim=-1) top_k_probs, top_k_indices = torch.topk(probs, k=10) top_k_tokens = [tokenizer.decode([idx.item()]) for idx in top_k_indices] predictions = "\n".join([f"'{token}' : {prob:.4f}" for token, prob in zip(top_k_tokens, top_k_probs)]) return top_k_tokens, predictions def generate(model_name, text, token_choice="", custom_token=""): if token_choice: text += token_choice.strip("'") if custom_token: text += custom_token tokens, predictions = get_predictions(text, model_name) return text, gr.Dropdown(choices=[f"'{t}'" for t in tokens]), predictions with gr.Blocks() as demo: gr.Markdown("# Interactive Text Generation") model_name = gr.Dropdown( choices=list(AVAILABLE_MODELS.keys()), value="distilgpt2", label="Select Model" ) text = gr.Textbox( lines=5, label="Text", placeholder="Type or select tokens to generate text..." ) with gr.Row(): token_choice = gr.Dropdown( choices=[], label="Select predicted token" ) custom_token = gr.Textbox( label="Or type custom token" ) predictions = gr.Textbox( label="Predictions", lines=10 ) for component in [model_name, token_choice, custom_token]: component.change( generate, inputs=[model_name, text, token_choice, custom_token], outputs=[text, token_choice, predictions] ) demo.queue().launch()