import transformers from transformers import AutoTokenizer, AutoModelForMaskedLM import logging import torch import matplotlib.pyplot as plt import seaborn as sns import numpy as np import gradio as gr logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # Load the tokenizer and model model_name = "ChatterjeeLab/FusOn-pLM" tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) model = AutoModelForMaskedLM.from_pretrained(model_name, trust_remote_code=True) model.to(device) model.eval() # fix this to take dynamic input sequence = 'MCNTNMS' all_logits = [] for i in range(len(sequence)): # add a masked token masked_seq = sequence[:i] + '' + sequence[i+1:] # tokenize masked sequence inputs = tokenizer(masked_seq, return_tensors="pt", padding=True, truncation=True,max_length=2000) inputs = {k: v.to(device) for k, v in inputs.items()} # predict logits for the masked token with torch.no_grad(): logits = model(**inputs).logits mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1] mask_token_logits = logits[0, mask_token_index, :] top_1_tokens = torch.topk(mask_token_logits, 1, dim=1).indices[0].item() logits_array = mask_token_logits.cpu().numpy() # filter out non-amino acid tokens filtered_indices = list(range(4, 23 + 1)) filtered_logits = logits_array[:, filtered_indices] all_logits.append(filtered_logits) token_indices = torch.arange(logits.size(-1)) tokens = [tokenizer.decode([idx]) for idx in token_indices] filtered_tokens = [tokens[i] for i in filtered_indices] all_logits_array = np.vstack(all_logits) normalized_logits_array = (all_logits_array - all_logits_array.min()) / (all_logits_array.max() - all_logits_array.min()) transposed_logits_array = normalized_logits_array.T # Plotting the heatmap step = 50 y_tick_positions = np.arange(0, len(sequence), step) y_tick_labels = [str(pos) for pos in y_tick_positions] plt.figure(figsize=(15, 8)) sns.heatmap(transposed_logits_array, cmap='plasma', xticklabels=y_tick_labels, yticklabels=filtered_tokens) plt.title('Logits for masked per residue tokens') plt.ylabel('Token') plt.xlabel('Residue Index') plt.yticks(rotation=0) plt.xticks(y_tick_positions, y_tick_labels, rotation = 0) plt.show() plt.savefig(f'heatmap_{i}.png', dpi=300, bbox_inches='tight')