adamcasson
add more viz
a891744
raw
history blame
7.08 kB
import random
from functools import partial
import gradio as gr
import numpy as np
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("adamcasson/ul2-tinystories")
def mask_spans(
tokens,
mu,
r,
vocab_size,
eos_id,
prepend_id=None,
prefix_lm=False,
):
masked_tokens = tokens[:]
encoder_inputs = [prepend_id] if prepend_id is not None else []
encoder_mask = [1] if prepend_id is not None else []
targets = []
targets_mask = []
# Original T5 code reused tokens at the end of vocab for sentinels
# https://github.com/google-research/text-to-text-transfer-transformer/blob/258fd30687e6c60d18b7204d009dc5c753142987/t5/data/preprocessors.py#L3106C6-L3106C6
sentinel_id = vocab_size - 1
if prefix_lm:
# n = 1
mu = max(1, int(len(tokens) * r))
start = max(
0, len(tokens) - random.randint(1, int(2 * mu))
) # max to handle start < 0
encoder_inputs += tokens[:start] + [sentinel_id]
encoder_mask += ([1] * len(tokens[:start])) + [0]
targets += [sentinel_id] + tokens[start:]
targets_mask += [0] + ([1] * len(tokens[start:]))
for i in range(start, len(tokens)):
masked_tokens[i] = -1
else:
# n = ceil(len(tokens) / mu)
prev_span_unmasked = False
start = 0
end = 0
while start < len(tokens):
# uniform random span length
length = random.randint(1, int(2 * mu))
end = min(start + length, len(tokens))
# randomly decide if span should be masked
if np.random.binomial(1, p=r):
encoder_inputs.append(sentinel_id)
encoder_mask.append(0)
targets += tokens[start:end]
targets_mask += ([1] * len(tokens[start:end]))
for i in range(start, end):
masked_tokens[i] = -1
prev_span_unmasked = False
sentinel_id -= 1
else:
encoder_inputs += tokens[start:end]
encoder_mask += ([1] * len(tokens[start:end]))
# if previous span was also unmasked we don't need to keep adding the sentinel token
if not prev_span_unmasked:
targets.append(sentinel_id)
targets_mask.append(0)
prev_span_unmasked = True
start = end
targets.append(eos_id)
targets_mask.append(1)
decoder_inputs = [eos_id] + targets[:-1]
decoder_mask = [1] + targets_mask[:-1]
return encoder_inputs, encoder_mask, decoder_inputs, decoder_mask, targets, targets_mask, masked_tokens
# Create mixture-of-denoisers
denoiser_map = {
"R (µ = 3, r = 0.15)": partial(
mask_spans,
mu=3,
r=0.15,
vocab_size=tokenizer.vocab_size,
eos_id=tokenizer.eos_token_id,
prepend_id=tokenizer.vocab["[R]"],
),
"R (µ = 8, r = 0.15)": partial(
mask_spans,
mu=8,
r=0.15,
vocab_size=tokenizer.vocab_size,
eos_id=tokenizer.eos_token_id,
prepend_id=tokenizer.vocab["[R]"],
),
"S (r = 0.25)": partial(
mask_spans,
mu=None,
r=0.25,
vocab_size=tokenizer.vocab_size,
eos_id=tokenizer.eos_token_id,
prefix_lm=True,
prepend_id=tokenizer.vocab["[S]"],
),
"X (µ = 3, r = 0.5)": partial(
mask_spans,
mu=3,
r=0.5,
vocab_size=tokenizer.vocab_size,
eos_id=tokenizer.eos_token_id,
prepend_id=tokenizer.vocab["[X]"],
),
"X (µ = 8, r = 0.5)": partial(
mask_spans,
mu=8,
r=0.5,
vocab_size=tokenizer.vocab_size,
eos_id=tokenizer.eos_token_id,
prepend_id=tokenizer.vocab["[X]"],
),
"X (µ = 32, r = 0.15)": partial(
mask_spans,
mu=32,
r=0.15,
vocab_size=tokenizer.vocab_size,
eos_id=tokenizer.eos_token_id,
prepend_id=tokenizer.vocab["[X]"],
),
"X (µ = 32, r = 0.5)": partial(
mask_spans,
mu=32,
r=0.5,
vocab_size=tokenizer.vocab_size,
eos_id=tokenizer.eos_token_id,
prepend_id=tokenizer.vocab["[X]"],
),
}
def mask_viz(denoiser, text):
seq = tokenizer.encode(text)
tokens = tokenizer.tokenize(text)
enc_in, enc_mask, dec_in, dec_mask, targets, targets_mask, mask = denoiser_map[denoiser](seq)
highlight_tok = []
for tok, tok_mask in zip(tokens, mask):
highlight_tok.append((tok.replace("Ġ", " ").replace("Ċ", "\n"), "masked" if tok_mask == -1 else "unmasked"))
highlight_enc = []
enc_tok = tokenizer.convert_ids_to_tokens(enc_in)
for id, tok, tok_mask in zip(enc_in, enc_tok, enc_mask):
highlight_enc.append((tok.replace("Ġ", " ").replace("Ċ", "\n") if tok_mask == 1 else str(id), "masked" if tok_mask == 0 else "unmasked"))
highlight_dec = []
dec_tok = tokenizer.convert_ids_to_tokens(dec_in)
for id, tok, tok_mask in zip(dec_in, dec_tok, dec_mask):
highlight_dec.append((tok.replace("Ġ", " ").replace("Ċ", "\n") if tok_mask == 1 else str(id), "masked" if tok_mask == 0 else "unmasked"))
return highlight_tok, highlight_enc, highlight_dec
iface = gr.Interface(
fn=mask_viz,
inputs=[
gr.Dropdown(
label="Denoiser",
choices=[
"R (µ = 3, r = 0.15)",
"R (µ = 8, r = 0.15)",
"S (r = 0.25)",
"X (µ = 3, r = 0.5)",
"X (µ = 8, r = 0.5)",
"X (µ = 32, r = 0.15)",
"X (µ = 32, r = 0.5)",
],
value="R (µ = 3, r = 0.15)",
),
gr.Textbox(
value='Once upon a time, there was a family with a little boy. His name was Jack.\nOne day, Jack had a thought. He wanted to go to the park and play. His parents were worried because it was getting dark and the park was far away.\n"Mom, I want to play in the park," Jack said.\nHis mother thought for a moment. "It\'s too late to go to the park now. We\'d better stay at home," she said. \nJack was sad, but he understood why his parents were worried. Together they decided to play games at home instead. \nJack was so happy to get to play games with his family. He thought it was the best time ever.'
),
],
outputs=[
gr.HighlightedText(
label="Corrupted spans",
combine_adjacent=True,
show_legend=True,
color_map={"unmasked": "green", "masked": "red"}
),
gr.HighlightedText(
label="Encoder input",
combine_adjacent=True,
show_legend=True,
color_map={"unmasked": "green", "masked": "red"}
),
gr.HighlightedText(
label="Decoder input",
combine_adjacent=True,
show_legend=True,
color_map={"unmasked": "green", "masked": "red"}
),
],
)
iface.launch()