File size: 9,375 Bytes
21208ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7623bc9
21208ef
7623bc9
21208ef
 
 
7623bc9
21208ef
 
 
 
 
 
 
 
 
 
 
 
7623bc9
21208ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
import gradio as gr
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch.nn.functional as F
import logging
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from io import BytesIO
from PIL import Image
from contextlib import contextmanager
import warnings
import sys
import os
import zipfile

logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load the tokenizer and model
model_name = "ChatterjeeLab/FusOn-pLM"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForMaskedLM.from_pretrained(model_name, trust_remote_code=True)
model.to(device)
model.eval()

@contextmanager
def suppress_output():
    with open(os.devnull, 'w') as devnull:
        old_stdout = sys.stdout
        sys.stdout = devnull
        try:
            yield
        finally:
            sys.stdout = old_stdout

def process_sequence(sequence, domain_bounds, n):
    AAs_tokens = ['L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C']
    AAs_tokens_indices = {'L' : 4, 'A' : 5, 'G' : 6, 'V': 7, 'S' : 8, 'E' : 9, 'R' : 10, 'T' : 11, 'I': 12, 'D' : 13, 'P' : 14,
                          'K' : 15, 'Q' : 16, 'N' : 17, 'F' : 18, 'Y' : 19, 'M' : 20, 'H' : 21, 'W' : 22, 'C' : 23}
    # checking sequence inputs
    if not sequence.strip():
      raise gr.Error("Error: The sequence input is empty. Please enter a valid protein sequence.")
      return None, None, None
    if any(char not in AAs_tokens for char in sequence):
      raise gr.Error("Error: The sequence input contains non-amino acid characters. Please enter a valid protein sequence.")
      return None, None, None

    # checking domain bounds inputs
    try:
      start = int(domain_bounds['start'][0])
      end = int(domain_bounds['end'][0])
    except ValueError:
      raise gr.Error("Error: Start and end indices must be integers.")
      return None, None, None
    if start >= end:
      raise gr.Error("Start index must be smaller than end index.")
      return None, None, None
    if start == 0 and end != 0:
      raise gr.Error("Indexing starts at 1. Please enter valid domain bounds.")
      return None, None, None
    if start <= 0 or end <= 0:
      raise gr.Error("Domain bounds must be positive integers. Please enter valid domain bounds.")
      return None, None, None
    if start > len(sequence) or end > len(sequence):
      raise gr.Error("Domain bounds exceed sequence length.")
      return None, None, None

    # checking top n tokens input
    if n == None:
      raise gr.Error("Choose Top N Tokens from the dropdown menu.")
      return None, None, None

    start_index = int(domain_bounds['start'][0]) - 1
    end_index = int(domain_bounds['end'][0])

    top_n_mutations = {}
    all_logits = []

    # these 2 lists are for the 2nd heatmap
    originals_logits = []
    conservation_likelihoods = {}

    for i in range(len(sequence)):
      # only iterate through the residues inside the domain
          if start_index <= i <= (end_index - 1):
              original_residue = sequence[i]
              original_residue_index = AAs_tokens_indices[original_residue]
              masked_seq = sequence[:i] + '<mask>' + sequence[i+1:]
              inputs = tokenizer(masked_seq, return_tensors="pt", padding=True, truncation=True, max_length=2000)
              inputs = {k: v.to(device) for k, v in inputs.items()}
              with torch.no_grad():
                  logits = model(**inputs).logits
              mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
              mask_token_logits = logits[0, mask_token_index, :]

              # Pick top N tokens
              all_tokens_logits = mask_token_logits.squeeze(0)
              top_tokens_indices = torch.argsort(all_tokens_logits, dim=0, descending=True)
              top_tokens_logits = all_tokens_logits[top_tokens_indices]
              mutation = []
              # make sure we don't include non-AA tokens
              for token_index in top_tokens_indices:
                  decoded_token = tokenizer.decode([token_index.item()])
                  # decoded all tokens, pick the top n amino acid ones
                  if decoded_token in AAs_tokens:
                      mutation.append(decoded_token)
                      if len(mutation) == n:
                          break
              top_n_mutations[(sequence[i], i)] = mutation

              # collecting logits for the heatmap
              logits_array = mask_token_logits.cpu().numpy()
              # filter out non-amino acid tokens
              filtered_indices = list(range(4, 23 + 1))
              filtered_logits = logits_array[:, filtered_indices]
              all_logits.append(filtered_logits)

              # code for the second heatmap
              normalized_mask_token_logits = F.softmax(torch.tensor(mask_token_logits).cpu(), dim=-1).numpy()
              normalized_mask_token_logits = np.squeeze(normalized_mask_token_logits)
              originals_logit = normalized_mask_token_logits[original_residue_index]
              originals_logits.append(originals_logit)

              if originals_logit > 0.7:
                  conservation_likelihoods[(original_residue, i)] = 1
              else:
                  conservation_likelihoods[(original_residue, i)] = 0



   # Plotting heatmap 2
    domain_len = end - start
    if 500 > domain_len > 100:
      step_size = 50
    elif 500 <= domain_len:
      step_size = 100
    elif domain_len < 10:
      step_size = 1
    else:
      step_size = 10
    x_tick_positions = np.arange(start_index, end_index, step_size)
    x_tick_labels = [str(pos + 1) for pos in x_tick_positions]

    all_logits_array = np.vstack(originals_logits)
    transposed_logits_array = all_logits_array.T
    conservation_likelihoods_array = np.array(list(conservation_likelihoods.values())).reshape(1, -1)
    # combine to make a 2D heatmap
    combined_array = np.vstack((transposed_logits_array, conservation_likelihoods_array))

    plt.figure(figsize=(15, 5))
    plt.rcParams.update({'font.size': 16.5})
    sns.heatmap(combined_array, cmap='viridis', xticklabels=x_tick_labels, yticklabels=['Residue \nLogits', 'Residue \nConservation'], cbar=True)
    plt.xticks(x_tick_positions - start_index + 0.5, x_tick_labels, rotation=0)
    plt.title('Original Residue Probability and Conservation')
    plt.xlabel('Residue Index')
    plt.show()
    buf = BytesIO()
    plt.savefig(buf, format='png', dpi=300)
    buf.seek(0)
    plt.close()
    img_2 = Image.open(buf)


