dar-tau's picture
Update app.py
17a2383 verified
raw
history blame
2.33 kB
import gradio as gr
import numpy as np
from datasets import load_dataset
from transformers import AutoTokenizer
import matplotlib.pyplot as plt
import seaborn as sns
from enum import Enum
class VisType(Enum):
SUM = 'Sum over Layers'
num_layers = 24
dataset = load_dataset('dar-tau/grammar-attention-maps-opt-350m')['train']
tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m', add_prefix_space=True)
def analyze_sentence(index, vis_type):
row = dataset[index]
text = row['text']
tokenized = tokenizer.batch_decode(tokenizer.encode(text, add_special_tokens=False))
attn_map_shape = row['attention_maps_shape'][1:]
seq_len = attn_map_shape[1]
attn_maps = np.array(row['attention_maps']).reshape(*attn_map_shape).clip(0, 1)
fig = plt.figure(figsize=(0.3 + 0.35 * len(tokenized), 0.3 * len(tokenized)))
attn_maps = attn_maps[:, 1:, 1:]
if vis_type == VisType.SUM.value:
plot_data = attn_maps.sum(0)
elif vis_type.startswith('Layer #'):
layer_to_inspect = int(vis_type.split('#')[1])
plot_data = attn_maps[layer_to_inspect]
else:
print(vis_type)
0/0
sns.heatmap(plot_data)
plt.xticks(np.arange(seq_len - 1) + 0.5, tokenized[1:], rotation=90);
plt.yticks(np.arange(seq_len - 1) + 0.5, tokenized[1:], rotation=0);
plt.ylabel('TARGET')
plt.xlabel('SOURCE')
plt.grid()
metrics = {'Metrics': 1}
metrics.update({k: v for k, v in row.items() if k not in ['text', 'attention_maps', 'attention_maps_shape']})
return fig, metrics
demo = gr.Blocks()
with demo:
with gr.Row():
sentence_dropdown = gr.Dropdown(label="Sentence",
choices=[x.split('</s> ')[1] for x in dataset['text']],
value=0, min_width=300, type='index')
vis_dropdown = gr.Dropdown(label="Visualization", choices=[x.value for x in VisType] + [f'Layer #{i}' for i in range(num_layers)],
min_width=100, value=VisType.SUM, type='value')
btn = gr.Button("Run", min_width=30)
output = gr.Plot(label="Plot", container=True)
metrics = gr.Label("Metrics")
btn.click(analyze_sentence, [sentence_dropdown, vis_dropdown], [output, metrics])
if __name__ == "__main__":
demo.launch()