dar-tau commited on
Commit
43f6895
·
verified ·
1 Parent(s): 6ff782d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -2
app.py CHANGED
@@ -5,9 +5,11 @@ from transformers import AutoTokenizer
5
  import matplotlib.pyplot as plt
6
  import seaborn as sns
7
 
 
8
  dataset = load_dataset('dar-tau/grammar-attention-maps-opt-350m')['train']
9
  tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m', add_prefix_space=True)
10
 
 
11
  def analyze_sentence(index):
12
  row = dataset[index]
13
  text = row['text']
@@ -15,13 +17,14 @@ def analyze_sentence(index):
15
  attn_map_shape = row['attention_maps_shape'][1:]
16
  seq_len = attn_map_shape[1]
17
  attn_maps = np.array(row['attention_maps']).reshape(*attn_map_shape)
18
- plot = sns.heatmap(attn_maps.sum(0)[1:, 1:])
 
19
  plt.xticks(np.arange(seq_len - 1) + 0.5, tokenized[1:], rotation=90);
20
  plt.yticks(np.arange(seq_len - 1) + 0.5, tokenized[1:], rotation=0);
21
  plt.ylabel('TARGET')
22
  plt.xlabel('SOURCE')
23
  plt.grid()
24
- return text, plot
25
 
26
 
27
  iface = gr.Interface(fn=analyze_sentence, inputs=[gr.Dropdown(choices=dataset['text'], value=0, type='index')],
 
5
  import matplotlib.pyplot as plt
6
  import seaborn as sns
7
 
8
+
9
  dataset = load_dataset('dar-tau/grammar-attention-maps-opt-350m')['train']
10
  tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m', add_prefix_space=True)
11
 
12
+
13
  def analyze_sentence(index):
14
  row = dataset[index]
15
  text = row['text']
 
17
  attn_map_shape = row['attention_maps_shape'][1:]
18
  seq_len = attn_map_shape[1]
19
  attn_maps = np.array(row['attention_maps']).reshape(*attn_map_shape)
20
+ fig = plt.figure()
21
+ sns.heatmap(attn_maps.sum(0)[1:, 1:])
22
  plt.xticks(np.arange(seq_len - 1) + 0.5, tokenized[1:], rotation=90);
23
  plt.yticks(np.arange(seq_len - 1) + 0.5, tokenized[1:], rotation=0);
24
  plt.ylabel('TARGET')
25
  plt.xlabel('SOURCE')
26
  plt.grid()
27
+ return text, fig
28
 
29
 
30
  iface = gr.Interface(fn=analyze_sentence, inputs=[gr.Dropdown(choices=dataset['text'], value=0, type='index')],