Kseniia-Kholina commited on
Commit
c578c17
·
verified ·
1 Parent(s): afdb4a4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -43
app.py CHANGED
@@ -47,53 +47,53 @@ def process_sequence(sequence, domain_bounds, n):
47
  filtered_logits = logits_array[:, filtered_indices]
48
  all_logits.append(filtered_logits)
49
 
50
- token_indices = torch.arange(logits.size(-1))
51
- tokens = [tokenizer.decode([idx]) for idx in token_indices]
52
- filtered_tokens = [tokens[i] for i in filtered_indices]
53
-
54
- all_logits_array = np.vstack(all_logits)
55
- normalized_logits_array = (all_logits_array - all_logits_array.min()) / (all_logits_array.max() - all_logits_array.min())
56
- transposed_logits_array = normalized_logits_array.T
57
 
58
- # Plotting the heatmap
59
- step = 50
60
- y_tick_positions = np.arange(0, len(sequence), step)
61
- y_tick_labels = [str(pos) for pos in y_tick_positions]
62
-
63
- plt.figure(figsize=(15, 8))
64
- sns.heatmap(transposed_logits_array, cmap='plasma', xticklabels=y_tick_labels, yticklabels=filtered_tokens)
65
- plt.title('Logits for masked per residue tokens')
66
- plt.ylabel('Token')
67
- plt.xlabel('Residue Index')
68
- plt.yticks(rotation=0)
69
- plt.xticks(y_tick_positions, y_tick_labels, rotation = 0)
70
-
71
- # Save the figure to a BytesIO object
72
- buf = BytesIO()
73
- plt.savefig(buf, format='png')
74
- buf.seek(0)
75
- plt.close()
76
-
77
- # Convert BytesIO object to an image
78
- img = Image.open(buf)
79
 
80
- original_residues = []
81
- mutations = []
82
- positions = []
 
 
 
 
83
 
84
- for key, value in top_n_mutations.items():
85
- original_residue, position = key
86
- original_residues.append(original_residue)
87
- mutations.append(value)
88
- positions.append(position + 1)
89
-
90
- df = pd.DataFrame({
91
- 'Original Residue': original_residues,
92
- 'Predicted Residues (in order of decreasing likelihood)': mutations,
93
- 'Position': positions
94
- })
95
 
96
- df = df[start_index:end_index]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  return df, img
99
 
 
47
  filtered_logits = logits_array[:, filtered_indices]
48
  all_logits.append(filtered_logits)
49
 
50
+ token_indices = torch.arange(logits.size(-1))
51
+ tokens = [tokenizer.decode([idx]) for idx in token_indices]
52
+ filtered_tokens = [tokens[i] for i in filtered_indices]
53
+
54
+ all_logits_array = np.vstack(all_logits)
55
+ normalized_logits_array = (all_logits_array - all_logits_array.min()) / (all_logits_array.max() - all_logits_array.min())
56
+ transposed_logits_array = normalized_logits_array.T
57
 
58
+ # Plotting the heatmap
59
+ step = 50
60
+ y_tick_positions = np.arange(0, len(sequence), step)
61
+ y_tick_labels = [str(pos) for pos in y_tick_positions]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
+ plt.figure(figsize=(15, 8))
64
+ sns.heatmap(transposed_logits_array, cmap='plasma', xticklabels=y_tick_labels, yticklabels=filtered_tokens)
65
+ plt.title('Logits for masked per residue tokens')
66
+ plt.ylabel('Token')
67
+ plt.xlabel('Residue Index')
68
+ plt.yticks(rotation=0)
69
+ plt.xticks(y_tick_positions, y_tick_labels, rotation = 0)
70
 
71
+ # Save the figure to a BytesIO object
72
+ buf = BytesIO()
73
+ plt.savefig(buf, format='png')
74
+ buf.seek(0)
75
+ plt.close()
 
 
 
 
 
 
76
 
77
+ # Convert BytesIO object to an image
78
+ img = Image.open(buf)
79
+
80
+ original_residues = []
81
+ mutations = []
82
+ positions = []
83
+
84
+ for key, value in top_n_mutations.items():
85
+ original_residue, position = key
86
+ original_residues.append(original_residue)
87
+ mutations.append(value)
88
+ positions.append(position + 1)
89
+
90
+ df = pd.DataFrame({
91
+ 'Original Residue': original_residues,
92
+ 'Predicted Residues (in order of decreasing likelihood)': mutations,
93
+ 'Position': positions
94
+ })
95
+
96
+ df = df[start_index:end_index]
97
 
98
  return df, img
99