dar-tau commited on
Commit
fac5648
·
verified ·
1 Parent(s): 9ddf875

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -1
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import gradio as gr
 
2
  from datasets import load_dataset
3
  import matplotlib.pyplot as plt
4
  import seaborn as sns
@@ -8,7 +9,15 @@ dataset = load_dataset('dar-tau/grammar-attention-maps-opt-350m')['train']
8
  def analyze_sentence(index):
9
  row = dataset[index]
10
  attn_maps = np.array(row['attention_maps']).reshape(*row['attention_maps_shape'])
11
- return row['text'], sns.heatmap(attn_maps.sum(1).sum(0))
 
 
 
 
 
 
 
 
12
 
13
 
14
  iface = gr.Interface(fn=analyze_sentence, inputs=[gr.Dropdown(choices=dataset['text'], type='index')],
 
1
  import gradio as gr
2
+ import numpy as np
3
  from datasets import load_dataset
4
  import matplotlib.pyplot as plt
5
  import seaborn as sns
 
9
  def analyze_sentence(index):
10
  row = dataset[index]
11
  attn_maps = np.array(row['attention_maps']).reshape(*row['attention_maps_shape'])
12
+ plot = sns.heatmap(attn_maps.sum(1).sum(0))
13
+ plt.xticks(np.arange(len(tokenized)-1) + 0.5,
14
+ tokenizer.tokenize(text, add_special_tokens=False), rotation=90);
15
+ plt.yticks(np.arange(len(tokenized)-1) + 0.5,
16
+ tokenizer.tokenize(text, add_special_tokens=False), rotation=0);
17
+ plt.ylabel('TARGET')
18
+ plt.xlabel('SOURCE')
19
+ plt.grid()
20
+ return row['text'], plot
21
 
22
 
23
  iface = gr.Interface(fn=analyze_sentence, inputs=[gr.Dropdown(choices=dataset['text'], type='index')],