Kseniia-Kholina's picture
Rename app.py to app_all_seq.py
dd5dc35 verified
raw
history blame
4.14 kB
import gradio as gr
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
import logging
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from io import BytesIO
from PIL import Image
logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Load the tokenizer and model
model_name = "ChatterjeeLab/FusOn-pLM"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForMaskedLM.from_pretrained(model_name, trust_remote_code=True)
model.to(device)
model.eval()
def process_sequence(sequence, domain_bounds, n):
start_index = int(domain_bounds['start'][0]) - 1
end_index = int(domain_bounds['end'][0])
top_n_mutations = {}
all_logits = []
for i in range(len(sequence)):
masked_seq = sequence[:i] + '<mask>' + sequence[i+1:]
inputs = tokenizer(masked_seq, return_tensors="pt", padding=True, truncation=True, max_length=2000)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
logits = model(**inputs).logits
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
mask_token_logits = logits[0, mask_token_index, :]
# Decode top n tokens
top_n_tokens = torch.topk(mask_token_logits, n, dim=1).indices[0].tolist()
mutation = [tokenizer.decode([token]) for token in top_n_tokens]
top_n_mutations[(sequence[i], i)] = mutation
logits_array = mask_token_logits.cpu().numpy()
# filter out non-amino acid tokens
filtered_indices = list(range(4, 23 + 1))
filtered_logits = logits_array[:, filtered_indices]
all_logits.append(filtered_logits)
token_indices = torch.arange(logits.size(-1))
tokens = [tokenizer.decode([idx]) for idx in token_indices]
filtered_tokens = [tokens[i] for i in filtered_indices]
all_logits_array = np.vstack(all_logits)
normalized_logits_array = (all_logits_array - all_logits_array.min()) / (all_logits_array.max() - all_logits_array.min())
transposed_logits_array = normalized_logits_array.T
# Plotting the heatmap
step = 50
y_tick_positions = np.arange(0, len(sequence), step)
y_tick_labels = [str(pos) for pos in y_tick_positions]
plt.figure(figsize=(15, 8))
sns.heatmap(transposed_logits_array, cmap='plasma', xticklabels=y_tick_labels, yticklabels=filtered_tokens)
plt.title('Logits for masked per residue tokens')
plt.ylabel('Token')
plt.xlabel('Residue Index')
plt.yticks(rotation=0)
plt.xticks(y_tick_positions, y_tick_labels, rotation = 0)
# Save the figure to a BytesIO object
buf = BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
plt.close()
# Convert BytesIO object to an image
img = Image.open(buf)
original_residues = []
mutations = []
positions = []
for key, value in top_n_mutations.items():
original_residue, position = key
original_residues.append(original_residue)
mutations.append(value)
positions.append(position + 1)
df = pd.DataFrame({
'Original Residue': original_residues,
'Predicted Residues (in order of decreasing likelihood)': mutations,
'Position': positions
})
df = df[start_index:end_index]
return df, img
demo = gr.Interface(
fn=process_sequence,
inputs=[
"text",
gr.Dataframe(
headers=["start", "end"],
datatype=["number", "number"],
row_count=(1, "fixed"),
col_count=(2, "fixed"),
),
gr.Dropdown([i for i in range(1, 21)]), # Dropdown with numbers from 1 to 20 as integers
],
outputs=["dataframe", "image"],
description="Choose a number between 1-20 to predict n tokens for each position. Choose the start and end index of the domain of interest (indexing starts at 1).",
)
if __name__ == "__main__":
demo.launch()