Kseniia-Kholina commited on
Commit
21208ef
·
verified ·
1 Parent(s): 7f7b413

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +242 -0
app.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ import torch
4
+ from transformers import AutoTokenizer, AutoModelForMaskedLM
5
+ import torch.nn.functional as F
6
+ import logging
7
+ import numpy as np
8
+ import matplotlib.pyplot as plt
9
+ import seaborn as sns
10
+ from io import BytesIO
11
+ from PIL import Image
12
+ from contextlib import contextmanager
13
+ import warnings
14
+ import sys
15
+ import os
16
+ import zipfile
17
+
18
+ logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
19
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+ print(f"Using device: {device}")
21
+
22
+ # Load the tokenizer and model
23
+ model_name = "ChatterjeeLab/FusOn-pLM"
24
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
25
+ model = AutoModelForMaskedLM.from_pretrained(model_name, trust_remote_code=True)
26
+ model.to(device)
27
+ model.eval()
28
+
29
+ @contextmanager
30
+ def suppress_output():
31
+ with open(os.devnull, 'w') as devnull:
32
+ old_stdout = sys.stdout
33
+ sys.stdout = devnull
34
+ try:
35
+ yield
36
+ finally:
37
+ sys.stdout = old_stdout
38
+
39
+ def process_sequence(sequence, domain_bounds, n):
40
+ AAs_tokens = ['L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C']
41
+ AAs_tokens_indices = {'L' : 4, 'A' : 5, 'G' : 6, 'V': 7, 'S' : 8, 'E' : 9, 'R' : 10, 'T' : 11, 'I': 12, 'D' : 13, 'P' : 14,
42
+ 'K' : 15, 'Q' : 16, 'N' : 17, 'F' : 18, 'Y' : 19, 'M' : 20, 'H' : 21, 'W' : 22, 'C' : 23}
43
+ # checking sequence inputs
44
+ if not sequence.strip():
45
+ raise gr.Error("Error: The sequence input is empty. Please enter a valid protein sequence.")
46
+ return None, None, None
47
+ if any(char not in AAs_tokens for char in sequence):
48
+ raise gr.Error("Error: The sequence input contains non-amino acid characters. Please enter a valid protein sequence.")
49
+ return None, None, None
50
+
51
+ # checking domain bounds inputs
52
+ try:
53
+ start = int(domain_bounds['start'][0])
54
+ end = int(domain_bounds['end'][0])
55
+ except ValueError:
56
+ raise gr.Error("Error: Start and end indices must be integers.")
57
+ return None, None, None
58
+ if start >= end:
59
+ raise gr.Error("Start index must be smaller than end index.")
60
+ return None, None, None
61
+ if start == 0 and end != 0:
62
+ raise gr.Error("Indexing starts at 1. Please enter valid domain bounds.")
63
+ return None, None, None
64
+ if start <= 0 or end <= 0:
65
+ raise gr.Error("Domain bounds must be positive integers. Please enter valid domain bounds.")
66
+ return None, None, None
67
+ if start > len(sequence) or end > len(sequence):
68
+ raise gr.Error("Domain bounds exceed sequence length.")
69
+ return None, None, None
70
+
71
+ # checking top n tokens input
72
+ if n == None:
73
+ raise gr.Error("Choose Top N Tokens from the dropdown menu.")
74
+ return None, None, None
75
+
76
+ start_index = int(domain_bounds['start'][0]) - 1
77
+ end_index = int(domain_bounds['end'][0])
78
+
79
+ top_n_mutations = {}
80
+ all_logits = []
81
+
82
+ # these 2 lists are for the 2nd heatmap
83
+ originals_logits = []
84
+ conservation_likelihoods = {}
85
+
86
+ for i in range(len(sequence)):
87
+ # only iterate through the residues inside the domain
88
+ if start_index <= i <= (end_index - 1):
89
+ original_residue = sequence[i]
90
+ original_residue_index = AAs_tokens_indices[original_residue]
91
+ masked_seq = sequence[:i] + '<mask>' + sequence[i+1:]
92
+ inputs = tokenizer(masked_seq, return_tensors="pt", padding=True, truncation=True, max_length=2000)
93
+ inputs = {k: v.to(device) for k, v in inputs.items()}
94
+ with torch.no_grad():
95
+ logits = model(**inputs).logits
96
+ mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
97
+ mask_token_logits = logits[0, mask_token_index, :]
98
+
99
+ # Pick top N tokens
100
+ all_tokens_logits = mask_token_logits.squeeze(0)
101
+ top_tokens_indices = torch.argsort(all_tokens_logits, dim=0, descending=True)
102
+ top_tokens_logits = all_tokens_logits[top_tokens_indices]
103
+ mutation = []
104
+ # make sure we don't include non-AA tokens
105
+ for token_index in top_tokens_indices:
106
+ decoded_token = tokenizer.decode([token_index.item()])
107
+ # decoded all tokens, pick the top n amino acid ones
108
+ if decoded_token in AAs_tokens:
109
+ mutation.append(decoded_token)
110
+ if len(mutation) == n:
111
+ break
112
+ top_n_mutations[(sequence[i], i)] = mutation
113
+
114
+ # collecting logits for the heatmap
115
+ logits_array = mask_token_logits.cpu().numpy()
116
+ # filter out non-amino acid tokens
117
+ filtered_indices = list(range(4, 23 + 1))
118
+ filtered_logits = logits_array[:, filtered_indices]
119
+ all_logits.append(filtered_logits)
120
+
121
+ # code for the second heatmap
122
+ normalized_mask_token_logits = F.softmax(torch.tensor(mask_token_logits).cpu(), dim=-1).numpy()
123
+ normalized_mask_token_logits = np.squeeze(normalized_mask_token_logits)
124
+ originals_logit = normalized_mask_token_logits[original_residue_index]
125
+ originals_logits.append(originals_logit)
126
+
127
+ if originals_logit > 0.7:
128
+ conservation_likelihoods[(original_residue, i)] = 1
129
+ else:
130
+ conservation_likelihoods[(original_residue, i)] = 0
131
+
132
+
133
+
134
+ # Plotting heatmap 2
135
+ domain_len = end - start
136
+ if 500 > domain_len > 100:
137
+ step_size = 49
138
+ elif 500 <= domain_len:
139
+ step_size = 99
140
+ elif domain_len < 10:
141
+ step_size = 1
142
+ else:
143
+ step_size = 9
144
+ x_tick_positions = np.arange(start_index, end_index, step_size)
145
+ x_tick_labels = [str(pos + 1) for pos in x_tick_positions]
146
+
147
+ all_logits_array = np.vstack(originals_logits)
148
+ transposed_logits_array = all_logits_array.T
149
+ conservation_likelihoods_array = np.array(list(conservation_likelihoods.values())).reshape(1, -1)
150
+ # combine to make a 2D heatmap
151
+ combined_array = np.vstack((transposed_logits_array, conservation_likelihoods_array))
152
+
153
+ plt.figure(figsize=(15, 5))
154
+ plt.rcParams.update({'font.size': 16.5})
155
+ sns.heatmap(combined_array, cmap='viridis', xticklabels=x_tick_labels, yticklabels=['Residue \nLogits', 'Residue \nConservation'], cbar=True)
156
+ plt.title('Original Residue Probability and Conservation')
157
+ plt.xlabel('Residue Index')
158
+ plt.show()
159
+ buf = BytesIO()
160
+ plt.savefig(buf, format='png', dpi=300)
161
+ buf.seek(0)
162
+ plt.close()
163
+ img_2 = Image.open(buf)
164
+
165
+
166
+ # plotting heatmap 1
167
+ token_indices = torch.arange(logits.size(-1))
168
+ tokens = [tokenizer.decode([idx]) for idx in token_indices]
169
+ filtered_tokens = [tokens[i] for i in filtered_indices]
170
+ all_logits_array = np.vstack(all_logits)
171
+ normalized_logits_array = F.softmax(torch.tensor(all_logits_array), dim=-1).numpy()
172
+ transposed_logits_array = normalized_logits_array.T
173
+
174
+
175
+ plt.figure(figsize=(15, 8))
176
+ plt.rcParams.update({'font.size': 16.5})
177
+ sns.heatmap(transposed_logits_array, cmap='plasma', xticklabels=x_tick_labels, yticklabels=filtered_tokens)
178
+ plt.title('Token Probability')
179
+ plt.ylabel('Amino Acid')
180
+ plt.xlabel('Residue Index')
181
+ plt.yticks(rotation=0)
182
+ plt.xticks(x_tick_positions - start_index + 0.5, x_tick_labels, rotation=0)
183
+
184
+ buf = BytesIO()
185
+ plt.savefig(buf, format='png', dpi = 300)
186
+ buf.seek(0)
187
+ plt.close()
188
+
189
+ img_1 = Image.open(buf)
190
+
191
+ # store the predicted mutations in a dataframe
192
+ original_residues = []
193
+ mutations = []
194
+ positions = []
195
+
196
+ for key, value in top_n_mutations.items():
197
+ original_residue, position = key
198
+ original_residues.append(original_residue)
199
+ mutations.append(value)
200
+ positions.append(position + 1)
201
+
202
+ df = pd.DataFrame({
203
+ 'Original Residue': original_residues,
204
+ 'Predicted Residues': mutations,
205
+ 'Position': positions
206
+ })
207
+ df.to_csv("predicted_tokens.csv", index=False)
208
+ img_1.save("heatmap.png", dpi=(300, 300))
209
+ img_2.save("heatmap_2.png", dpi=(300, 300))
210
+ zip_path = "outputs.zip"
211
+ with zipfile.ZipFile(zip_path, 'w') as zipf:
212
+ zipf.write("predicted_tokens.csv")
213
+ zipf.write("heatmap.png")
214
+ zipf.write("heatmap_2.png")
215
+
216
+ return df, img_1, img_2, zip_path
217
+
218
+ # launch the demo
219
+ demo = gr.Interface(
220
+ fn=process_sequence,
221
+ inputs=[
222
+ gr.Textbox(label="Sequence", placeholder="Enter the protein sequence here"),
223
+ gr.Dataframe(
224
+ value = [[1, 1]],
225
+ headers=["start", "end"],
226
+ datatype=["number", "number"],
227
+ row_count=(1, "fixed"),
228
+ col_count=(2, "fixed"),
229
+ label="Domain Bounds"
230
+ ),
231
+ gr.Dropdown([i for i in range(1, 21)], label="Top N Tokens"),
232
+ ],
233
+ outputs=[
234
+ gr.Dataframe(label="Predicted Tokens (in order of decreasing likelihood)"),
235
+ gr.Image(type="pil", label="Probability Distribution for All Tokens"),
236
+ gr.Image(type="pil", label="Residue Conservation"),
237
+ gr.File(label="Download Outputs"),
238
+ ],
239
+ )
240
+ if __name__ == "__main__":
241
+ with suppress_output():
242
+ demo.launch()