Spaces:
Running
Running
File size: 2,638 Bytes
617bd81 6c99f7c 617bd81 6c99f7c 617bd81 6c99f7c 617bd81 6c99f7c f586a0d 6c99f7c 617bd81 6c99f7c 617bd81 6c99f7c 617bd81 6c99f7c 617bd81 6c99f7c 617bd81 6c99f7c 175fea5 6c99f7c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
AVAILABLE_MODELS = {
"distilgpt2": "distilgpt2",
"bloomz-560m": "bigscience/bloomz-560m",
"gpt2-medium": "gpt2-medium",
"opt-350m": "facebook/opt-350m",
"pythia-160m": "EleutherAI/pythia-160m"
}
generator = None
def load_model(model_name):
global generator
try:
model = AutoModelForCausalLM.from_pretrained(AVAILABLE_MODELS[model_name])
tokenizer = AutoTokenizer.from_pretrained(AVAILABLE_MODELS[model_name])
generator = (model, tokenizer)
return f"Successfully loaded {model_name}"
except Exception as e:
return f"Error loading model: {str(e)}"
def get_predictions(text, model_name):
global generator
if not generator:
load_model(model_name)
model, tokenizer = generator
inputs = tokenizer(text, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits[0, -1, :]
probs = torch.nn.functional.softmax(logits, dim=-1)
top_k_probs, top_k_indices = torch.topk(probs, k=10)
top_k_tokens = [tokenizer.decode([idx.item()]) for idx in top_k_indices]
predictions = "\n".join([f"'{token}' : {prob:.4f}" for token, prob in zip(top_k_tokens, top_k_probs)])
return top_k_tokens, predictions
def generate(model_name, text, token_choice="", custom_token=""):
if token_choice:
text += token_choice.strip("'")
if custom_token:
text += custom_token
tokens, predictions = get_predictions(text, model_name)
return text, gr.Dropdown(choices=[f"'{t}'" for t in tokens]), predictions
with gr.Blocks() as demo:
gr.Markdown("# Interactive Text Generation")
model_name = gr.Dropdown(
choices=list(AVAILABLE_MODELS.keys()),
value="distilgpt2",
label="Select Model"
)
text = gr.Textbox(
lines=5,
label="Text",
placeholder="Type or select tokens to generate text..."
)
with gr.Row():
token_choice = gr.Dropdown(
choices=[],
label="Select predicted token"
)
custom_token = gr.Textbox(
label="Or type custom token"
)
predictions = gr.Textbox(
label="Predictions",
lines=10
)
for component in [model_name, token_choice, custom_token]:
component.change(
generate,
inputs=[model_name, text, token_choice, custom_token],
outputs=[text, token_choice, predictions]
)
demo.queue().launch() |