Spaces:
Sleeping
Sleeping
import datasets | |
import gradio as gr | |
from transformers import AutoModelForMaskedLM, AutoTokenizer, DataCollatorForLanguageModeling | |
ds = datasets.load_dataset( | |
"oscar-corpus/OSCAR-2109", "deduplicated_en", streaming=True, use_auth_token=True, split="train" | |
) | |
ds = ds.shuffle(buffer_size=1000) | |
ds = iter(ds) | |
model_name = "RomanCast/roberta-en-100k" | |
model = AutoModelForMaskedLM.from_pretrained(model_name) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
collate_fn = DataCollatorForLanguageModeling(tokenizer) | |
with gr.Blocks() as demo: | |
inputs_oscar = gr.TextArea( | |
placeholder="Type a sentence or click the button below to get a random sentence from the English OSCAR corpus", | |
label="Input", | |
num_lines=6, | |
interactive=True, | |
) | |
next_button = gr.Button("Random OSCAR sentence") | |
next_button.click(fn=lambda: next(ds)["text"], outputs=inputs_oscar) | |
masked_text = gr.Textbox(label="Masked sentence") | |
labels_and_outputs = [] | |
with gr.Row(): | |
for _ in range(4): | |
with gr.Column(): | |
labels_and_outputs.append(gr.Textbox(label="Label")) | |
labels_and_outputs.append(gr.Label(num_top_classes=5, show_label=False)) | |
with gr.Row(): | |
for _ in range(4): | |
with gr.Column(): | |
labels_and_outputs.append(gr.Textbox(label="Label")) | |
labels_and_outputs.append(gr.Label(num_top_classes=5, show_label=False)) | |
def model_inputs_and_outputs(example): | |
token_ids = tokenizer(example, return_tensors="pt", truncation=True, max_length=128) | |
model_inputs = collate_fn((token_ids,)) | |
model_inputs = {k: v[0] for k, v in model_inputs.items()} | |
masked_tokens = tokenizer.batch_decode(model_inputs["input_ids"])[0] | |
original_labels = [tokenizer.convert_ids_to_tokens([id])[0] for id in model_inputs["labels"][0] if id != -100] | |
out = model(**model_inputs) | |
all_logits = out.logits[model_inputs["labels"] != -100].softmax(-1) | |
all_outputs = [ | |
{tokenizer.convert_ids_to_tokens([id])[0]: val.item() for id, val in enumerate(logits)} | |
for logits in all_logits | |
] | |
out_dict = {masked_text: masked_tokens} | |
for i in range(len(labels_and_outputs) // 2): | |
try: | |
out_dict[labels_and_outputs[2 * i]] = original_labels[i] | |
out_dict[labels_and_outputs[2 * i + 1]] = all_outputs[i] | |
except: | |
out_dict[labels_and_outputs[2 * i]] = "" | |
out_dict[labels_and_outputs[2 * i + 1]] = {} | |
return out_dict | |
button = gr.Button("Predict tokens") | |
button.click(fn=model_inputs_and_outputs, inputs=inputs_oscar, outputs=[masked_text] + labels_and_outputs) | |
demo.launch() | |