Update app.py
Browse files
app.py
CHANGED
@@ -47,53 +47,53 @@ def process_sequence(sequence, domain_bounds, n):
|
|
47 |
filtered_logits = logits_array[:, filtered_indices]
|
48 |
all_logits.append(filtered_logits)
|
49 |
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
plt.figure(figsize=(15, 8))
|
64 |
-
sns.heatmap(transposed_logits_array, cmap='plasma', xticklabels=y_tick_labels, yticklabels=filtered_tokens)
|
65 |
-
plt.title('Logits for masked per residue tokens')
|
66 |
-
plt.ylabel('Token')
|
67 |
-
plt.xlabel('Residue Index')
|
68 |
-
plt.yticks(rotation=0)
|
69 |
-
plt.xticks(y_tick_positions, y_tick_labels, rotation = 0)
|
70 |
-
|
71 |
-
# Save the figure to a BytesIO object
|
72 |
-
buf = BytesIO()
|
73 |
-
plt.savefig(buf, format='png')
|
74 |
-
buf.seek(0)
|
75 |
-
plt.close()
|
76 |
-
|
77 |
-
# Convert BytesIO object to an image
|
78 |
-
img = Image.open(buf)
|
79 |
|
80 |
-
|
81 |
-
|
82 |
-
|
|
|
|
|
|
|
|
|
83 |
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
df = pd.DataFrame({
|
91 |
-
'Original Residue': original_residues,
|
92 |
-
'Predicted Residues (in order of decreasing likelihood)': mutations,
|
93 |
-
'Position': positions
|
94 |
-
})
|
95 |
|
96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
|
98 |
return df, img
|
99 |
|
|
|
47 |
filtered_logits = logits_array[:, filtered_indices]
|
48 |
all_logits.append(filtered_logits)
|
49 |
|
50 |
+
token_indices = torch.arange(logits.size(-1))
|
51 |
+
tokens = [tokenizer.decode([idx]) for idx in token_indices]
|
52 |
+
filtered_tokens = [tokens[i] for i in filtered_indices]
|
53 |
+
|
54 |
+
all_logits_array = np.vstack(all_logits)
|
55 |
+
normalized_logits_array = (all_logits_array - all_logits_array.min()) / (all_logits_array.max() - all_logits_array.min())
|
56 |
+
transposed_logits_array = normalized_logits_array.T
|
57 |
|
58 |
+
# Plotting the heatmap
|
59 |
+
step = 50
|
60 |
+
y_tick_positions = np.arange(0, len(sequence), step)
|
61 |
+
y_tick_labels = [str(pos) for pos in y_tick_positions]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
+
plt.figure(figsize=(15, 8))
|
64 |
+
sns.heatmap(transposed_logits_array, cmap='plasma', xticklabels=y_tick_labels, yticklabels=filtered_tokens)
|
65 |
+
plt.title('Logits for masked per residue tokens')
|
66 |
+
plt.ylabel('Token')
|
67 |
+
plt.xlabel('Residue Index')
|
68 |
+
plt.yticks(rotation=0)
|
69 |
+
plt.xticks(y_tick_positions, y_tick_labels, rotation = 0)
|
70 |
|
71 |
+
# Save the figure to a BytesIO object
|
72 |
+
buf = BytesIO()
|
73 |
+
plt.savefig(buf, format='png')
|
74 |
+
buf.seek(0)
|
75 |
+
plt.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
|
77 |
+
# Convert BytesIO object to an image
|
78 |
+
img = Image.open(buf)
|
79 |
+
|
80 |
+
original_residues = []
|
81 |
+
mutations = []
|
82 |
+
positions = []
|
83 |
+
|
84 |
+
for key, value in top_n_mutations.items():
|
85 |
+
original_residue, position = key
|
86 |
+
original_residues.append(original_residue)
|
87 |
+
mutations.append(value)
|
88 |
+
positions.append(position + 1)
|
89 |
+
|
90 |
+
df = pd.DataFrame({
|
91 |
+
'Original Residue': original_residues,
|
92 |
+
'Predicted Residues (in order of decreasing likelihood)': mutations,
|
93 |
+
'Position': positions
|
94 |
+
})
|
95 |
+
|
96 |
+
df = df[start_index:end_index]
|
97 |
|
98 |
return df, img
|
99 |
|