# plotting heatmap 1
    token_indices = torch.arange(logits.size(-1))
    tokens = [tokenizer.decode([idx]) for idx in token_indices]
    filtered_tokens = [tokens[i] for i in filtered_indices]
    all_logits_array = np.vstack(all_logits)
    normalized_logits_array = F.softmax(torch.tensor(all_logits_array), dim=-1).numpy()
    transposed_logits_array = normalized_logits_array.T


    plt.figure(figsize=(15, 8))
    plt.rcParams.update({'font.size': 16.5})
    sns.heatmap(transposed_logits_array, cmap='plasma', xticklabels=x_tick_labels, yticklabels=filtered_tokens)
    plt.title('Token Probability')
    plt.ylabel('Amino Acid')
    plt.xlabel('Residue Index')
    plt.yticks(rotation=0)
    plt.xticks(x_tick_positions - start_index + 0.5, x_tick_labels, rotation=0)

    buf = BytesIO()
    plt.savefig(buf, format='png', dpi = 300)
    buf.seek(0)
    plt.close()

    img_1 = Image.open(buf)

# store the predicted mutations in a dataframe
    original_residues = []
    mutations = []
    positions = []

    for key, value in top_n_mutations.items():
        original_residue, position = key
        original_residues.append(original_residue)
        mutations.append(value)
        positions.append(position + 1)

    df = pd.DataFrame({
        'Original Residue': original_residues,
        'Predicted Residues': mutations,
        'Position': positions
    })
    df.to_csv("predicted_tokens.csv", index=False)
    img_1.save("heatmap.png", dpi=(300, 300))
    img_2.save("heatmap_2.png", dpi=(300, 300))
    zip_path = "outputs.zip"
    with zipfile.ZipFile(zip_path, 'w') as zipf:
        zipf.write("predicted_tokens.csv")
        zipf.write("heatmap.png")
        zipf.write("heatmap_2.png")

    return df, img_1, img_2, zip_path

# launch the demo
demo = gr.Interface(
    fn=process_sequence,
    inputs=[
        gr.Textbox(label="Sequence", placeholder="Enter the protein sequence here"),
        gr.Dataframe(
            value = [[1, 1]],
            headers=["start", "end"],
            datatype=["number", "number"],
            row_count=(1, "fixed"),
            col_count=(2, "fixed"),
            label="Domain Bounds"
        ),
        gr.Dropdown([i for i in range(1, 21)], label="Top N Tokens"),
    ],
     outputs=[
        gr.Dataframe(label="Predicted Tokens (in order of decreasing likelihood)"),
        gr.Image(type="pil", label="Probability Distribution for All Tokens"),
        gr.Image(type="pil", label="Residue Conservation"),
        gr.File(label="Download Outputs"),
    ],
)
if __name__ == "__main__":
    with suppress_output():
      demo.launch()