File size: 4,386 Bytes
2aebbc2 257341a 2aebbc2 a18e3ef 2aebbc2 c578c17 257341a c578c17 2aebbc2 af69adc 257341a 2aebbc2 c578c17 257341a c578c17 af69adc 2aebbc2 c578c17 2aebbc2 c578c17 299ebbf c578c17 2aebbc2 299ebbf 2aebbc2 299ebbf 2aebbc2 299ebbf 2aebbc2 299ebbf 2aebbc2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
import gradio as gr
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch.nn.functional as F
import logging
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from io import BytesIO
from PIL import Image
logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Load the tokenizer and model
model_name = "ChatterjeeLab/FusOn-pLM"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForMaskedLM.from_pretrained(model_name, trust_remote_code=True)
model.to(device)
model.eval()
def process_sequence(sequence, domain_bounds, n):
start_index = int(domain_bounds['start'][0]) - 1
end_index = int(domain_bounds['end'][0])
top_n_mutations = {}
all_logits = []
for i in range(len(sequence)):
if start_index <= i <= (end_index - 1):
masked_seq = sequence[:i] + '<mask>' + sequence[i+1:]
inputs = tokenizer(masked_seq, return_tensors="pt", padding=True, truncation=True, max_length=2000)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
logits = model(**inputs).logits
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
mask_token_logits = logits[0, mask_token_index, :]
# Decode top n tokens
top_n_tokens = torch.topk(mask_token_logits, n, dim=1).indices[0].tolist()
mutation = [tokenizer.decode([token]) for token in top_n_tokens]
top_n_mutations[(sequence[i], i)] = mutation
logits_array = mask_token_logits.cpu().numpy()
# filter out non-amino acid tokens
filtered_indices = list(range(4, 23 + 1))
filtered_logits = logits_array[:, filtered_indices]
all_logits.append(filtered_logits)
token_indices = torch.arange(logits.size(-1))
tokens = [tokenizer.decode([idx]) for idx in token_indices]
filtered_tokens = [tokens[i] for i in filtered_indices]
all_logits_array = np.vstack(all_logits)
normalized_logits_array = F.softmax(torch.tensor(all_logits_array), dim=-1).numpy()
transposed_logits_array = normalized_logits_array.T
# Plotting the heatmap
x_tick_positions = np.arange(start_index, end_index, 10)
x_tick_labels = [str(pos + 1) for pos in x_tick_positions]
plt.figure(figsize=(15, 8))
sns.heatmap(transposed_logits_array, cmap='plasma', xticklabels=x_tick_labels, yticklabels=filtered_tokens)
plt.title('Logits for masked per residue tokens')
plt.ylabel('Token')
plt.xlabel('Residue Index')
plt.yticks(rotation=0)
plt.xticks(x_tick_positions - start_index, x_tick_labels, rotation=0)
# Save the figure to a BytesIO object
buf = BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
plt.close()
# Convert BytesIO object to an image
img = Image.open(buf)
original_residues = []
mutations = []
positions = []
for key, value in top_n_mutations.items():
original_residue, position = key
original_residues.append(original_residue)
mutations.append(value)
positions.append(position + 1)
df = pd.DataFrame({
'Original Residue': original_residues,
'Predicted Residues': mutations,
'Position': positions
})
return df, img
demo = gr.Interface(
fn=process_sequence,
inputs=[
gr.Textbox(label="Sequence", placeholder="Enter the protein sequence here"),
gr.Dataframe(
headers=["start", "end"],
datatype=["number", "number"],
row_count=(1, "fixed"),
col_count=(2, "fixed"),
label="Domain Bounds"
),
gr.Dropdown([i for i in range(1, 21)], label="Top N Tokens"),
],
outputs=[
gr.Dataframe(label="Predicted Tokens (in order of decreasing likelihood)"),
gr.Image(type="pil", label="Heatmap"),
],
description="Choose a number from the dropdown to predict N tokens for each position. Choose the start and end index of the domain of interest (indexing starts at 1).",
)
if __name__ == "__main__":
demo.launch() |