llms-demo / app.py
yabramuvdi's picture
Update app.py
6c99f7c verified
raw
history blame
2.64 kB
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
AVAILABLE_MODELS = {
"distilgpt2": "distilgpt2",
"bloomz-560m": "bigscience/bloomz-560m",
"gpt2-medium": "gpt2-medium",
"opt-350m": "facebook/opt-350m",
"pythia-160m": "EleutherAI/pythia-160m"
}
generator = None
def load_model(model_name):
global generator
try:
model = AutoModelForCausalLM.from_pretrained(AVAILABLE_MODELS[model_name])
tokenizer = AutoTokenizer.from_pretrained(AVAILABLE_MODELS[model_name])
generator = (model, tokenizer)
return f"Successfully loaded {model_name}"
except Exception as e:
return f"Error loading model: {str(e)}"
def get_predictions(text, model_name):
global generator
if not generator:
load_model(model_name)
model, tokenizer = generator
inputs = tokenizer(text, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits[0, -1, :]
probs = torch.nn.functional.softmax(logits, dim=-1)
top_k_probs, top_k_indices = torch.topk(probs, k=10)
top_k_tokens = [tokenizer.decode([idx.item()]) for idx in top_k_indices]
predictions = "\n".join([f"'{token}' : {prob:.4f}" for token, prob in zip(top_k_tokens, top_k_probs)])
return top_k_tokens, predictions
def generate(model_name, text, token_choice="", custom_token=""):
if token_choice:
text += token_choice.strip("'")
if custom_token:
text += custom_token
tokens, predictions = get_predictions(text, model_name)
return text, gr.Dropdown(choices=[f"'{t}'" for t in tokens]), predictions
with gr.Blocks() as demo:
gr.Markdown("# Interactive Text Generation")
model_name = gr.Dropdown(
choices=list(AVAILABLE_MODELS.keys()),
value="distilgpt2",
label="Select Model"
)
text = gr.Textbox(
lines=5,
label="Text",
placeholder="Type or select tokens to generate text..."
)
with gr.Row():
token_choice = gr.Dropdown(
choices=[],
label="Select predicted token"
)
custom_token = gr.Textbox(
label="Or type custom token"
)
predictions = gr.Textbox(
label="Predictions",
lines=10
)
for component in [model_name, token_choice, custom_token]:
component.change(
generate,
inputs=[model_name, text, token_choice, custom_token],
outputs=[text, token_choice, predictions]
)
demo.queue().launch()