Spaces:
Runtime error
Runtime error
File size: 1,286 Bytes
a042926 fac5648 cbd7ed1 e991bcb cbd7ed1 a042926 43f6895 7e273f3 e991bcb cbd7ed1 43f6895 9ddf875 392fdcd b3d4e85 0577153 b3d4e85 b941cea 43f6895 6ff782d fac5648 822923c a042926 822923c b941cea 822923c 3095cb6 cbd7ed1 822923c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
import gradio as gr
import numpy as np
from datasets import load_dataset
from transformers import AutoTokenizer
import matplotlib.pyplot as plt
import seaborn as sns
dataset = load_dataset('dar-tau/grammar-attention-maps-opt-350m')['train']
tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m', add_prefix_space=True)
def analyze_sentence(index):
row = dataset[index]
text = row['text']
tokenized = tokenizer.batch_decode(tokenizer.encode(text, add_special_tokens=False))
attn_map_shape = row['attention_maps_shape'][1:]
seq_len = attn_map_shape[1]
attn_maps = np.array(row['attention_maps']).reshape(*attn_map_shape).clip(0, 1)
fig = plt.figure(figsize=(7, 6))
sns.heatmap(attn_maps.sum(0)[1:, 1:])
plt.xticks(np.arange(seq_len - 1) + 0.5, tokenized[1:], rotation=90);
plt.yticks(np.arange(seq_len - 1) + 0.5, tokenized[1:], rotation=0);
plt.ylabel('TARGET')
plt.xlabel('SOURCE')
plt.grid()
return fig
demo = gr.Blocks()
with demo:
with gr.Row():
dropdown = gr.Dropdown(choices=dataset['text'], value=0, min_width=700, type='index')
btn = gr.Button("Run")
output = gr.Plot(label="Plot")
btn.click(analyze_sentence, [dropdown], [output])
if __name__ == "__main__":
demo.launch() |