Update app.py
Browse files
app.py
CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
|
|
2 |
import pandas as pd
|
3 |
import torch
|
4 |
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
|
|
5 |
import logging
|
6 |
import numpy as np
|
7 |
import matplotlib.pyplot as plt
|
@@ -52,15 +53,15 @@ def process_sequence(sequence, domain_bounds, n):
|
|
52 |
filtered_tokens = [tokens[i] for i in filtered_indices]
|
53 |
|
54 |
all_logits_array = np.vstack(all_logits)
|
55 |
-
normalized_logits_array = (
|
56 |
transposed_logits_array = normalized_logits_array.T
|
57 |
|
58 |
# Plotting the heatmap
|
59 |
-
|
60 |
-
|
61 |
|
62 |
plt.figure(figsize=(15, 8))
|
63 |
-
sns.heatmap(transposed_logits_array, cmap='plasma', xticklabels=
|
64 |
plt.title('Logits for masked per residue tokens')
|
65 |
plt.ylabel('Token')
|
66 |
plt.xlabel('Residue Index')
|
|
|
2 |
import pandas as pd
|
3 |
import torch
|
4 |
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
5 |
+
import torch.nn.functional as F
|
6 |
import logging
|
7 |
import numpy as np
|
8 |
import matplotlib.pyplot as plt
|
|
|
53 |
filtered_tokens = [tokens[i] for i in filtered_indices]
|
54 |
|
55 |
all_logits_array = np.vstack(all_logits)
|
56 |
+
normalized_logits_array = F.softmax(torch.tensor(all_logits_array), dim=-1).numpy()
|
57 |
transposed_logits_array = normalized_logits_array.T
|
58 |
|
59 |
# Plotting the heatmap
|
60 |
+
x_tick_positions = np.arange(start_index, end_index)
|
61 |
+
x_tick_labels = [str(pos + 1) for pos in x_tick_positions]
|
62 |
|
63 |
plt.figure(figsize=(15, 8))
|
64 |
+
sns.heatmap(transposed_logits_array, cmap='plasma', xticklabels=x_tick_labels, yticklabels=filtered_tokens)
|
65 |
plt.title('Logits for masked per residue tokens')
|
66 |
plt.ylabel('Token')
|
67 |
plt.xlabel('Residue Index')
|