Kseniia-Kholina commited on
Commit
257341a
·
verified ·
1 Parent(s): a18e3ef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
2
  import pandas as pd
3
  import torch
4
  from transformers import AutoTokenizer, AutoModelForMaskedLM
 
5
  import logging
6
  import numpy as np
7
  import matplotlib.pyplot as plt
@@ -52,15 +53,15 @@ def process_sequence(sequence, domain_bounds, n):
52
  filtered_tokens = [tokens[i] for i in filtered_indices]
53
 
54
  all_logits_array = np.vstack(all_logits)
55
- normalized_logits_array = (all_logits_array - all_logits_array.min()) / (all_logits_array.max() - all_logits_array.min())
56
  transposed_logits_array = normalized_logits_array.T
57
 
58
  # Plotting the heatmap
59
- y_tick_positions = np.arange((start_index+1), end_index + 1, 10)
60
- y_tick_labels = [str(pos) for pos in y_tick_positions]
61
 
62
  plt.figure(figsize=(15, 8))
63
- sns.heatmap(transposed_logits_array, cmap='plasma', xticklabels=y_tick_labels, yticklabels=filtered_tokens)
64
  plt.title('Logits for masked per residue tokens')
65
  plt.ylabel('Token')
66
  plt.xlabel('Residue Index')
 
2
  import pandas as pd
3
  import torch
4
  from transformers import AutoTokenizer, AutoModelForMaskedLM
5
+ import torch.nn.functional as F
6
  import logging
7
  import numpy as np
8
  import matplotlib.pyplot as plt
 
53
  filtered_tokens = [tokens[i] for i in filtered_indices]
54
 
55
  all_logits_array = np.vstack(all_logits)
56
+ normalized_logits_array = F.softmax(torch.tensor(all_logits_array), dim=-1).numpy()
57
  transposed_logits_array = normalized_logits_array.T
58
 
59
  # Plotting the heatmap
60
+ x_tick_positions = np.arange(start_index, end_index)
61
+ x_tick_labels = [str(pos + 1) for pos in x_tick_positions]
62
 
63
  plt.figure(figsize=(15, 8))
64
+ sns.heatmap(transposed_logits_array, cmap='plasma', xticklabels=x_tick_labels, yticklabels=filtered_tokens)
65
  plt.title('Logits for masked per residue tokens')
66
  plt.ylabel('Token')
67
  plt.xlabel('Residue Index')