Kseniia-Kholina commited on
Commit
f3096bf
·
verified ·
1 Parent(s): e7999a2

Delete app_latest.py

Browse files
Files changed (1) hide show
  1. app_latest.py +0 -134
app_latest.py DELETED
@@ -1,134 +0,0 @@
1
- import gradio as gr
2
- import pandas as pd
3
- import torch
4
- from transformers import AutoTokenizer, AutoModelForMaskedLM
5
- import torch.nn.functional as F
6
- import logging
7
- import numpy as np
8
- import matplotlib.pyplot as plt
9
- import seaborn as sns
10
- from io import BytesIO
11
- from PIL import Image
12
-
13
- logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
14
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
- print(f"Using device: {device}")
16
-
17
- # Load the tokenizer and model
18
- model_name = "ChatterjeeLab/FusOn-pLM"
19
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
20
- model = AutoModelForMaskedLM.from_pretrained(model_name, trust_remote_code=True)
21
- model.to(device)
22
- model.eval()
23
-
24
- def process_sequence(sequence, domain_bounds, n):
25
- start_index = int(domain_bounds['start'][0]) - 1
26
- end_index = int(domain_bounds['end'][0])
27
-
28
- top_n_mutations = {}
29
- all_logits = []
30
-
31
- for i in range(len(sequence)):
32
- if start_index <= i <= (end_index - 1):
33
- masked_seq = sequence[:i] + '<mask>' + sequence[i+1:]
34
- inputs = tokenizer(masked_seq, return_tensors="pt", padding=True, truncation=True, max_length=2000)
35
- inputs = {k: v.to(device) for k, v in inputs.items()}
36
- with torch.no_grad():
37
- logits = model(**inputs).logits
38
- mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
39
- mask_token_logits = logits[0, mask_token_index, :]
40
-
41
- # Define amino acid tokens
42
- AAs_tokens = ['L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C']
43
- all_tokens_logits = mask_token_logits.squeeze(0)
44
- top_tokens_indices = torch.argsort(all_tokens_logits, dim=0, descending=True)
45
- top_tokens_logits = all_tokens_logits[top_tokens_indices]
46
- mutation = []
47
- # make sure we don't include non-AA tokens
48
- for token_index in top_tokens_indices:
49
- decoded_token = tokenizer.decode([token_index.item()])
50
- if decoded_token in AAs_tokens:
51
- mutation.append(decoded_token)
52
- if len(mutation) == n:
53
- break
54
- top_n_mutations[(sequence[i], i)] = mutation
55
-
56
- # collecting logits for the heatmap
57
- logits_array = mask_token_logits.cpu().numpy()
58
- # filter out non-amino acid tokens
59
- filtered_indices = list(range(4, 23 + 1))
60
- filtered_logits = logits_array[:, filtered_indices]
61
- all_logits.append(filtered_logits)
62
-
63
-
64
- token_indices = torch.arange(logits.size(-1))
65
- tokens = [tokenizer.decode([idx]) for idx in token_indices]
66
- filtered_tokens = [tokens[i] for i in filtered_indices]
67
-
68
- all_logits_array = np.vstack(all_logits)
69
- normalized_logits_array = F.softmax(torch.tensor(all_logits_array), dim=-1).numpy()
70
- transposed_logits_array = normalized_logits_array.T
71
-
72
- # Plotting the heatmap
73
- x_tick_positions = np.arange(start_index, end_index, 10)
74
- x_tick_labels = [str(pos + 1) for pos in x_tick_positions]
75
-
76
- plt.figure(figsize=(15, 8))
77
- plt.rcParams.update({'font.size': 18})
78
-
79
- sns.heatmap(transposed_logits_array, cmap='plasma', xticklabels=x_tick_labels, yticklabels=filtered_tokens)
80
- plt.title('Token Probability Heatmap')
81
- plt.ylabel('Token')
82
- plt.xlabel('Residue Index')
83
- plt.yticks(rotation=0)
84
- plt.xticks(x_tick_positions - start_index + 0.5, x_tick_labels, rotation=0)
85
-
86
- # Save the figure to a BytesIO object
87
- buf = BytesIO()
88
- plt.savefig(buf, format='png', dpi = 300)
89
- buf.seek(0)
90
- plt.close()
91
-
92
- # Convert BytesIO object to an image
93
- img = Image.open(buf)
94
-
95
- original_residues = []
96
- mutations = []
97
- positions = []
98
-
99
- for key, value in top_n_mutations.items():
100
- original_residue, position = key
101
- original_residues.append(original_residue)
102
- mutations.append(value)
103
- positions.append(position + 1)
104
-
105
- df = pd.DataFrame({
106
- 'Original Residue': original_residues,
107
- 'Predicted Residues': mutations,
108
- 'Position': positions
109
- })
110
-
111
- return df, img
112
-
113
- demo = gr.Interface(
114
- fn=process_sequence,
115
- inputs=[
116
- gr.Textbox(label="Sequence", placeholder="Enter the protein sequence here"),
117
- gr.Dataframe(
118
- headers=["start", "end"],
119
- datatype=["number", "number"],
120
- row_count=(1, "fixed"),
121
- col_count=(2, "fixed"),
122
- label="Domain Bounds"
123
- ),
124
- gr.Dropdown([i for i in range(1, 21)], label="Top N Tokens"),
125
- ],
126
- outputs=[
127
- gr.Dataframe(label="Predicted Tokens (in order of decreasing likelihood)"),
128
- gr.Image(type="pil", label="Heatmap"),
129
- ],
130
- description="Choose a number from the dropdown to predict N tokens for each position. Choose the start and end index of the domain of interest (indexing starts at 1).",
131
- )
132
-
133
- if __name__ == "__main__":
134
- demo.launch()