File size: 1,177 Bytes
a042926
fac5648
cbd7ed1
e991bcb
cbd7ed1
 
a042926
43f6895
7e273f3
e991bcb
cbd7ed1
43f6895
9ddf875
 
392fdcd
b3d4e85
0577153
b3d4e85
 
 
43f6895
6ff782d
 
fac5648
 
 
43f6895
a042926
cbd7ed1
223968c
9ddf875
a042926
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import gradio as gr
import numpy as np
from datasets import load_dataset
from transformers import AutoTokenizer
import matplotlib.pyplot as plt
import seaborn as sns


dataset = load_dataset('dar-tau/grammar-attention-maps-opt-350m')['train']
tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m', add_prefix_space=True)


def analyze_sentence(index):
    row = dataset[index]
    text = row['text']
    tokenized = tokenizer.batch_decode(tokenizer.encode(text, add_special_tokens=False))
    attn_map_shape = row['attention_maps_shape'][1:]
    seq_len = attn_map_shape[1]
    attn_maps = np.array(row['attention_maps']).reshape(*attn_map_shape).clip(0, 1)
    fig = plt.figure(figsize=(8, 8))
    sns.heatmap(attn_maps.sum(0)[1:, 1:])
    plt.xticks(np.arange(seq_len - 1) + 0.5, tokenized[1:], rotation=90);
    plt.yticks(np.arange(seq_len - 1) + 0.5, tokenized[1:], rotation=0);
    plt.ylabel('TARGET')
    plt.xlabel('SOURCE')    
    plt.grid()
    return text, fig


iface = gr.Interface(fn=analyze_sentence, inputs=[gr.Dropdown(choices=dataset['text'], value=0, type='index')], 
                     outputs=[gr.Label(), gr.Plot(label="Plot")])
iface.launch()