Kseniia-Kholina commited on
Commit
968e490
·
verified ·
1 Parent(s): f9ddf41

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +111 -0
app.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ import torch
4
+ from transformers import AutoTokenizer, AutoModelForMaskedLM
5
+ import logging
6
+
7
+ logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
8
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+ print(f"Using device: {device}")
10
+
11
+ # Load the tokenizer and model
12
+ model_name = "ChatterjeeLab/FusOn-pLM"
13
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
14
+ model = AutoModelForMaskedLM.from_pretrained(model_name, trust_remote_code=True)
15
+ model.to(device)
16
+ model.eval()
17
+
18
+ def process_sequence(sequence, domain_bounds, n):
19
+ start_index = int(domain_bounds['start'][0]) - 1
20
+ end_index = int(domain_bounds['end'][0])
21
+
22
+ top_n_mutations = {}
23
+ all_logits = []
24
+
25
+ for i in range(len(sequence)):
26
+ masked_seq = sequence[:i] + '<mask>' + sequence[i+1:]
27
+ inputs = tokenizer(masked_seq, return_tensors="pt", padding=True, truncation=True, max_length=2000)
28
+ inputs = {k: v.to(device) for k, v in inputs.items()}
29
+ with torch.no_grad():
30
+ logits = model(**inputs).logits
31
+ mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
32
+ mask_token_logits = logits[0, mask_token_index, :]
33
+ # Decode top n tokens
34
+ top_n_tokens = torch.topk(mask_token_logits, n, dim=1).indices[0].tolist()
35
+ mutation = [tokenizer.decode([token]) for token in top_n_tokens]
36
+ top_n_mutations[(sequence[i], i)] = mutation
37
+
38
+ logits_array = mask_token_logits.cpu().numpy()
39
+ # filter out non-amino acid tokens
40
+ filtered_indices = list(range(4, 23 + 1))
41
+ filtered_logits = logits_array[:, filtered_indices]
42
+ all_logits.append(filtered_logits)
43
+
44
+ token_indices = torch.arange(logits.size(-1))
45
+ tokens = [tokenizer.decode([idx]) for idx in token_indices]
46
+ filtered_tokens = [tokens[i] for i in filtered_indices]
47
+
48
+ all_logits_array = np.vstack(all_logits)
49
+ normalized_logits_array = (all_logits_array - all_logits_array.min()) / (all_logits_array.max() - all_logits_array.min())
50
+ transposed_logits_array = normalized_logits_array.T
51
+
52
+ # Plotting the heatmap
53
+ step = 50
54
+ y_tick_positions = np.arange(0, len(sequence), step)
55
+ y_tick_labels = [str(pos) for pos in y_tick_positions]
56
+
57
+ plt.figure(figsize=(15, 8))
58
+ sns.heatmap(transposed_logits_array, cmap='plasma', xticklabels=y_tick_labels, yticklabels=filtered_tokens)
59
+ plt.title('Logits for masked per residue tokens')
60
+ plt.ylabel('Token')
61
+ plt.xlabel('Residue Index')
62
+ plt.yticks(rotation=0)
63
+ plt.xticks(y_tick_positions, y_tick_labels, rotation = 0)
64
+
65
+ # Save the figure to a BytesIO object
66
+ buf = BytesIO()
67
+ plt.savefig(buf, format='png')
68
+ buf.seek(0)
69
+ plt.close()
70
+
71
+ # Convert BytesIO object to an image
72
+ img = Image.open(buf)
73
+
74
+ original_residues = []
75
+ mutations = []
76
+ positions = []
77
+
78
+ for key, value in top_n_mutations.items():
79
+ original_residue, position = key
80
+ original_residues.append(original_residue)
81
+ mutations.append(value)
82
+ positions.append(position + 1)
83
+
84
+ df = pd.DataFrame({
85
+ 'Original Residue': original_residues,
86
+ 'Predicted Residues (in order of decreasing likelihood)': mutations,
87
+ 'Position': positions
88
+ })
89
+
90
+ df = df[start_index:end_index]
91
+
92
+ return df, img
93
+
94
+ demo = gr.Interface(
95
+ fn=process_sequence,
96
+ inputs=[
97
+ "text",
98
+ gr.Dataframe(
99
+ headers=["start", "end"],
100
+ datatype=["number", "number"],
101
+ row_count=(1, "fixed"),
102
+ col_count=(2, "fixed"),
103
+ ),
104
+ gr.Dropdown([i for i in range(1, 21)]), # Dropdown with numbers from 1 to 20 as integers
105
+ ],
106
+ outputs=["dataframe", "image"],
107
+ description="Choose a number between 1-20 to predict n tokens for each position. Choose the start and end index of the domain of interest (indexing starts at 1).",
108
+ )
109
+
110
+ if __name__ == "__main__":
111
+ demo.launch()