File size: 4,386 Bytes
2aebbc2
 
 
 
257341a
2aebbc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a18e3ef
2aebbc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c578c17
 
 
 
 
257341a
c578c17
2aebbc2
af69adc
 
257341a
2aebbc2
c578c17
257341a
c578c17
 
 
 
af69adc
2aebbc2
c578c17
 
 
 
 
2aebbc2
c578c17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299ebbf
c578c17
 
2aebbc2
 
 
 
 
 
299ebbf
2aebbc2
 
 
 
 
299ebbf
2aebbc2
299ebbf
2aebbc2
299ebbf
 
 
 
 
2aebbc2
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import gradio as gr
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch.nn.functional as F
import logging
import numpy as np 
import matplotlib.pyplot as plt
import seaborn as sns
from io import BytesIO
from PIL import Image

logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load the tokenizer and model
model_name = "ChatterjeeLab/FusOn-pLM"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForMaskedLM.from_pretrained(model_name, trust_remote_code=True)
model.to(device)
model.eval()

def process_sequence(sequence, domain_bounds, n):
    start_index = int(domain_bounds['start'][0]) - 1  
    end_index = int(domain_bounds['end'][0])      

    top_n_mutations = {}
    all_logits = []

    for i in range(len(sequence)):
        if start_index <= i <= (end_index - 1):
            masked_seq = sequence[:i] + '<mask>' + sequence[i+1:]
            inputs = tokenizer(masked_seq, return_tensors="pt", padding=True, truncation=True, max_length=2000)
            inputs = {k: v.to(device) for k, v in inputs.items()}
            with torch.no_grad():
                logits = model(**inputs).logits
            mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
            mask_token_logits = logits[0, mask_token_index, :]
            # Decode top n tokens
            top_n_tokens = torch.topk(mask_token_logits, n, dim=1).indices[0].tolist()
            mutation = [tokenizer.decode([token]) for token in top_n_tokens]
            top_n_mutations[(sequence[i], i)] = mutation 
    
            logits_array = mask_token_logits.cpu().numpy()
            # filter out non-amino acid tokens
            filtered_indices = list(range(4, 23 + 1))
            filtered_logits = logits_array[:, filtered_indices]
            all_logits.append(filtered_logits)
    
    token_indices = torch.arange(logits.size(-1))
    tokens = [tokenizer.decode([idx]) for idx in token_indices]
    filtered_tokens = [tokens[i] for i in filtered_indices]
    
    all_logits_array = np.vstack(all_logits)
    normalized_logits_array = F.softmax(torch.tensor(all_logits_array), dim=-1).numpy()
    transposed_logits_array = normalized_logits_array.T

   # Plotting the heatmap
    x_tick_positions = np.arange(start_index, end_index, 10)
    x_tick_labels = [str(pos + 1) for pos in x_tick_positions]
    
    plt.figure(figsize=(15, 8))
    sns.heatmap(transposed_logits_array, cmap='plasma', xticklabels=x_tick_labels, yticklabels=filtered_tokens)
    plt.title('Logits for masked per residue tokens')
    plt.ylabel('Token')
    plt.xlabel('Residue Index')
    plt.yticks(rotation=0)
    plt.xticks(x_tick_positions - start_index, x_tick_labels, rotation=0) 
    
    # Save the figure to a BytesIO object
    buf = BytesIO()
    plt.savefig(buf, format='png')
    buf.seek(0)
    plt.close()
    
    # Convert BytesIO object to an image
    img = Image.open(buf)

    original_residues = []
    mutations = []
    positions = []

    for key, value in top_n_mutations.items():
        original_residue, position = key
        original_residues.append(original_residue)
        mutations.append(value)
        positions.append(position + 1)

    df = pd.DataFrame({
        'Original Residue': original_residues,
        'Predicted Residues': mutations,
        'Position': positions
    })
            
    return df, img

demo = gr.Interface(
    fn=process_sequence,
    inputs=[
        gr.Textbox(label="Sequence", placeholder="Enter the protein sequence here"),
        gr.Dataframe(
            headers=["start", "end"],
            datatype=["number", "number"],
            row_count=(1, "fixed"),
            col_count=(2, "fixed"),
            label="Domain Bounds"
        ),
        gr.Dropdown([i for i in range(1, 21)], label="Top N Tokens"),  
    ],
     outputs=[
        gr.Dataframe(label="Predicted Tokens (in order of decreasing likelihood)"),
        gr.Image(type="pil", label="Heatmap"),
    ],
    description="Choose a number from the dropdown to predict N tokens for each position. Choose the start and end index of the domain of interest (indexing starts at 1).",
)

if __name__ == "__main__":
    demo.launch()