import gradio as gr import numpy as np from datasets import load_dataset from transformers import AutoTokenizer import matplotlib.pyplot as plt import seaborn as sns from enum import Enum class VisType(Enum): ALL = 'ALL' dataset = load_dataset('dar-tau/grammar-attention-maps-opt-350m')['train'] tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m', add_prefix_space=True) def analyze_sentence(index, vis_type): row = dataset[index] text = row['text'] tokenized = tokenizer.batch_decode(tokenizer.encode(text, add_special_tokens=False)) attn_map_shape = row['attention_maps_shape'][1:] seq_len = attn_map_shape[1] attn_maps = np.array(row['attention_maps']).reshape(*attn_map_shape).clip(0, 1) fig = plt.figure(figsize=(0.5 + 0.5 * len(tokenized), 0.4 * len(tokenized))) attn_maps = attn_maps[:, 1:, 1:] if vis_type == VisType.ALL: plot_data = attn_maps.sum(0) else: print(vis_type) 0/0 sns.heatmap(plot_data) plt.xticks(np.arange(seq_len - 1) + 0.5, tokenized[1:], rotation=90); plt.yticks(np.arange(seq_len - 1) + 0.5, tokenized[1:], rotation=0); plt.ylabel('TARGET') plt.xlabel('SOURCE') plt.grid() metrics = {k: v for k, v in record.items() if x not in ['text', 'attention_maps', 'attention_maps_shape']} return fig demo = gr.Blocks() with demo: with gr.Row(): sentence_dropdown = gr.Dropdown(label="Sentence", choices=[x.split(' ')[1] for x in dataset['text']], value=0, min_width=500, type='index') vis_dropdown = gr.Dropdown(label="Visualization", choices=list(VisType), min_width=150, value=VisType.ALL, type='value') btn = gr.Button("Run", min_width=30) output = gr.Plot(label="Plot", container=True) metrics = gr.Label("Metrics") btn.click(analyze_sentence, [sentence_dropdown, vis_dropdown], [output, metrics]) if __name__ == "__main__": demo.launch()