File size: 1,305 Bytes
a042926
fac5648
cbd7ed1
e991bcb
cbd7ed1
 
a042926
43f6895
7e273f3
e991bcb
cbd7ed1
43f6895
9ddf875
 
392fdcd
b3d4e85
0577153
b3d4e85
 
c79dcc0
43f6895
6ff782d
 
fac5648
 
 
822923c
a042926
822923c
 
 
c79dcc0
822923c
c79dcc0
3095cb6
cbd7ed1
822923c
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import gradio as gr
import numpy as np
from datasets import load_dataset
from transformers import AutoTokenizer
import matplotlib.pyplot as plt
import seaborn as sns


dataset = load_dataset('dar-tau/grammar-attention-maps-opt-350m')['train']
tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m', add_prefix_space=True)


def analyze_sentence(index):
    row = dataset[index]
    text = row['text']
    tokenized = tokenizer.batch_decode(tokenizer.encode(text, add_special_tokens=False))
    attn_map_shape = row['attention_maps_shape'][1:]
    seq_len = attn_map_shape[1]
    attn_maps = np.array(row['attention_maps']).reshape(*attn_map_shape).clip(0, 1)
    fig = plt.figure(figsize=(6.5, 6))
    sns.heatmap(attn_maps.sum(0)[1:, 1:])
    plt.xticks(np.arange(seq_len - 1) + 0.5, tokenized[1:], rotation=90);
    plt.yticks(np.arange(seq_len - 1) + 0.5, tokenized[1:], rotation=0);
    plt.ylabel('TARGET')
    plt.xlabel('SOURCE')    
    plt.grid()
    return fig

demo = gr.Blocks()
with demo:
    with gr.Row():
        dropdown = gr.Dropdown(choices=dataset['text'], value=0, min_width=750, type='index')
        btn = gr.Button("Run")
    output = gr.Plot(label="Plot", container=False)
    btn.click(analyze_sentence, [dropdown], [output])


if __name__ == "__main__":
    demo.launch()