File size: 4,361 Bytes
a042926
fac5648
872cb05
cbd7ed1
e991bcb
cbd7ed1
 
c75ae48
7a323b8
c75ae48
a042926
c5fa8a7
 
 
43f6895
c5fa8a7
7e273f3
e991bcb
cbd7ed1
43f6895
db36965
9ddf875
392fdcd
b3d4e85
0577153
b3d4e85
 
7578fb8
c5fa8a7
c75ae48
c5fa8a7
 
 
f1b9681
376cb17
f1b9681
60b5a72
 
 
 
 
 
 
 
 
872cb05
60b5a72
0e625ad
2a70f49
36a5be5
6417c8b
96ea6a5
5fcb86a
175bc50
9758523
5fcb86a
ce95ec2
aeda444
 
 
60b5a72
872cb05
 
3f2619b
872cb05
 
 
ca37660
 
872cb05
79bdc9d
3f2619b
17a2383
60b5a72
c5fa8a7
60b5a72
a042926
9758523
0ac164b
822923c
 
c75ae48
 
17a2383
db11578
 
8d98842
376cb17
ee7b758
872cb05
ee7b758
a7e4d41
0e625ad
c75ae48
9758523
 
 
cbd7ed1
822923c
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import gradio as gr
import numpy as np
import torch
from datasets import load_dataset
from transformers import AutoTokenizer
import matplotlib.pyplot as plt
import seaborn as sns
from enum import Enum
from spacy import displacy


class VisType(Enum):
    SUM = 'Sum over Layers'
    

num_layers = 24
dataset = load_dataset('dar-tau/grammar-attention-maps-opt-350m')['train']
tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m', add_prefix_space=True)


def analyze_sentence(index, vis_type, vis_format):
    row = dataset[index]
    text = row['text']
    tokenized = tokenizer.batch_decode(tokenizer.encode(text, add_special_tokens=False))
    attn_map_shape = row['attention_maps_shape'][1:]
    seq_len = attn_map_shape[1]
    attn_maps = np.array(row['attention_maps']).reshape(*attn_map_shape).clip(0, 1)
    attn_maps = attn_maps[:, 1:, 1:]
    if vis_type == VisType.SUM.value:
        plot_data = attn_maps.sum(0)
    elif vis_type.startswith('Layer #'):
        layer_to_inspect = int(vis_type.split('#')[1])
        plot_data = attn_maps[layer_to_inspect]
    else:
        print(vis_type)
        0/0
    if vis_format == 'Plot':
        fig = plt.figure(figsize=(0.1 + 0.3 * len(tokenized), 0.25 * len(tokenized)))
        sns.heatmap(plot_data)
        plt.xticks(np.arange(seq_len - 1) + 0.5, tokenized[1:], rotation=90);
        plt.yticks(np.arange(seq_len - 1) + 0.5, tokenized[1:], rotation=0);
        plt.ylabel('TARGET')
        plt.xlabel('SOURCE')    
        plt.grid()
        fig_output, graph_output = fig, ""
    elif vis_format == 'Graph':
        ex = [{
            "words": [{"text": x, "tag": ""} for x in tokenized[1:]],
            "arcs": [{"start": j, "end": i, "label": "", "dir": "right"} 
                         for i in range(len(tokenized) - 1) for j in range(i) if plot_data[i, j] > 0.5 and abs(i-j) > 1
                    ]
             }]
        graph_output = displacy.render(ex, style="dep", jupyter=False, manual=True, options={"compact": True,
                                                                                             "offset_x": 20,
                                                                                             "distance": 130
                                                                                            })
        graph_output = ("<div class='displacy_container' style='max-width:100%; max-height:500px; overflow:auto'>" 
                        + graph_output + 
                        "</div>"
                       )
        fig_output = None
    else:
        fig_output = None
        plot_data[np.arange(len(plot_data)), np.arange(len(plot_data))] = 0.
        top_values, top_indices = torch.tensor(plot_data).flatten().topk(30)
        topk_data = []
        for val, ind in zip(top_values, top_indices):
            if val < 0.5:
                break
            ind = np.unravel_index(ind, plot_data.shape)
            topk_data += [str((tokenized[1+ind[0]], tokenized[1+ind[1]]))]
        graph_output = '<div><b>' + text + '</b><br/>' + '<br/>'.join(topk_data) + '</div>'
    metrics = {'Metrics': 1}
    
    metrics.update({k: v for k, v in row.items() if k not in ['text', 'attention_maps', 'attention_maps_shape']})    
    return fig_output, graph_output, metrics


demo = gr.Blocks(css=".displacy_container svg{height:500px !important; margin-top:-100px; transform: scale(0.5)}")
with demo:
    with gr.Row():
        sentence_dropdown = gr.Dropdown(label="Sentence", 
                                        choices=[x.split('</s> ')[1] for x in dataset['text']], 
                               value=0, min_width=300, type='index')
        vis_dropdown = gr.Dropdown(label="Visualization", choices=[x.value for x in VisType] + 
                                   [f'Layer #{i}' for i in range(num_layers)], 
                                   min_width=70, value=VisType.SUM, type='value')
        btn = gr.Button("Run", min_width=30)
    
    vis_format_checkbox = gr.Radio(['Plot', 'Graph', 'Text'])
        
    output = gr.Plot(label="Plot", container=True)
    graph_output = gr.HTML()
    metrics = gr.Label("Metrics")
    btn.click(analyze_sentence, 
              [sentence_dropdown, vis_dropdown, vis_format_checkbox], 
              [output, graph_output, metrics])


if __name__ == "__main__":
    demo.launch()