File size: 2,646 Bytes
691f4df
eef83d0
691f4df
 
 
 
 
 
 
47b2af6
691f4df
 
eef83d0
691f4df
eef83d0
691f4df
 
 
 
 
 
199bc94
eef83d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47b2af6
eef83d0
 
 
 
88904be
05ca641
691f4df
06dcb2f
691f4df
c90b054
691f4df
06dcb2f
 
eef83d0
c90b054
2cb6115
691f4df
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import solara
import random
import torch
import torch.nn.functional as F
import pandas as pd
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained('gpt2')
model = AutoModelForCausalLM.from_pretrained('gpt2')
text1 = solara.reactive("I")
@solara.component
def Page():
  with solara.Column(margin=10):
    solara.Markdown("#Next token prediction visualization")
    solara.Markdown("I built this tool to help me understand autoregressive language models. For any given text, it gives the top 10 candidates to be the next token with their respective probabilities. The language model I'm using is the smallest version of GPT-2, with 124M parameters.")
    def on_action_cell(column, row_index):
      text1.value += tokenizer.decode(top_10.indices[0][row_index])
    cell_actions = [solara.CellAction(icon="mdi-thumb-up", name="Select", on_click=on_action_cell)]
    solara.InputText("Enter text:", value=text1, continuous_update=True)
    if text1.value != "":
      tokens = tokenizer.encode(text1.value, return_tensors="pt")
      spans1 = ""
      spans2 = ""
      for i, token in enumerate(tokens[0]):
        random.seed(i)
        random_color = ''.join([random.choice('0123456789ABCDEF') for k in range(6)])
        spans1 += " " + f"<span style='font-family: helvetica; color: #{random_color}'>{token}</span>"
        spans2 += " " + f"""<span style="
            padding: 6px;
            border-right: 3px solid white;
            line-height: 3em;
            font-family: courier;
            background-color: #{random_color};
            color: white;
            position: relative;
          "><span style="
          position: absolute;
          top: 5.5ch;
          bottom: 20px;
          line-height: 1em;
          left: -0.5px;
          font-size: 0.45em"> {token}</span>{tokenizer.decode([token])}</span>"""
      solara.Markdown(f'{spans2}')
      solara.Markdown(f'{spans1}')
      outputs = model.generate(tokens, max_new_tokens=2, output_scores=True, return_dict_in_generate=True, pad_token_id=tokenizer.eos_token_id)
      scores = F.softmax(outputs.scores[0], dim=-1)
      top_10 = torch.topk(scores, 1000)
      df = pd.DataFrame()
      df["probs"] = top_10.values[0]
      df["probs"] = [f"{value:.2%}" for value in df["probs"].values]
      df["next token ID"] = [top_10.indices[0][i].numpy() for i in range(1000)]
      df["predicted next token"] = [tokenizer.decode(top_10.indices[0][i]) for i in range(1000)]
      solara.Markdown("###Prediction")
      solara.DataFrame(df, items_per_page=10, cell_actions=cell_actions)
    solara.Markdown('-----')
Page()