File size: 4,136 Bytes
968e490
 
 
 
 
3f8bb76
 
 
 
 
968e490
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import gradio as gr
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
import logging
import numpy as np 
import matplotlib.pyplot as plt
import seaborn as sns
from io import BytesIO
from PIL import Image

logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load the tokenizer and model
model_name = "ChatterjeeLab/FusOn-pLM"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForMaskedLM.from_pretrained(model_name, trust_remote_code=True)
model.to(device)
model.eval()

def process_sequence(sequence, domain_bounds, n):
    start_index = int(domain_bounds['start'][0]) - 1  
    end_index = int(domain_bounds['end'][0])      

    top_n_mutations = {}
    all_logits = []

    for i in range(len(sequence)):
        masked_seq = sequence[:i] + '<mask>' + sequence[i+1:]
        inputs = tokenizer(masked_seq, return_tensors="pt", padding=True, truncation=True, max_length=2000)
        inputs = {k: v.to(device) for k, v in inputs.items()}
        with torch.no_grad():
            logits = model(**inputs).logits
        mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
        mask_token_logits = logits[0, mask_token_index, :]
        # Decode top n tokens
        top_n_tokens = torch.topk(mask_token_logits, n, dim=1).indices[0].tolist()
        mutation = [tokenizer.decode([token]) for token in top_n_tokens]
        top_n_mutations[(sequence[i], i)] = mutation 

        logits_array = mask_token_logits.cpu().numpy()
        # filter out non-amino acid tokens
        filtered_indices = list(range(4, 23 + 1))
        filtered_logits = logits_array[:, filtered_indices]
        all_logits.append(filtered_logits)

    token_indices = torch.arange(logits.size(-1))
    tokens = [tokenizer.decode([idx]) for idx in token_indices]
    filtered_tokens = [tokens[i] for i in filtered_indices]
    
    all_logits_array = np.vstack(all_logits)
    normalized_logits_array = (all_logits_array - all_logits_array.min()) / (all_logits_array.max() - all_logits_array.min())
    transposed_logits_array = normalized_logits_array.T

    # Plotting the heatmap
    step = 50
    y_tick_positions = np.arange(0, len(sequence), step)
    y_tick_labels = [str(pos) for pos in y_tick_positions]
    
    plt.figure(figsize=(15, 8))
    sns.heatmap(transposed_logits_array, cmap='plasma', xticklabels=y_tick_labels, yticklabels=filtered_tokens)
    plt.title('Logits for masked per residue tokens')
    plt.ylabel('Token')
    plt.xlabel('Residue Index')
    plt.yticks(rotation=0)
    plt.xticks(y_tick_positions, y_tick_labels, rotation = 0)
    
    # Save the figure to a BytesIO object
    buf = BytesIO()
    plt.savefig(buf, format='png')
    buf.seek(0)
    plt.close()
    
    # Convert BytesIO object to an image
    img = Image.open(buf)

    original_residues = []
    mutations = []
    positions = []

    for key, value in top_n_mutations.items():
        original_residue, position = key
        original_residues.append(original_residue)
        mutations.append(value)
        positions.append(position + 1)

    df = pd.DataFrame({
        'Original Residue': original_residues,
        'Predicted Residues (in order of decreasing likelihood)': mutations,
        'Position': positions
    })

    df = df[start_index:end_index]
        
    return df, img

demo = gr.Interface(
    fn=process_sequence,
    inputs=[
        "text",
        gr.Dataframe(
            headers=["start", "end"],
            datatype=["number", "number"],
            row_count=(1, "fixed"),
            col_count=(2, "fixed"),
        ),
        gr.Dropdown([i for i in range(1, 21)]),  # Dropdown with numbers from 1 to 20 as integers
    ],
    outputs=["dataframe", "image"],
    description="Choose a number between 1-20 to predict n tokens for each position. Choose the start and end index of the domain of interest (indexing starts at 1).",
)

if __name__ == "__main__":
    demo.launch()