Kseniia-Kholina's picture
Update app.py
37ffe44 verified
raw
history blame
2.58 kB
import transformers
from transformers import AutoTokenizer, AutoModelForMaskedLM
import logging
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import gradio as gr
def get_heatmap(sequence):
logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Load the tokenizer and model
model_name = "ChatterjeeLab/FusOn-pLM"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForMaskedLM.from_pretrained(model_name, trust_remote_code=True)
model.to(device)
model.eval()
all_logits = []
for i in range(len(sequence)):
# add a masked token
masked_seq = sequence[:i] + '<mask>' + sequence[i+1:]
# tokenize masked sequence
inputs = tokenizer(masked_seq, return_tensors="pt", padding=True, truncation=True,max_length=2000)
inputs = {k: v.to(device) for k, v in inputs.items()}
# predict logits for the masked token
with torch.no_grad():
logits = model(**inputs).logits
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
mask_token_logits = logits[0, mask_token_index, :]
top_1_tokens = torch.topk(mask_token_logits, 1, dim=1).indices[0].item()
logits_array = mask_token_logits.cpu().numpy()
# filter out non-amino acid tokens
filtered_indices = list(range(4, 23 + 1))
filtered_logits = logits_array[:, filtered_indices]
all_logits.append(filtered_logits)
token_indices = torch.arange(logits.size(-1))
tokens = [tokenizer.decode([idx]) for idx in token_indices]
filtered_tokens = [tokens[i] for i in filtered_indices]
all_logits_array = np.vstack(all_logits)
normalized_logits_array = (all_logits_array - all_logits_array.min()) / (all_logits_array.max() - all_logits_array.min())
transposed_logits_array = normalized_logits_array.T
# Plotting the heatmap
step = 50
y_tick_positions = np.arange(0, len(sequence), step)
y_tick_labels = [str(pos) for pos in y_tick_positions]
plt.figure(figsize=(15, 8))
sns.heatmap(transposed_logits_array, cmap='plasma', xticklabels=y_tick_labels, yticklabels=filtered_tokens)
plt.title('Logits for masked per residue tokens')
plt.ylabel('Token')
plt.xlabel('Residue Index')
plt.yticks(rotation=0)
plt.xticks(y_tick_positions, y_tick_labels, rotation = 0)
fig = plt.gcf()
plt.close(fig)
return fig
demo = gr.Interface(fn=get_heatmap, inputs="text", outputs="image")
demo.launch()