File size: 1,286 Bytes
a042926
fac5648
cbd7ed1
e991bcb
cbd7ed1
 
a042926
43f6895
7e273f3
e991bcb
cbd7ed1
43f6895
9ddf875
 
392fdcd
b3d4e85
0577153
b3d4e85
 
b941cea
43f6895
6ff782d
 
fac5648
 
 
822923c
a042926
822923c
 
 
b941cea
822923c
 
3095cb6
cbd7ed1
822923c
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import gradio as gr
import numpy as np
from datasets import load_dataset
from transformers import AutoTokenizer
import matplotlib.pyplot as plt
import seaborn as sns


dataset = load_dataset('dar-tau/grammar-attention-maps-opt-350m')['train']
tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m', add_prefix_space=True)


def analyze_sentence(index):
    row = dataset[index]
    text = row['text']
    tokenized = tokenizer.batch_decode(tokenizer.encode(text, add_special_tokens=False))
    attn_map_shape = row['attention_maps_shape'][1:]
    seq_len = attn_map_shape[1]
    attn_maps = np.array(row['attention_maps']).reshape(*attn_map_shape).clip(0, 1)
    fig = plt.figure(figsize=(7, 6))
    sns.heatmap(attn_maps.sum(0)[1:, 1:])
    plt.xticks(np.arange(seq_len - 1) + 0.5, tokenized[1:], rotation=90);
    plt.yticks(np.arange(seq_len - 1) + 0.5, tokenized[1:], rotation=0);
    plt.ylabel('TARGET')
    plt.xlabel('SOURCE')    
    plt.grid()
    return fig

demo = gr.Blocks()
with demo:
    with gr.Row():
        dropdown = gr.Dropdown(choices=dataset['text'], value=0, min_width=700, type='index')
        btn = gr.Button("Run")
    output = gr.Plot(label="Plot")
    btn.click(analyze_sentence, [dropdown], [output])


if __name__ == "__main__":
    demo.launch()