Kseniia-Kholina commited on
Commit
84afffe
·
verified ·
1 Parent(s): f3096bf

Delete app_all_seq.py

Browse files
Files changed (1) hide show
  1. app_all_seq.py +0 -116
app_all_seq.py DELETED
@@ -1,116 +0,0 @@
1
- import gradio as gr
2
- import pandas as pd
3
- import torch
4
- from transformers import AutoTokenizer, AutoModelForMaskedLM
5
- import logging
6
- import numpy as np
7
- import matplotlib.pyplot as plt
8
- import seaborn as sns
9
- from io import BytesIO
10
- from PIL import Image
11
-
12
- logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
13
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
- print(f"Using device: {device}")
15
-
16
- # Load the tokenizer and model
17
- model_name = "ChatterjeeLab/FusOn-pLM"
18
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
19
- model = AutoModelForMaskedLM.from_pretrained(model_name, trust_remote_code=True)
20
- model.to(device)
21
- model.eval()
22
-
23
- def process_sequence(sequence, domain_bounds, n):
24
- start_index = int(domain_bounds['start'][0]) - 1
25
- end_index = int(domain_bounds['end'][0])
26
-
27
- top_n_mutations = {}
28
- all_logits = []
29
-
30
- for i in range(len(sequence)):
31
- masked_seq = sequence[:i] + '<mask>' + sequence[i+1:]
32
- inputs = tokenizer(masked_seq, return_tensors="pt", padding=True, truncation=True, max_length=2000)
33
- inputs = {k: v.to(device) for k, v in inputs.items()}
34
- with torch.no_grad():
35
- logits = model(**inputs).logits
36
- mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
37
- mask_token_logits = logits[0, mask_token_index, :]
38
- # Decode top n tokens
39
- top_n_tokens = torch.topk(mask_token_logits, n, dim=1).indices[0].tolist()
40
- mutation = [tokenizer.decode([token]) for token in top_n_tokens]
41
- top_n_mutations[(sequence[i], i)] = mutation
42
-
43
- logits_array = mask_token_logits.cpu().numpy()
44
- # filter out non-amino acid tokens
45
- filtered_indices = list(range(4, 23 + 1))
46
- filtered_logits = logits_array[:, filtered_indices]
47
- all_logits.append(filtered_logits)
48
-
49
- token_indices = torch.arange(logits.size(-1))
50
- tokens = [tokenizer.decode([idx]) for idx in token_indices]
51
- filtered_tokens = [tokens[i] for i in filtered_indices]
52
-
53
- all_logits_array = np.vstack(all_logits)
54
- normalized_logits_array = (all_logits_array - all_logits_array.min()) / (all_logits_array.max() - all_logits_array.min())
55
- transposed_logits_array = normalized_logits_array.T
56
-
57
- # Plotting the heatmap
58
- step = 50
59
- y_tick_positions = np.arange(0, len(sequence), step)
60
- y_tick_labels = [str(pos) for pos in y_tick_positions]
61
-
62
- plt.figure(figsize=(15, 8))
63
- sns.heatmap(transposed_logits_array, cmap='plasma', xticklabels=y_tick_labels, yticklabels=filtered_tokens)
64
- plt.title('Logits for masked per residue tokens')
65
- plt.ylabel('Token')
66
- plt.xlabel('Residue Index')
67
- plt.yticks(rotation=0)
68
- plt.xticks(y_tick_positions, y_tick_labels, rotation = 0)
69
-
70
- # Save the figure to a BytesIO object
71
- buf = BytesIO()
72
- plt.savefig(buf, format='png')
73
- buf.seek(0)
74
- plt.close()
75
-
76
- # Convert BytesIO object to an image
77
- img = Image.open(buf)
78
-
79
- original_residues = []
80
- mutations = []
81
- positions = []
82
-
83
- for key, value in top_n_mutations.items():
84
- original_residue, position = key
85
- original_residues.append(original_residue)
86
- mutations.append(value)
87
- positions.append(position + 1)
88
-
89
- df = pd.DataFrame({
90
- 'Original Residue': original_residues,
91
- 'Predicted Residues (in order of decreasing likelihood)': mutations,
92
- 'Position': positions
93
- })
94
-
95
- df = df[start_index:end_index]
96
-
97
- return df, img
98
-
99
- demo = gr.Interface(
100
- fn=process_sequence,
101
- inputs=[
102
- "text",
103
- gr.Dataframe(
104
- headers=["start", "end"],
105
- datatype=["number", "number"],
106
- row_count=(1, "fixed"),
107
- col_count=(2, "fixed"),
108
- ),
109
- gr.Dropdown([i for i in range(1, 21)]), # Dropdown with numbers from 1 to 20 as integers
110
- ],
111
- outputs=["dataframe", "image"],
112
- description="Choose a number between 1-20 to predict n tokens for each position. Choose the start and end index of the domain of interest (indexing starts at 1).",
113
- )
114
-
115
- if __name__ == "__main__":
116
- demo.launch()