Spaces:
Sleeping
Sleeping
File size: 2,646 Bytes
691f4df eef83d0 691f4df 47b2af6 691f4df eef83d0 691f4df eef83d0 691f4df 199bc94 eef83d0 47b2af6 eef83d0 88904be 05ca641 691f4df 06dcb2f 691f4df c90b054 691f4df 06dcb2f eef83d0 c90b054 2cb6115 691f4df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
import solara
import random
import torch
import torch.nn.functional as F
import pandas as pd
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained('gpt2')
model = AutoModelForCausalLM.from_pretrained('gpt2')
text1 = solara.reactive("I")
@solara.component
def Page():
with solara.Column(margin=10):
solara.Markdown("#Next token prediction visualization")
solara.Markdown("I built this tool to help me understand autoregressive language models. For any given text, it gives the top 10 candidates to be the next token with their respective probabilities. The language model I'm using is the smallest version of GPT-2, with 124M parameters.")
def on_action_cell(column, row_index):
text1.value += tokenizer.decode(top_10.indices[0][row_index])
cell_actions = [solara.CellAction(icon="mdi-thumb-up", name="Select", on_click=on_action_cell)]
solara.InputText("Enter text:", value=text1, continuous_update=True)
if text1.value != "":
tokens = tokenizer.encode(text1.value, return_tensors="pt")
spans1 = ""
spans2 = ""
for i, token in enumerate(tokens[0]):
random.seed(i)
random_color = ''.join([random.choice('0123456789ABCDEF') for k in range(6)])
spans1 += " " + f"<span style='font-family: helvetica; color: #{random_color}'>{token}</span>"
spans2 += " " + f"""<span style="
padding: 6px;
border-right: 3px solid white;
line-height: 3em;
font-family: courier;
background-color: #{random_color};
color: white;
position: relative;
"><span style="
position: absolute;
top: 5.5ch;
bottom: 20px;
line-height: 1em;
left: -0.5px;
font-size: 0.45em"> {token}</span>{tokenizer.decode([token])}</span>"""
solara.Markdown(f'{spans2}')
solara.Markdown(f'{spans1}')
outputs = model.generate(tokens, max_new_tokens=2, output_scores=True, return_dict_in_generate=True, pad_token_id=tokenizer.eos_token_id)
scores = F.softmax(outputs.scores[0], dim=-1)
top_10 = torch.topk(scores, 1000)
df = pd.DataFrame()
df["probs"] = top_10.values[0]
df["probs"] = [f"{value:.2%}" for value in df["probs"].values]
df["next token ID"] = [top_10.indices[0][i].numpy() for i in range(1000)]
df["predicted next token"] = [tokenizer.decode(top_10.indices[0][i]) for i in range(1000)]
solara.Markdown("###Prediction")
solara.DataFrame(df, items_per_page=10, cell_actions=cell_actions)
solara.Markdown('-----')
Page()
|