Spaces:
Sleeping
Sleeping
Commit
·
eef83d0
1
Parent(s):
8a16818
Update app
Browse files
app.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import solara
|
|
|
2 |
import torch
|
3 |
import torch.nn.functional as F
|
4 |
import pandas as pd
|
@@ -9,20 +10,45 @@ model = AutoModelForCausalLM.from_pretrained('gpt2')
|
|
9 |
text1 = solara.reactive("Alan Turing theorized that computers would one day become")
|
10 |
@solara.component
|
11 |
def Page():
|
12 |
-
with solara.
|
13 |
solara.Markdown("#Next token prediction visualization")
|
|
|
14 |
def on_action_cell(column, row_index):
|
15 |
text1.value += tokenizer.decode(top_10.indices[0][row_index])
|
16 |
cell_actions = [solara.CellAction(icon="mdi-thumb-up", name="Select", on_click=on_action_cell)]
|
17 |
solara.InputText("Enter text:", value=text1, continuous_update=True)
|
18 |
if text1.value != "":
|
19 |
tokens = tokenizer.encode(text1.value, return_tensors="pt")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
outputs = model.generate(tokens, max_new_tokens=1, output_scores=True, return_dict_in_generate=True, pad_token_id=tokenizer.eos_token_id)
|
21 |
scores = F.softmax(outputs.scores[0], dim=-1)
|
22 |
top_10 = torch.topk(scores, 10)
|
23 |
df = pd.DataFrame()
|
24 |
df["probs"] = top_10.values[0]
|
25 |
df["probs"] = [f"{value:.2%}" for value in df["probs"].values]
|
|
|
26 |
df["predicted next token"] = [tokenizer.decode(top_10.indices[0][i]) for i in range(10)]
|
|
|
27 |
solara.DataFrame(df, items_per_page=10, cell_actions=cell_actions)
|
28 |
Page()
|
|
|
1 |
import solara
|
2 |
+
import random
|
3 |
import torch
|
4 |
import torch.nn.functional as F
|
5 |
import pandas as pd
|
|
|
10 |
text1 = solara.reactive("Alan Turing theorized that computers would one day become")
|
11 |
@solara.component
|
12 |
def Page():
|
13 |
+
with solara.Column(margin=10):
|
14 |
solara.Markdown("#Next token prediction visualization")
|
15 |
+
solara.Markdown("I built this tool to help me understand autoregressive language models. For any given text, it gives the top 10 candidates to be the next token with their respective probabilities. The language model I'm using is the smallest version of GPT-2, with 124M parameters.")
|
16 |
def on_action_cell(column, row_index):
|
17 |
text1.value += tokenizer.decode(top_10.indices[0][row_index])
|
18 |
cell_actions = [solara.CellAction(icon="mdi-thumb-up", name="Select", on_click=on_action_cell)]
|
19 |
solara.InputText("Enter text:", value=text1, continuous_update=True)
|
20 |
if text1.value != "":
|
21 |
tokens = tokenizer.encode(text1.value, return_tensors="pt")
|
22 |
+
spans1 = ""
|
23 |
+
spans2 = ""
|
24 |
+
for i, token in enumerate(tokens[0]):
|
25 |
+
random.seed(i)
|
26 |
+
random_color = ''.join([random.choice('0123456789ABCDEF') for k in range(6)])
|
27 |
+
spans1 += " " + f"<span style='font-family: helvetica; color: #{random_color}'>{token}</span>"
|
28 |
+
spans2 += " " + f"""<span style="
|
29 |
+
padding: 6px;
|
30 |
+
border-right: 3px solid white;
|
31 |
+
line-height: 3em;
|
32 |
+
font-family: courier;
|
33 |
+
background-color: #{random_color};
|
34 |
+
color: white;
|
35 |
+
position: relative;
|
36 |
+
"><span style="
|
37 |
+
position: absolute;
|
38 |
+
top: 5.5ch;
|
39 |
+
line-height: 1em;
|
40 |
+
left: -0.5px;
|
41 |
+
font-size: 0.45em"> {token}</span>{tokenizer.decode([token])}</span>"""
|
42 |
+
solara.Markdown(f'{spans2}')
|
43 |
+
solara.Markdown(f'{spans2}')
|
44 |
outputs = model.generate(tokens, max_new_tokens=1, output_scores=True, return_dict_in_generate=True, pad_token_id=tokenizer.eos_token_id)
|
45 |
scores = F.softmax(outputs.scores[0], dim=-1)
|
46 |
top_10 = torch.topk(scores, 10)
|
47 |
df = pd.DataFrame()
|
48 |
df["probs"] = top_10.values[0]
|
49 |
df["probs"] = [f"{value:.2%}" for value in df["probs"].values]
|
50 |
+
df["next token ID"] = [top_10.indices[0][i].numpy() for i in range(10)]
|
51 |
df["predicted next token"] = [tokenizer.decode(top_10.indices[0][i]) for i in range(10)]
|
52 |
+
solara.Markdown("###Prediction")
|
53 |
solara.DataFrame(df, items_per_page=10, cell_actions=cell_actions)
|
54 |
Page()
|