File size: 2,638 Bytes
617bd81
 
 
 
 
 
 
 
 
 
 
 
6c99f7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
617bd81
6c99f7c
 
 
 
 
 
 
617bd81
6c99f7c
 
 
 
 
617bd81
6c99f7c
 
 
 
 
 
 
 
f586a0d
6c99f7c
 
617bd81
6c99f7c
 
 
 
 
617bd81
6c99f7c
 
 
 
 
617bd81
6c99f7c
 
 
 
 
 
 
 
617bd81
6c99f7c
 
 
 
617bd81
6c99f7c
 
 
 
 
 
175fea5
6c99f7c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

AVAILABLE_MODELS = {
    "distilgpt2": "distilgpt2",
    "bloomz-560m": "bigscience/bloomz-560m",
    "gpt2-medium": "gpt2-medium",
    "opt-350m": "facebook/opt-350m",
    "pythia-160m": "EleutherAI/pythia-160m"
}

generator = None

def load_model(model_name):
    global generator
    try:
        model = AutoModelForCausalLM.from_pretrained(AVAILABLE_MODELS[model_name])
        tokenizer = AutoTokenizer.from_pretrained(AVAILABLE_MODELS[model_name])
        generator = (model, tokenizer)
        return f"Successfully loaded {model_name}"
    except Exception as e:
        return f"Error loading model: {str(e)}"

def get_predictions(text, model_name):
    global generator
    if not generator:
        load_model(model_name)
    
    model, tokenizer = generator
    inputs = tokenizer(text, return_tensors="pt")
    
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits[0, -1, :]
        probs = torch.nn.functional.softmax(logits, dim=-1)
        
    top_k_probs, top_k_indices = torch.topk(probs, k=10)
    top_k_tokens = [tokenizer.decode([idx.item()]) for idx in top_k_indices]
    predictions = "\n".join([f"'{token}' : {prob:.4f}" for token, prob in zip(top_k_tokens, top_k_probs)])
    
    return top_k_tokens, predictions

def generate(model_name, text, token_choice="", custom_token=""):
    if token_choice:
        text += token_choice.strip("'")
    if custom_token:
        text += custom_token
        
    tokens, predictions = get_predictions(text, model_name)
    return text, gr.Dropdown(choices=[f"'{t}'" for t in tokens]), predictions

with gr.Blocks() as demo:
    gr.Markdown("# Interactive Text Generation")
    
    model_name = gr.Dropdown(
        choices=list(AVAILABLE_MODELS.keys()),
        value="distilgpt2",
        label="Select Model"
    )
    
    text = gr.Textbox(
        lines=5,
        label="Text",
        placeholder="Type or select tokens to generate text..."
    )
    
    with gr.Row():
        token_choice = gr.Dropdown(
            choices=[],
            label="Select predicted token"
        )
        custom_token = gr.Textbox(
            label="Or type custom token"
        )
    
    predictions = gr.Textbox(
        label="Predictions",
        lines=10
    )
    
    for component in [model_name, token_choice, custom_token]:
        component.change(
            generate,
            inputs=[model_name, text, token_choice, custom_token],
            outputs=[text, token_choice, predictions]
        )

demo.queue().launch()