Spaces:
Runtime error
Runtime error
File size: 1,305 Bytes
a042926 fac5648 cbd7ed1 e991bcb cbd7ed1 a042926 43f6895 7e273f3 e991bcb cbd7ed1 43f6895 9ddf875 392fdcd b3d4e85 0577153 b3d4e85 c79dcc0 43f6895 6ff782d fac5648 822923c a042926 822923c c79dcc0 822923c c79dcc0 3095cb6 cbd7ed1 822923c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
import gradio as gr
import numpy as np
from datasets import load_dataset
from transformers import AutoTokenizer
import matplotlib.pyplot as plt
import seaborn as sns
dataset = load_dataset('dar-tau/grammar-attention-maps-opt-350m')['train']
tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m', add_prefix_space=True)
def analyze_sentence(index):
row = dataset[index]
text = row['text']
tokenized = tokenizer.batch_decode(tokenizer.encode(text, add_special_tokens=False))
attn_map_shape = row['attention_maps_shape'][1:]
seq_len = attn_map_shape[1]
attn_maps = np.array(row['attention_maps']).reshape(*attn_map_shape).clip(0, 1)
fig = plt.figure(figsize=(6.5, 6))
sns.heatmap(attn_maps.sum(0)[1:, 1:])
plt.xticks(np.arange(seq_len - 1) + 0.5, tokenized[1:], rotation=90);
plt.yticks(np.arange(seq_len - 1) + 0.5, tokenized[1:], rotation=0);
plt.ylabel('TARGET')
plt.xlabel('SOURCE')
plt.grid()
return fig
demo = gr.Blocks()
with demo:
with gr.Row():
dropdown = gr.Dropdown(choices=dataset['text'], value=0, min_width=750, type='index')
btn = gr.Button("Run")
output = gr.Plot(label="Plot", container=False)
btn.click(analyze_sentence, [dropdown], [output])
if __name__ == "__main__":
demo.launch() |