dar-tau commited on
Commit
d762850
·
verified ·
1 Parent(s): 0577153

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -12,7 +12,7 @@ def analyze_sentence(index):
12
  attn_map_shape = row['attention_maps_shape'][1:]
13
  seq_len = attn_map_shape[1]
14
  attn_maps = np.array(row['attention_maps']).reshape(*attn_map_shape)
15
- plot = sns.heatmap(attn_maps.sum(1).sum(0))
16
  plt.xticks(np.arange(seq_len - 1) + 0.5,
17
  tokenizer.tokenize(text, add_special_tokens=False), rotation=90);
18
  plt.yticks(np.arange(seq_len - 1) + 0.5,
@@ -20,7 +20,7 @@ def analyze_sentence(index):
20
  plt.ylabel('TARGET')
21
  plt.xlabel('SOURCE')
22
  plt.grid()
23
- return row['text'], plot
24
 
25
 
26
  iface = gr.Interface(fn=analyze_sentence, inputs=[gr.Dropdown(choices=dataset['text'], type='index')],
 
12
  attn_map_shape = row['attention_maps_shape'][1:]
13
  seq_len = attn_map_shape[1]
14
  attn_maps = np.array(row['attention_maps']).reshape(*attn_map_shape)
15
+ plot = sns.heatmap(attn_maps.sum(0))
16
  plt.xticks(np.arange(seq_len - 1) + 0.5,
17
  tokenizer.tokenize(text, add_special_tokens=False), rotation=90);
18
  plt.yticks(np.arange(seq_len - 1) + 0.5,
 
20
  plt.ylabel('TARGET')
21
  plt.xlabel('SOURCE')
22
  plt.grid()
23
+ return text, plot
24
 
25
 
26
  iface = gr.Interface(fn=analyze_sentence, inputs=[gr.Dropdown(choices=dataset['text'], type='index')],