File size: 5,272 Bytes
25d7145
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26ac9e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
406e1bc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import gradio as gr
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch.nn.functional as F
import logging
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from io import BytesIO
from PIL import Image
from contextlib import contextmanager
import warnings
import sys
import os
import zipfile

logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load the tokenizer and model
model_name = "ChatterjeeLab/FusOn-pLM"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForMaskedLM.from_pretrained(model_name, trust_remote_code=True)
model.to(device)
model.eval()


def process_sequence(sequence, domain_bounds, n):
    start_index = int(domain_bounds['start'][0]) - 1
    end_index = int(domain_bounds['end'][0])

    top_n_mutations = {}
    all_logits = []

    for i in range(len(sequence)):
          if start_index <= i <= (end_index - 1):
              masked_seq = sequence[:i] + '<mask>' + sequence[i+1:]
              inputs = tokenizer(masked_seq, return_tensors="pt", padding=True, truncation=True, max_length=2000)
              inputs = {k: v.to(device) for k, v in inputs.items()}
              with torch.no_grad():
                  logits = model(**inputs).logits
              mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
              mask_token_logits = logits[0, mask_token_index, :]

              # Define amino acid tokens
              AAs_tokens = ['L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C']
              all_tokens_logits = mask_token_logits.squeeze(0)
              top_tokens_indices = torch.argsort(all_tokens_logits, dim=0, descending=True)
              top_tokens_logits = all_tokens_logits[top_tokens_indices]
              mutation = []
              # make sure we don't include non-AA tokens
              for token_index in top_tokens_indices:
                  decoded_token = tokenizer.decode([token_index.item()])
                  if decoded_token in AAs_tokens:
                      mutation.append(decoded_token)
                      if len(mutation) == n:
                          break
              top_n_mutations[(sequence[i], i)] = mutation

              # collecting logits for the heatmap
              logits_array = mask_token_logits.cpu().numpy()
              # filter out non-amino acid tokens
              filtered_indices = list(range(4, 23 + 1))
              filtered_logits = logits_array[:, filtered_indices]
              all_logits.append(filtered_logits)


    token_indices = torch.arange(logits.size(-1))
    tokens = [tokenizer.decode([idx]) for idx in token_indices]
    filtered_tokens = [tokens[i] for i in filtered_indices]

    all_logits_array = np.vstack(all_logits)
    normalized_logits_array = F.softmax(torch.tensor(all_logits_array), dim=-1).numpy()
    transposed_logits_array = normalized_logits_array.T

   # Plotting the heatmap
    x_tick_positions = np.arange(start_index, end_index, 10)
    x_tick_labels = [str(pos + 1) for pos in x_tick_positions]

    plt.figure(figsize=(15, 8))
    plt.rcParams.update({'font.size': 18})

    sns.heatmap(transposed_logits_array, cmap='plasma', xticklabels=x_tick_labels, yticklabels=filtered_tokens)
    plt.title('Token Probability Heatmap')
    plt.ylabel('Token')
    plt.xlabel('Residue Index')
    plt.yticks(rotation=0)
    plt.xticks(x_tick_positions - start_index + 0.5, x_tick_labels, rotation=0)

    # Save the figure to a BytesIO object
    buf = BytesIO()
    plt.savefig(buf, format='png', dpi = 300)
    buf.seek(0)
    plt.close()

    # Convert BytesIO object to an image
    img = Image.open(buf)

    original_residues = []
    mutations = []
    positions = []

    for key, value in top_n_mutations.items():
        original_residue, position = key
        original_residues.append(original_residue)
        mutations.append(value)
        positions.append(position + 1)

    df = pd.DataFrame({
        'Original Residue': original_residues,
        'Predicted Residues': mutations,
        'Position': positions
    })
    df.to_csv("predicted_tokens.csv", index=False)
    img.save("heatmap.png", dpi = 300)
    zip_path = "outputs.zip"
    with zipfile.ZipFile(zip_path, 'w') as zipf:
        zipf.write("predicted_tokens.csv")
        zipf.write("heatmap.png")


    return df, img, zip_path

demo = gr.Interface(
    fn=process_sequence,
    inputs=[
        gr.Textbox(label="Sequence", placeholder="Enter the protein sequence here"),
        gr.Dataframe(
            headers=["start", "end"],
            datatype=["number", "number"],
            row_count=(1, "fixed"),
            col_count=(2, "fixed"),
            label="Domain Bounds"
        ),
        gr.Dropdown([i for i in range(1, 21)], label="Top N Tokens"),
    ],
     outputs=[
        gr.Dataframe(label="Predicted Tokens (in order of decreasing likelihood)"),
        gr.Image(type="pil", label="Heatmap"),
        gr.File(label="Download Outputs"),
    ],
)
if __name__ == "__main__":
    demo.launch()