Spaces:
Sleeping
Sleeping
File size: 4,361 Bytes
a042926 fac5648 872cb05 cbd7ed1 e991bcb cbd7ed1 c75ae48 7a323b8 c75ae48 a042926 c5fa8a7 43f6895 c5fa8a7 7e273f3 e991bcb cbd7ed1 43f6895 db36965 9ddf875 392fdcd b3d4e85 0577153 b3d4e85 7578fb8 c5fa8a7 c75ae48 c5fa8a7 f1b9681 376cb17 f1b9681 60b5a72 872cb05 60b5a72 0e625ad 2a70f49 36a5be5 6417c8b 96ea6a5 5fcb86a 175bc50 9758523 5fcb86a ce95ec2 aeda444 60b5a72 872cb05 3f2619b 872cb05 ca37660 872cb05 79bdc9d 3f2619b 17a2383 60b5a72 c5fa8a7 60b5a72 a042926 9758523 0ac164b 822923c c75ae48 17a2383 db11578 8d98842 376cb17 ee7b758 872cb05 ee7b758 a7e4d41 0e625ad c75ae48 9758523 cbd7ed1 822923c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
import gradio as gr
import numpy as np
import torch
from datasets import load_dataset
from transformers import AutoTokenizer
import matplotlib.pyplot as plt
import seaborn as sns
from enum import Enum
from spacy import displacy
class VisType(Enum):
SUM = 'Sum over Layers'
num_layers = 24
dataset = load_dataset('dar-tau/grammar-attention-maps-opt-350m')['train']
tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m', add_prefix_space=True)
def analyze_sentence(index, vis_type, vis_format):
row = dataset[index]
text = row['text']
tokenized = tokenizer.batch_decode(tokenizer.encode(text, add_special_tokens=False))
attn_map_shape = row['attention_maps_shape'][1:]
seq_len = attn_map_shape[1]
attn_maps = np.array(row['attention_maps']).reshape(*attn_map_shape).clip(0, 1)
attn_maps = attn_maps[:, 1:, 1:]
if vis_type == VisType.SUM.value:
plot_data = attn_maps.sum(0)
elif vis_type.startswith('Layer #'):
layer_to_inspect = int(vis_type.split('#')[1])
plot_data = attn_maps[layer_to_inspect]
else:
print(vis_type)
0/0
if vis_format == 'Plot':
fig = plt.figure(figsize=(0.1 + 0.3 * len(tokenized), 0.25 * len(tokenized)))
sns.heatmap(plot_data)
plt.xticks(np.arange(seq_len - 1) + 0.5, tokenized[1:], rotation=90);
plt.yticks(np.arange(seq_len - 1) + 0.5, tokenized[1:], rotation=0);
plt.ylabel('TARGET')
plt.xlabel('SOURCE')
plt.grid()
fig_output, graph_output = fig, ""
elif vis_format == 'Graph':
ex = [{
"words": [{"text": x, "tag": ""} for x in tokenized[1:]],
"arcs": [{"start": j, "end": i, "label": "", "dir": "right"}
for i in range(len(tokenized) - 1) for j in range(i) if plot_data[i, j] > 0.5 and abs(i-j) > 1
]
}]
graph_output = displacy.render(ex, style="dep", jupyter=False, manual=True, options={"compact": True,
"offset_x": 20,
"distance": 130
})
graph_output = ("<div class='displacy_container' style='max-width:100%; max-height:500px; overflow:auto'>"
+ graph_output +
"</div>"
)
fig_output = None
else:
fig_output = None
plot_data[np.arange(len(plot_data)), np.arange(len(plot_data))] = 0.
top_values, top_indices = torch.tensor(plot_data).flatten().topk(30)
topk_data = []
for val, ind in zip(top_values, top_indices):
if val < 0.5:
break
ind = np.unravel_index(ind, plot_data.shape)
topk_data += [str((tokenized[1+ind[0]], tokenized[1+ind[1]]))]
graph_output = '<div><b>' + text + '</b><br/>' + '<br/>'.join(topk_data) + '</div>'
metrics = {'Metrics': 1}
metrics.update({k: v for k, v in row.items() if k not in ['text', 'attention_maps', 'attention_maps_shape']})
return fig_output, graph_output, metrics
demo = gr.Blocks(css=".displacy_container svg{height:500px !important; margin-top:-100px; transform: scale(0.5)}")
with demo:
with gr.Row():
sentence_dropdown = gr.Dropdown(label="Sentence",
choices=[x.split('</s> ')[1] for x in dataset['text']],
value=0, min_width=300, type='index')
vis_dropdown = gr.Dropdown(label="Visualization", choices=[x.value for x in VisType] +
[f'Layer #{i}' for i in range(num_layers)],
min_width=70, value=VisType.SUM, type='value')
btn = gr.Button("Run", min_width=30)
vis_format_checkbox = gr.Radio(['Plot', 'Graph', 'Text'])
output = gr.Plot(label="Plot", container=True)
graph_output = gr.HTML()
metrics = gr.Label("Metrics")
btn.click(analyze_sentence,
[sentence_dropdown, vis_dropdown, vis_format_checkbox],
[output, graph_output, metrics])
if __name__ == "__main__":
demo.launch() |