File size: 5,272 Bytes
25d7145 26ac9e7 406e1bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import gradio as gr
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch.nn.functional as F
import logging
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from io import BytesIO
from PIL import Image
from contextlib import contextmanager
import warnings
import sys
import os
import zipfile
logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Load the tokenizer and model
model_name = "ChatterjeeLab/FusOn-pLM"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForMaskedLM.from_pretrained(model_name, trust_remote_code=True)
model.to(device)
model.eval()
def process_sequence(sequence, domain_bounds, n):
start_index = int(domain_bounds['start'][0]) - 1
end_index = int(domain_bounds['end'][0])
top_n_mutations = {}
all_logits = []
for i in range(len(sequence)):
if start_index <= i <= (end_index - 1):
masked_seq = sequence[:i] + '<mask>' + sequence[i+1:]
inputs = tokenizer(masked_seq, return_tensors="pt", padding=True, truncation=True, max_length=2000)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
logits = model(**inputs).logits
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
mask_token_logits = logits[0, mask_token_index, :]
# Define amino acid tokens
AAs_tokens = ['L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C']
all_tokens_logits = mask_token_logits.squeeze(0)
top_tokens_indices = torch.argsort(all_tokens_logits, dim=0, descending=True)
top_tokens_logits = all_tokens_logits[top_tokens_indices]
mutation = []
# make sure we don't include non-AA tokens
for token_index in top_tokens_indices:
decoded_token = tokenizer.decode([token_index.item()])
if decoded_token in AAs_tokens:
mutation.append(decoded_token)
if len(mutation) == n:
break
top_n_mutations[(sequence[i], i)] = mutation
# collecting logits for the heatmap
logits_array = mask_token_logits.cpu().numpy()
# filter out non-amino acid tokens
filtered_indices = list(range(4, 23 + 1))
filtered_logits = logits_array[:, filtered_indices]
all_logits.append(filtered_logits)
token_indices = torch.arange(logits.size(-1))
tokens = [tokenizer.decode([idx]) for idx in token_indices]
filtered_tokens = [tokens[i] for i in filtered_indices]
all_logits_array = np.vstack(all_logits)
normalized_logits_array = F.softmax(torch.tensor(all_logits_array), dim=-1).numpy()
transposed_logits_array = normalized_logits_array.T
# Plotting the heatmap
x_tick_positions = np.arange(start_index, end_index, 10)
x_tick_labels = [str(pos + 1) for pos in x_tick_positions]
plt.figure(figsize=(15, 8))
plt.rcParams.update({'font.size': 18})
sns.heatmap(transposed_logits_array, cmap='plasma', xticklabels=x_tick_labels, yticklabels=filtered_tokens)
plt.title('Token Probability Heatmap')
plt.ylabel('Token')
plt.xlabel('Residue Index')
plt.yticks(rotation=0)
plt.xticks(x_tick_positions - start_index + 0.5, x_tick_labels, rotation=0)
# Save the figure to a BytesIO object
buf = BytesIO()
plt.savefig(buf, format='png', dpi = 300)
buf.seek(0)
plt.close()
# Convert BytesIO object to an image
img = Image.open(buf)
original_residues = []
mutations = []
positions = []
for key, value in top_n_mutations.items():
original_residue, position = key
original_residues.append(original_residue)
mutations.append(value)
positions.append(position + 1)
df = pd.DataFrame({
'Original Residue': original_residues,
'Predicted Residues': mutations,
'Position': positions
})
df.to_csv("predicted_tokens.csv", index=False)
img.save("heatmap.png", dpi = 300)
zip_path = "outputs.zip"
with zipfile.ZipFile(zip_path, 'w') as zipf:
zipf.write("predicted_tokens.csv")
zipf.write("heatmap.png")
return df, img, zip_path
demo = gr.Interface(
fn=process_sequence,
inputs=[
gr.Textbox(label="Sequence", placeholder="Enter the protein sequence here"),
gr.Dataframe(
headers=["start", "end"],
datatype=["number", "number"],
row_count=(1, "fixed"),
col_count=(2, "fixed"),
label="Domain Bounds"
),
gr.Dropdown([i for i in range(1, 21)], label="Top N Tokens"),
],
outputs=[
gr.Dataframe(label="Predicted Tokens (in order of decreasing likelihood)"),
gr.Image(type="pil", label="Heatmap"),
gr.File(label="Download Outputs"),
],
)
if __name__ == "__main__":
demo.launch() |