import gradio as gr import pandas as pd import torch from transformers import AutoTokenizer, AutoModelForMaskedLM import logging import numpy as np import matplotlib.pyplot as plt import seaborn as sns from io import BytesIO from PIL import Image logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # Load the tokenizer and model model_name = "ChatterjeeLab/FusOn-pLM" tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) model = AutoModelForMaskedLM.from_pretrained(model_name, trust_remote_code=True) model.to(device) model.eval() def process_sequence(sequence, domain_bounds, n): start_index = int(domain_bounds['start'][0]) - 1 end_index = int(domain_bounds['end'][0]) top_n_mutations = {} all_logits = [] for i in range(len(sequence)): masked_seq = sequence[:i] + '' + sequence[i+1:] inputs = tokenizer(masked_seq, return_tensors="pt", padding=True, truncation=True, max_length=2000) inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): logits = model(**inputs).logits mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1] mask_token_logits = logits[0, mask_token_index, :] # Decode top n tokens top_n_tokens = torch.topk(mask_token_logits, n, dim=1).indices[0].tolist() mutation = [tokenizer.decode([token]) for token in top_n_tokens] top_n_mutations[(sequence[i], i)] = mutation logits_array = mask_token_logits.cpu().numpy() # filter out non-amino acid tokens filtered_indices = list(range(4, 23 + 1)) filtered_logits = logits_array[:, filtered_indices] all_logits.append(filtered_logits) token_indices = torch.arange(logits.size(-1)) tokens = [tokenizer.decode([idx]) for idx in token_indices] filtered_tokens = [tokens[i] for i in filtered_indices] all_logits_array = np.vstack(all_logits) normalized_logits_array = (all_logits_array - all_logits_array.min()) / (all_logits_array.max() - all_logits_array.min()) transposed_logits_array = normalized_logits_array.T # Plotting the heatmap step = 50 y_tick_positions = np.arange(0, len(sequence), step) y_tick_labels = [str(pos) for pos in y_tick_positions] plt.figure(figsize=(15, 8)) sns.heatmap(transposed_logits_array, cmap='plasma', xticklabels=y_tick_labels, yticklabels=filtered_tokens) plt.title('Logits for masked per residue tokens') plt.ylabel('Token') plt.xlabel('Residue Index') plt.yticks(rotation=0) plt.xticks(y_tick_positions, y_tick_labels, rotation = 0) # Save the figure to a BytesIO object buf = BytesIO() plt.savefig(buf, format='png') buf.seek(0) plt.close() # Convert BytesIO object to an image img = Image.open(buf) original_residues = [] mutations = [] positions = [] for key, value in top_n_mutations.items(): original_residue, position = key original_residues.append(original_residue) mutations.append(value) positions.append(position + 1) df = pd.DataFrame({ 'Original Residue': original_residues, 'Predicted Residues (in order of decreasing likelihood)': mutations, 'Position': positions }) df = df[start_index:end_index] return df, img demo = gr.Interface( fn=process_sequence, inputs=[ "text", gr.Dataframe( headers=["start", "end"], datatype=["number", "number"], row_count=(1, "fixed"), col_count=(2, "fixed"), ), gr.Dropdown([i for i in range(1, 21)]), # Dropdown with numbers from 1 to 20 as integers ], outputs=["dataframe", "image"], description="Choose a number between 1-20 to predict n tokens for each position. Choose the start and end index of the domain of interest (indexing starts at 1).", ) if __name__ == "__main__": demo.launch()