import transformers from transformers import AutoTokenizer, AutoModelForMaskedLM import logging import torch import matplotlib.pyplot as plt import seaborn as sns import numpy as np import gradio as gr from io import BytesIO from PIL import Image def get_heatmap(sequence): logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # Load the tokenizer and model model_name = "ChatterjeeLab/FusOn-pLM" tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) model = AutoModelForMaskedLM.from_pretrained(model_name, trust_remote_code=True) model.to(device) model.eval() all_logits = [] for i in range(len(sequence)): # add a masked token masked_seq = sequence[:i] + '' + sequence[i+1:] # tokenize masked sequence inputs = tokenizer(masked_seq, return_tensors="pt", padding=True, truncation=True,max_length=2000) inputs = {k: v.to(device) for k, v in inputs.items()} # predict logits for the masked token with torch.no_grad(): logits = model(**inputs).logits mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1] mask_token_logits = logits[0, mask_token_index, :] top_1_tokens = torch.topk(mask_token_logits, 1, dim=1).indices[0].item() logits_array = mask_token_logits.cpu().numpy() # filter out non-amino acid tokens filtered_indices = list(range(4, 23 + 1)) filtered_logits = logits_array[:, filtered_indices] all_logits.append(filtered_logits) token_indices = torch.arange(logits.size(-1)) tokens = [tokenizer.decode([idx]) for idx in token_indices] filtered_tokens = [tokens[i] for i in filtered_indices] all_logits_array = np.vstack(all_logits) normalized_logits_array = (all_logits_array - all_logits_array.min()) / (all_logits_array.max() - all_logits_array.min()) transposed_logits_array = normalized_logits_array.T # Plotting the heatmap step = 50 y_tick_positions = np.arange(0, len(sequence), step) y_tick_labels = [str(pos) for pos in y_tick_positions] plt.figure(figsize=(15, 8)) sns.heatmap(transposed_logits_array, cmap='plasma', xticklabels=y_tick_labels, yticklabels=filtered_tokens) plt.title('Logits for masked per residue tokens') plt.ylabel('Token') plt.xlabel('Residue Index') plt.yticks(rotation=0) plt.xticks(y_tick_positions, y_tick_labels, rotation = 0) # Save the figure to a BytesIO object buf = BytesIO() plt.savefig(buf, format='png') buf.seek(0) plt.close() # Convert BytesIO object to an image img = Image.open(buf) return img demo = gr.Interface(fn=get_heatmap, inputs="text", outputs="image") demo.launch()