dar-tau commited on
Commit
b3d4e85
·
verified ·
1 Parent(s): 43f6895

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -13,11 +13,11 @@ tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m', add_prefix_space=
13
  def analyze_sentence(index):
14
  row = dataset[index]
15
  text = row['text']
16
- tokenized = tokenizer.tokenize(text, add_special_tokens=False)
17
  attn_map_shape = row['attention_maps_shape'][1:]
18
- seq_len = attn_map_shape[1]
19
- attn_maps = np.array(row['attention_maps']).reshape(*attn_map_shape)
20
- fig = plt.figure()
21
  sns.heatmap(attn_maps.sum(0)[1:, 1:])
22
  plt.xticks(np.arange(seq_len - 1) + 0.5, tokenized[1:], rotation=90);
23
  plt.yticks(np.arange(seq_len - 1) + 0.5, tokenized[1:], rotation=0);
 
13
  def analyze_sentence(index):
14
  row = dataset[index]
15
  text = row['text']
16
+ tokenized = tokenizer.batch_decode(tokenizer.encode(text, add_special_tokens=False))
17
  attn_map_shape = row['attention_maps_shape'][1:]
18
+ seq_len = attn_map_shape[1]
19
+ attn_maps = np.array(row['attention_maps']).reshape(*attn_map_shape).clip(0, 1)
20
+ fig = plt.figure(figsize=(8, 8))
21
  sns.heatmap(attn_maps.sum(0)[1:, 1:])
22
  plt.xticks(np.arange(seq_len - 1) + 0.5, tokenized[1:], rotation=90);
23
  plt.yticks(np.arange(seq_len - 1) + 0.5, tokenized[1:], rotation=0);