import gradio as gr import pandas as pd import torch from transformers import AutoTokenizer, AutoModelForMaskedLM import torch.nn.functional as F import logging import numpy as np import matplotlib.pyplot as plt import seaborn as sns from io import BytesIO from PIL import Image from contextlib import contextmanager import warnings import sys import os import zipfile logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # Load the tokenizer and model model_name = "ChatterjeeLab/FusOn-pLM" tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) model = AutoModelForMaskedLM.from_pretrained(model_name, trust_remote_code=True) model.to(device) model.eval() def process_sequence(sequence, domain_bounds, n): start_index = int(domain_bounds['start'][0]) - 1 end_index = int(domain_bounds['end'][0]) top_n_mutations = {} all_logits = [] for i in range(len(sequence)): if start_index <= i <= (end_index - 1): masked_seq = sequence[:i] + '' + sequence[i+1:] inputs = tokenizer(masked_seq, return_tensors="pt", padding=True, truncation=True, max_length=2000) inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): logits = model(**inputs).logits mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1] mask_token_logits = logits[0, mask_token_index, :] # Define amino acid tokens AAs_tokens = ['L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C'] all_tokens_logits = mask_token_logits.squeeze(0) top_tokens_indices = torch.argsort(all_tokens_logits, dim=0, descending=True) top_tokens_logits = all_tokens_logits[top_tokens_indices] mutation = [] # make sure we don't include non-AA tokens for token_index in top_tokens_indices: decoded_token = tokenizer.decode([token_index.item()]) if decoded_token in AAs_tokens: mutation.append(decoded_token) if len(mutation) == n: break top_n_mutations[(sequence[i], i)] = mutation # collecting logits for the heatmap logits_array = mask_token_logits.cpu().numpy() # filter out non-amino acid tokens filtered_indices = list(range(4, 23 + 1)) filtered_logits = logits_array[:, filtered_indices] all_logits.append(filtered_logits) token_indices = torch.arange(logits.size(-1)) tokens = [tokenizer.decode([idx]) for idx in token_indices] filtered_tokens = [tokens[i] for i in filtered_indices] all_logits_array = np.vstack(all_logits) normalized_logits_array = F.softmax(torch.tensor(all_logits_array), dim=-1).numpy() transposed_logits_array = normalized_logits_array.T # Plotting the heatmap x_tick_positions = np.arange(start_index, end_index, 10) x_tick_labels = [str(pos + 1) for pos in x_tick_positions] plt.figure(figsize=(15, 8)) plt.rcParams.update({'font.size': 18}) sns.heatmap(transposed_logits_array, cmap='plasma', xticklabels=x_tick_labels, yticklabels=filtered_tokens) plt.title('Token Probability Heatmap') plt.ylabel('Token') plt.xlabel('Residue Index') plt.yticks(rotation=0) plt.xticks(x_tick_positions - start_index + 0.5, x_tick_labels, rotation=0) # Save the figure to a BytesIO object buf = BytesIO() plt.savefig(buf, format='png', dpi = 300) buf.seek(0) plt.close() # Convert BytesIO object to an image img = Image.open(buf) original_residues = [] mutations = [] positions = [] for key, value in top_n_mutations.items(): original_residue, position = key original_residues.append(original_residue) mutations.append(value) positions.append(position + 1) df = pd.DataFrame({ 'Original Residue': original_residues, 'Predicted Residues': mutations, 'Position': positions }) df.to_csv("predicted_tokens.csv", index=False) img.save("heatmap.png", dpi = 300) zip_path = "outputs.zip" with zipfile.ZipFile(zip_path, 'w') as zipf: zipf.write("predicted_tokens.csv") zipf.write("heatmap.png") return df, img, zip_path demo = gr.Interface( fn=process_sequence, inputs=[ gr.Textbox(label="Sequence", placeholder="Enter the protein sequence here"), gr.Dataframe( headers=["start", "end"], datatype=["number", "number"], row_count=(1, "fixed"), col_count=(2, "fixed"), label="Domain Bounds" ), gr.Dropdown([i for i in range(1, 21)], label="Top N Tokens"), ], outputs=[ gr.Dataframe(label="Predicted Tokens (in order of decreasing likelihood)"), gr.Image(type="pil", label="Heatmap"), gr.File(label="Download Outputs"), ], ) if __name__ == "__main__": demo.launch()