Spaces:
Runtime error
Runtime error
import gradio as gr | |
import numpy as np | |
from datasets import load_dataset | |
from transformers import AutoTokenizer | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
dataset = load_dataset('dar-tau/grammar-attention-maps-opt-350m')['train'] | |
tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m', add_prefix_space=True) | |
def analyze_sentence(index): | |
row = dataset[index] | |
text = row['text'] | |
tokenized = tokenizer.batch_decode(tokenizer.encode(text, add_special_tokens=False)) | |
attn_map_shape = row['attention_maps_shape'][1:] | |
seq_len = attn_map_shape[1] | |
attn_maps = np.array(row['attention_maps']).reshape(*attn_map_shape).clip(0, 1) | |
fig = plt.figure(figsize=(6.5, 6)) | |
sns.heatmap(attn_maps.sum(0)[1:, 1:]) | |
plt.xticks(np.arange(seq_len - 1) + 0.5, tokenized[1:], rotation=90); | |
plt.yticks(np.arange(seq_len - 1) + 0.5, tokenized[1:], rotation=0); | |
plt.ylabel('TARGET') | |
plt.xlabel('SOURCE') | |
plt.grid() | |
return fig | |
demo = gr.Blocks() | |
with demo: | |
with gr.Row(): | |
dropdown = gr.Dropdown(choices=dataset['text'], value=0, min_width=750, type='index') | |
btn = gr.Button("Run") | |
output = gr.Plot(label="Plot", container=False) | |
btn.click(analyze_sentence, [dropdown], [output]) | |
if __name__ == "__main__": | |
demo.launch() |