File size: 2,976 Bytes
77be14e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50cc7b6
77be14e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50cc7b6
77be14e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
Hugging Face's logo
Hugging Face
Search models, datasets, users...
Models
Datasets
Spaces
Posts
Docs
Pricing



Spaces:

alonsosilva
/
NextTokenPrediction


like
3
App
Files
Community
NextTokenPrediction
/
app.py

alonsosilva's picture
alonsosilva
Change reactive text
a4869ab
7 months ago
raw
history
blame
contribute
delete
No virus
2.63 kB
import solara
import random
import torch
import torch.nn.functional as F
import pandas as pd
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained('gpt2')
model = AutoModelForCausalLM.from_pretrained('gpt2')
text1 = solara.reactive("Never gonna give you up, never gonna let you")
@solara.component
def Page():
  with solara.Column(margin=10):
    solara.Markdown("#Next token prediction visualization")
    solara.Markdown("I built this tool to help me understand autoregressive language models. For any given text, it gives the top 10 candidates to be the next token with their respective probabilities. The language model I'm using is the smallest version of GPT-2, with 124M parameters.")
    def on_action_cell(column, row_index):
      text1.value += tokenizer.decode(top_10.indices[0][row_index])
    cell_actions = [solara.CellAction(icon="mdi-thumb-up", name="Select", on_click=on_action_cell)]
    solara.InputText("Enter text:", value=text1, continuous_update=True)
    if text1.value != "":
      tokens = tokenizer.encode(text1.value, return_tensors="pt")
      spans1 = ""
      spans2 = ""
      for i, token in enumerate(tokens[0]):
        random.seed(i)
        random_color = ''.join([random.choice('0123456789ABCDEF') for k in range(6)])
        spans1 += " " + f"<span style='font-family: helvetica; color: #{random_color}'>{token}</span>"
        spans2 += " " + f"""<span style="
            padding: 6px;
            border-right: 3px solid white;
            line-height: 3em;
            font-family: courier;
            background-color: #{random_color};
            color: white;
            position: relative;
          "><span style="
          position: absolute;
          top: 5.5ch;
          line-height: 1em;
          left: -0.5px;
          font-size: 0.45em"> {token}</span>{tokenizer.decode([token])}</span>"""
      solara.Markdown(f'{spans2}')
      solara.Markdown(f'{spans1}')
      outputs = model.generate(tokens, max_new_tokens=1, output_scores=True, return_dict_in_generate=True, pad_token_id=tokenizer.eos_token_id)
      scores = F.softmax(outputs.scores[0], dim=-1)
      top_10 = torch.topk(scores, 10)
      df = pd.DataFrame()
      df["probs"] = top_10.values[0]
      df["probs"] = [f"{value:.2%}" for value in df["probs"].values]
      df["next token ID"] = [top_10.indices[0][i].numpy() for i in range(10)]
      df["predicted next token"] = [tokenizer.decode(top_10.indices[0][i]) for i in range(10)]
      solara.Markdown("###Prediction")
      solara.DataFrame(df, items_per_page=10, cell_actions=cell_actions)
Page()