test_webpage / app.py
ThorbenF's picture
Updated .js and HTML files to dynamically pass sequence to Hugging Face and display scores
50f1c9f
raw
history blame
6.19 kB
import gradio as gr
from model_loader import load_model
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from torch.utils.data import DataLoader
import re
import numpy as np
import os
import pandas as pd
import copy
import transformers, datasets
from transformers.modeling_outputs import TokenClassifierOutput
from transformers.models.t5.modeling_t5 import T5Config, T5PreTrainedModel, T5Stack
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
from transformers import T5EncoderModel, T5Tokenizer
from transformers.models.esm.modeling_esm import EsmPreTrainedModel, EsmModel
from transformers import AutoTokenizer
from transformers import TrainingArguments, Trainer, set_seed
from transformers import DataCollatorForTokenClassification
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union
# for custom DataCollator
from transformers.data.data_collator import DataCollatorMixin
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.utils import PaddingStrategy
from datasets import Dataset
from scipy.special import expit
#import peft
#from peft import get_peft_config, PeftModel, PeftConfig, inject_adapter_in_model, LoraConfig
checkpoint='ThorbenF/prot_t5_xl_uniref50'
max_length=1500
model, tokenizer = load_model(checkpoint,max_length)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
model.eval()
def create_dataset(tokenizer,seqs,labels,checkpoint):
tokenized = tokenizer(seqs, max_length=max_length, padding=False, truncation=True)
dataset = Dataset.from_dict(tokenized)
if ("esm" in checkpoint) or ("ProstT5" in checkpoint):
labels = [l[:max_length-2] for l in labels]
else:
labels = [l[:max_length-1] for l in labels]
dataset = dataset.add_column("labels", labels)
return dataset
def convert_predictions(input_logits):
all_probs = []
for logits in input_logits:
logits = logits.reshape(-1, 2)
# Mask out irrelevant regions
# Compute probabilities for class 1
probabilities_class1 = expit(logits[:, 1] - logits[:, 0])
all_probs.append(probabilities_class1)
return np.concatenate(all_probs)
def normalize_scores(scores):
min_score = np.min(scores)
max_score = np.max(scores)
return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores
def predict_protein_sequence(test_one_letter_sequence):
dummy_labels=[np.zeros(len(test_one_letter_sequence))]
# Replace uncommon amino acids with "X"
test_one_letter_sequence = test_one_letter_sequence.replace("O", "X").replace("B", "X").replace("U", "X").replace("Z", "X").replace("J", "X")
# Add spaces between each amino acid for ProtT5 and ProstT5 models
if ("prot_t5" in checkpoint) or ("ProstT5" in checkpoint):
test_one_letter_sequence = " ".join(test_one_letter_sequence)
# Add <AA2fold> for ProstT5 model input format
if "ProstT5" in checkpoint:
test_one_letter_sequence = "<AA2fold> " + test_one_letter_sequence
test_dataset=create_dataset(tokenizer,[test_one_letter_sequence],dummy_labels,checkpoint)
if ("esm" in checkpoint) or ("ProstT5" in checkpoint):
data_collator = DataCollatorForTokenClassificationESM(tokenizer)
else:
data_collator = DataCollatorForTokenClassification(tokenizer)
test_loader = DataLoader(test_dataset, batch_size=1, collate_fn=data_collator)
for batch in test_loader:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'] # Ensure to get labels from batch
outputs = model(input_ids, attention_mask=attention_mask)
logits = outputs.logits.detach().cpu().numpy()
logits = logits[:, :-1] #remove for prot_t5 the last element, because it is a special token
logits=convert_predictions(logits)
normalized_scores = normalize_scores(logits)
test_one_letter_sequence = test_one_letter_sequence.replace(" ", "")
result_str = "\n".join([f"{aa}: {score:.2f}" for aa, score in zip(test_one_letter_sequence, normalized_scores)])
return result_str
#interface = gr.Interface(
# fn=predict_protein_sequence,
# inputs=gr.Textbox(lines=2, placeholder="Enter protein sequence here..."),
# outputs=gr.Textbox(), #gr.JSON(), # Use gr.JSON() for list or array-like outputs
# title="Protein sequence - Binding site prediction",
# description="Enter a protein sequence to predict its possible binding sites.",
#)
# Launch the app
#interface.launch()
# Function to fetch and visualize the PDB structure using py3Dmol
def fetch_and_display_pdb(pdb_id):
# Fetch the PDB structure from the RCSB
pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb'
response = requests.get(pdb_url)
if response.status_code == 200:
pdb_structure = response.text
else:
return "Failed to load PDB structure. Please check the PDB ID."
# Initialize the viewer
viewer = py3Dmol.view(width=800, height=400)
viewer.addModel(pdb_structure, "pdb")
viewer.setStyle({}, {"cartoon": {"color": "spectrum"}})
viewer.zoomTo()
return viewer._make_html()
# Define the Gradio interface
interface = gr.Interface(
fn=predict_protein_sequence,
inputs=[
gr.Textbox(lines=2, placeholder="Enter protein sequence here...", label="Protein Sequence"),
gr.Textbox(lines=1, placeholder="Enter PDB ID here...", label="PDB ID for 3D Visualization")
],
outputs=[
gr.Textbox(label="Binding Site Predictions"),
gr.HTML(label="3Dmol Viewer") # HTML output to render the 3Dmol viewer
],
title="Protein Binding Site Prediction and 3D Structure Viewer",
description="Input a protein sequence to predict binding sites and view the protein structure in 3D using its PDB ID.",
)
# Launch the Gradio app
interface.launch()