Spaces:
Running
Running
import gradio as gr | |
from model_loader import load_model | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss | |
from torch.utils.data import DataLoader | |
import re | |
import numpy as np | |
import os | |
import pandas as pd | |
import copy | |
import transformers, datasets | |
from transformers.modeling_outputs import TokenClassifierOutput | |
from transformers.models.t5.modeling_t5 import T5Config, T5PreTrainedModel, T5Stack | |
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map | |
from transformers import T5EncoderModel, T5Tokenizer | |
from transformers.models.esm.modeling_esm import EsmPreTrainedModel, EsmModel | |
from transformers import AutoTokenizer | |
from transformers import TrainingArguments, Trainer, set_seed | |
from transformers import DataCollatorForTokenClassification | |
from dataclasses import dataclass | |
from typing import Dict, List, Optional, Tuple, Union | |
# for custom DataCollator | |
from transformers.data.data_collator import DataCollatorMixin | |
from transformers.tokenization_utils_base import PreTrainedTokenizerBase | |
from transformers.utils import PaddingStrategy | |
from datasets import Dataset | |
from scipy.special import expit | |
import requests | |
from gradio_molecule3d import Molecule3D | |
#import peft | |
#from peft import get_peft_config, PeftModel, PeftConfig, inject_adapter_in_model, LoraConfig | |
# Configuration | |
checkpoint = 'ThorbenF/prot_t5_xl_uniref50' | |
max_length = 1500 | |
# Default representations for molecule rendering | |
reps = [ | |
{ | |
"model": 0, | |
"chain": "", | |
"resname": "", | |
"style": "cartoon", | |
"color": "spectrum", | |
"residue_range": "", | |
"around": 0, | |
"byres": False, | |
"visible": True | |
} | |
] | |
# Load model and move to device | |
model, tokenizer = load_model(checkpoint, max_length) | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model.to(device) | |
model.eval() | |
def create_dataset(tokenizer, seqs, labels, checkpoint): | |
tokenized = tokenizer(seqs, max_length=max_length, padding=False, truncation=True) | |
dataset = Dataset.from_dict(tokenized) | |
# Adjust labels based on checkpoint | |
if ("esm" in checkpoint) or ("ProstT5" in checkpoint): | |
labels = [l[:max_length-2] for l in labels] | |
else: | |
labels = [l[:max_length-1] for l in labels] | |
dataset = dataset.add_column("labels", labels) | |
return dataset | |
def convert_predictions(input_logits): | |
all_probs = [] | |
for logits in input_logits: | |
logits = logits.reshape(-1, 2) | |
probabilities_class1 = expit(logits[:, 1] - logits[:, 0]) | |
all_probs.append(probabilities_class1) | |
return np.concatenate(all_probs) | |
def normalize_scores(scores): | |
min_score = np.min(scores) | |
max_score = np.max(scores) | |
return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores | |
def predict_protein_sequence(test_one_letter_sequence): | |
# Sanitize input sequence | |
test_one_letter_sequence = test_one_letter_sequence.replace("O", "X") \ | |
.replace("B", "X").replace("U", "X") \ | |
.replace("Z", "X").replace("J", "X") | |
# Prepare sequence for different model types | |
if ("prot_t5" in checkpoint) or ("ProstT5" in checkpoint): | |
test_one_letter_sequence = " ".join(test_one_letter_sequence) | |
if "ProstT5" in checkpoint: | |
test_one_letter_sequence = "<AA2fold> " + test_one_letter_sequence | |
# Create dummy labels | |
dummy_labels = [np.zeros(len(test_one_letter_sequence))] | |
# Create dataset | |
test_dataset = create_dataset(tokenizer, | |
[test_one_letter_sequence], | |
dummy_labels, | |
checkpoint) | |
# Select appropriate data collator | |
data_collator = (DataCollatorForTokenClassification(tokenizer) | |
if "esm" not in checkpoint and "ProstT5" not in checkpoint | |
else DataCollatorForTokenClassification(tokenizer)) | |
# Create data loader | |
test_loader = DataLoader(test_dataset, batch_size=1, collate_fn=data_collator) | |
# Predict | |
for batch in test_loader: | |
input_ids = batch['input_ids'].to(device) | |
attention_mask = batch['attention_mask'].to(device) | |
with torch.no_grad(): | |
outputs = model(input_ids, attention_mask=attention_mask) | |
logits = outputs.logits.detach().cpu().numpy() | |
# Process logits | |
logits = logits[:, :-1] # Remove last element for prot_t5 | |
logits = convert_predictions(logits) | |
# Normalize and format results | |
normalized_scores = normalize_scores(logits) | |
test_one_letter_sequence = test_one_letter_sequence.replace(" ", "") | |
result_str = "\n".join([f"{aa}: {score:.2f}" for aa, score in zip(test_one_letter_sequence, normalized_scores)]) | |
return result_str | |
def fetch_pdb(pdb_id): | |
try: | |
# Create a directory to store PDB files if it doesn't exist | |
os.makedirs('pdb_files', exist_ok=True) | |
# Fetch the PDB structure from RCSB | |
pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb' | |
pdb_path = f'pdb_files/{pdb_id}.pdb' | |
# Download the file | |
response = requests.get(pdb_url) | |
if response.status_code == 200: | |
with open(pdb_path, 'wb') as f: | |
f.write(response.content) | |
return pdb_path | |
else: | |
return None | |
except Exception as e: | |
print(f"Error fetching PDB: {e}") | |
return None | |
def process_input(sequence, pdb_id): | |
# Predict binding sites | |
binding_site_predictions = predict_protein_sequence(sequence) | |
# Fetch PDB file | |
pdb_path = fetch_pdb(pdb_id) | |
return binding_site_predictions, pdb_path | |
# Create Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# Protein Binding Site Prediction") | |
with gr.Row(): | |
with gr.Column(): | |
# Sequence input | |
sequence_input = gr.Textbox( | |
lines=2, | |
placeholder="Enter protein sequence here...", | |
label="Protein Sequence" | |
) | |
# PDB ID input | |
pdb_input = gr.Textbox( | |
lines=1, | |
placeholder="Enter PDB ID here...", | |
label="PDB ID for 3D Visualization" | |
) | |
# Predict button | |
predict_btn = gr.Button("Predict Binding Sites") | |
with gr.Column(): | |
# Binding site predictions output | |
predictions_output = gr.Textbox( | |
label="Binding Site Predictions" | |
) | |
# 3D Molecule visualization | |
molecule_output = Molecule3D( | |
label="Protein Structure", | |
reps=reps | |
) | |
# Prediction logic | |
predict_btn.click( | |
process_input, | |
inputs=[sequence_input, pdb_input], | |
outputs=[predictions_output, molecule_output] | |
) | |
# Add some example inputs | |
gr.Markdown("## Examples") | |
gr.Examples( | |
examples=[ | |
["MKVLWAALLVTFLAGCQAKVEQAVETEPEPELRQQTEWQSGQRWELALGRFWDYLRWVQTLSEQVQEELLSSQVTQELRALMDETMKELKAYKSELEEQLTPVAEETRARLSKELQAAQARLGADMEDVCGRLVQYRGEVQAMLGQSTEELRVRLASHLRKLRKRLLRDADDLQKRLAVYQAGAREGAERGLSAIRERLGPLVEQGRVRAATVGSLAGQPLQERAQAWGERLRARMEEMGSRTRDRLDEVKEQVAEVRAKLEEQAQQRL", "1ABC"], | |
], | |
inputs=[sequence_input, pdb_input], | |
outputs=[predictions_output, molecule_output] | |
) | |
demo.launch() |