File size: 1,919 Bytes
a042926
fac5648
cbd7ed1
e991bcb
cbd7ed1
 
c75ae48
 
 
 
a042926
43f6895
7e273f3
e991bcb
cbd7ed1
43f6895
c75ae48
9ddf875
392fdcd
b3d4e85
0577153
b3d4e85
 
f7076c5
c75ae48
 
 
 
6ff782d
 
fac5648
 
 
c75ae48
822923c
a042926
822923c
 
 
c75ae48
 
 
 
822923c
a7e4d41
c75ae48
 
cbd7ed1
822923c
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import gradio as gr
import numpy as np
from datasets import load_dataset
from transformers import AutoTokenizer
import matplotlib.pyplot as plt
import seaborn as sns
from enum import Enum

class VisType(Enum):
    ALL = 'ALL'


dataset = load_dataset('dar-tau/grammar-attention-maps-opt-350m')['train']
tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m', add_prefix_space=True)


def analyze_sentence(index, vis_type):
    row = dataset[index]
    text = row['text']
    tokenized = tokenizer.batch_decode(tokenizer.encode(text, add_special_tokens=False))
    attn_map_shape = row['attention_maps_shape'][1:]
    seq_len = attn_map_shape[1]
    attn_maps = np.array(row['attention_maps']).reshape(*attn_map_shape).clip(0, 1)
    fig = plt.figure(figsize=(0.5 + 0.5 * len(tokenized), 0.4 * len(tokenized)))
    plot_data = attn_maps[:, 1:, 1:]
    if vis_type == VisType.ALL:
        plot_data = attn_maps.sum(0)
    sns.heatmap(plot_data)
    plt.xticks(np.arange(seq_len - 1) + 0.5, tokenized[1:], rotation=90);
    plt.yticks(np.arange(seq_len - 1) + 0.5, tokenized[1:], rotation=0);
    plt.ylabel('TARGET')
    plt.xlabel('SOURCE')    
    plt.grid()
    metrics = {k: v for k, v in record.items() if x not in ['text', 'attention_maps', 'attention_maps_shape']}
    return fig

demo = gr.Blocks()
with demo:
    with gr.Row():
        sentence_dropdown = gr.Dropdown(label="Sentence", 
                                        choices=[x.split('</s> ')[1] for x in dataset['text']], 
                               value=0, min_width=500, type='index')
        vis_dropdown = gr.Dropdown(label="Visualization", choices=list(VisType), value=0, type='index')
        btn = gr.Button("Run")
    output = gr.Plot(label="Plot", container=True)
    metrics = gr.Label("Metrics")
    btn.click(analyze_sentence, [sentence_dropdown, vis_dropdown], [output, metrics])


if __name__ == "__main__":
    demo.launch()