test_webpage / app.py
ThorbenF's picture
Update
01ff8b6
raw
history blame
7.63 kB
import gradio as gr
from model_loader import load_model
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from torch.utils.data import DataLoader
import re
import numpy as np
import os
import pandas as pd
import copy
import transformers, datasets
from transformers.modeling_outputs import TokenClassifierOutput
from transformers.models.t5.modeling_t5 import T5Config, T5PreTrainedModel, T5Stack
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
from transformers import T5EncoderModel, T5Tokenizer
from transformers.models.esm.modeling_esm import EsmPreTrainedModel, EsmModel
from transformers import AutoTokenizer
from transformers import TrainingArguments, Trainer, set_seed
from transformers import DataCollatorForTokenClassification
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union
# for custom DataCollator
from transformers.data.data_collator import DataCollatorMixin
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.utils import PaddingStrategy
from datasets import Dataset
from scipy.special import expit
import requests
from gradio_molecule3d import Molecule3D
#import peft
#from peft import get_peft_config, PeftModel, PeftConfig, inject_adapter_in_model, LoraConfig
# Configuration
checkpoint = 'ThorbenF/prot_t5_xl_uniref50'
max_length = 1500
# Default representations for molecule rendering
reps = [
{
"model": 0,
"chain": "",
"resname": "",
"style": "cartoon",
"color": "spectrum",
"residue_range": "",
"around": 0,
"byres": False,
"visible": True
}
]
# Load model and move to device
model, tokenizer = load_model(checkpoint, max_length)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
model.eval()
def create_dataset(tokenizer, seqs, labels, checkpoint):
tokenized = tokenizer(seqs, max_length=max_length, padding=False, truncation=True)
dataset = Dataset.from_dict(tokenized)
# Adjust labels based on checkpoint
if ("esm" in checkpoint) or ("ProstT5" in checkpoint):
labels = [l[:max_length-2] for l in labels]
else:
labels = [l[:max_length-1] for l in labels]
dataset = dataset.add_column("labels", labels)
return dataset
def convert_predictions(input_logits):
all_probs = []
for logits in input_logits:
logits = logits.reshape(-1, 2)
probabilities_class1 = expit(logits[:, 1] - logits[:, 0])
all_probs.append(probabilities_class1)
return np.concatenate(all_probs)
def normalize_scores(scores):
min_score = np.min(scores)
max_score = np.max(scores)
return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores
def predict_protein_sequence(test_one_letter_sequence):
# Sanitize input sequence
test_one_letter_sequence = test_one_letter_sequence.replace("O", "X") \
.replace("B", "X").replace("U", "X") \
.replace("Z", "X").replace("J", "X")
# Prepare sequence for different model types
if ("prot_t5" in checkpoint) or ("ProstT5" in checkpoint):
test_one_letter_sequence = " ".join(test_one_letter_sequence)
if "ProstT5" in checkpoint:
test_one_letter_sequence = "<AA2fold> " + test_one_letter_sequence
# Create dummy labels
dummy_labels = [np.zeros(len(test_one_letter_sequence))]
# Create dataset
test_dataset = create_dataset(tokenizer,
[test_one_letter_sequence],
dummy_labels,
checkpoint)
# Select appropriate data collator
data_collator = (DataCollatorForTokenClassification(tokenizer)
if "esm" not in checkpoint and "ProstT5" not in checkpoint
else DataCollatorForTokenClassification(tokenizer))
# Create data loader
test_loader = DataLoader(test_dataset, batch_size=1, collate_fn=data_collator)
# Predict
for batch in test_loader:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
with torch.no_grad():
outputs = model(input_ids, attention_mask=attention_mask)
logits = outputs.logits.detach().cpu().numpy()
# Process logits
logits = logits[:, :-1] # Remove last element for prot_t5
logits = convert_predictions(logits)
# Normalize and format results
normalized_scores = normalize_scores(logits)
test_one_letter_sequence = test_one_letter_sequence.replace(" ", "")
result_str = "\n".join([f"{aa}: {score:.2f}" for aa, score in zip(test_one_letter_sequence, normalized_scores)])
return result_str
def fetch_pdb(pdb_id):
try:
# Create a directory to store PDB files if it doesn't exist
os.makedirs('pdb_files', exist_ok=True)
# Fetch the PDB structure from RCSB
pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb'
pdb_path = f'pdb_files/{pdb_id}.pdb'
# Download the file
response = requests.get(pdb_url)
if response.status_code == 200:
with open(pdb_path, 'wb') as f:
f.write(response.content)
return pdb_path
else:
return None
except Exception as e:
print(f"Error fetching PDB: {e}")
return None
def process_input(sequence, pdb_id):
# Predict binding sites
binding_site_predictions = predict_protein_sequence(sequence)
# Fetch PDB file
pdb_path = fetch_pdb(pdb_id)
return binding_site_predictions, pdb_path
# Create Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Protein Binding Site Prediction")
with gr.Row():
with gr.Column():
# Sequence input
sequence_input = gr.Textbox(
lines=2,
placeholder="Enter protein sequence here...",
label="Protein Sequence"
)
# PDB ID input
pdb_input = gr.Textbox(
lines=1,
placeholder="Enter PDB ID here...",
label="PDB ID for 3D Visualization"
)
# Predict button
predict_btn = gr.Button("Predict Binding Sites")
with gr.Column():
# Binding site predictions output
predictions_output = gr.Textbox(
label="Binding Site Predictions"
)
# 3D Molecule visualization
molecule_output = Molecule3D(
label="Protein Structure",
reps=reps
)
# Prediction logic
predict_btn.click(
process_input,
inputs=[sequence_input, pdb_input],
outputs=[predictions_output, molecule_output]
)
# Add some example inputs
gr.Markdown("## Examples")
gr.Examples(
examples=[
["MKVLWAALLVTFLAGCQAKVEQAVETEPEPELRQQTEWQSGQRWELALGRFWDYLRWVQTLSEQVQEELLSSQVTQELRALMDETMKELKAYKSELEEQLTPVAEETRARLSKELQAAQARLGADMEDVCGRLVQYRGEVQAMLGQSTEELRVRLASHLRKLRKRLLRDADDLQKRLAVYQAGAREGAERGLSAIRERLGPLVEQGRVRAATVGSLAGQPLQERAQAWGERLRARMEEMGSRTRDRLDEVKEQVAEVRAKLEEQAQQRL", "1ABC"],
],
inputs=[sequence_input, pdb_input],
outputs=[predictions_output, molecule_output]
)
demo.launch()