Spaces:
Running
Running
import gradio as gr | |
from model_loader import load_model | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.utils.data import DataLoader | |
import re | |
import numpy as np | |
import os | |
import pandas as pd | |
import copy | |
import transformers, datasets | |
from transformers import AutoTokenizer | |
from transformers import DataCollatorForTokenClassification | |
from datasets import Dataset | |
from scipy.special import expit | |
import requests | |
# Biopython imports | |
from Bio.PDB import PDBParser, Select | |
from Bio.PDB.DSSP import DSSP | |
from gradio_molecule3d import Molecule3D | |
# Configuration | |
checkpoint = 'ThorbenF/prot_t5_xl_uniref50' | |
max_length = 1500 | |
# Load model and move to device | |
model, tokenizer = load_model(checkpoint, max_length) | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model.to(device) | |
model.eval() | |
def create_dataset(tokenizer, seqs, labels, checkpoint): | |
tokenized = tokenizer(seqs, max_length=max_length, padding=False, truncation=True) | |
dataset = Dataset.from_dict(tokenized) | |
# Adjust labels based on checkpoint | |
if ("esm" in checkpoint) or ("ProstT5" in checkpoint): | |
labels = [l[:max_length-2] for l in labels] | |
else: | |
labels = [l[:max_length-1] for l in labels] | |
dataset = dataset.add_column("labels", labels) | |
return dataset | |
def convert_predictions(input_logits): | |
all_probs = [] | |
for logits in input_logits: | |
logits = logits.reshape(-1, 2) | |
probabilities_class1 = expit(logits[:, 1] - logits[:, 0]) | |
all_probs.append(probabilities_class1) | |
return np.concatenate(all_probs) | |
def normalize_scores(scores): | |
min_score = np.min(scores) | |
max_score = np.max(scores) | |
return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores | |
def predict_protein_sequence(test_one_letter_sequence): | |
# Sanitize input sequence | |
test_one_letter_sequence = test_one_letter_sequence.replace("O", "X") \ | |
.replace("B", "X").replace("U", "X") \ | |
.replace("Z", "X").replace("J", "X") | |
# Prepare sequence for different model types | |
if ("prot_t5" in checkpoint) or ("ProstT5" in checkpoint): | |
test_one_letter_sequence = " ".join(test_one_letter_sequence) | |
if "ProstT5" in checkpoint: | |
test_one_letter_sequence = "<AA2fold> " + test_one_letter_sequence | |
# Create dummy labels | |
dummy_labels = [np.zeros(len(test_one_letter_sequence))] | |
# Create dataset | |
test_dataset = create_dataset(tokenizer, | |
[test_one_letter_sequence], | |
dummy_labels, | |
checkpoint) | |
# Select appropriate data collator | |
data_collator = (DataCollatorForTokenClassification(tokenizer) | |
if "esm" not in checkpoint and "ProstT5" not in checkpoint | |
else DataCollatorForTokenClassification(tokenizer)) | |
# Create data loader | |
test_loader = DataLoader(test_dataset, batch_size=1, collate_fn=data_collator) | |
# Predict | |
for batch in test_loader: | |
input_ids = batch['input_ids'].to(device) | |
attention_mask = batch['attention_mask'].to(device) | |
with torch.no_grad(): | |
outputs = model(input_ids, attention_mask=attention_mask) | |
logits = outputs.logits.detach().cpu().numpy() | |
# Process logits | |
logits = logits[:, :-1] # Remove last element for prot_t5 | |
logits = convert_predictions(logits) | |
# Normalize and format results | |
normalized_scores = normalize_scores(logits) | |
test_one_letter_sequence = test_one_letter_sequence.replace(" ", "") | |
return test_one_letter_sequence, normalized_scores | |
def fetch_pdb(pdb_id): | |
try: | |
# Create a directory to store PDB files if it doesn't exist | |
os.makedirs('pdb_files', exist_ok=True) | |
# Fetch the PDB structure from RCSB | |
pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb' | |
pdb_path = f'pdb_files/{pdb_id}.pdb' | |
# Download the file | |
response = requests.get(pdb_url) | |
if response.status_code == 200: | |
with open(pdb_path, 'wb') as f: | |
f.write(response.content) | |
return pdb_path | |
else: | |
return None | |
except Exception as e: | |
print(f"Error fetching PDB: {e}") | |
return None | |
def extract_protein_sequence(pdb_path): | |
""" | |
Extract the longest protein sequence from a PDB file | |
""" | |
parser = PDBParser(QUIET=1) | |
structure = parser.get_structure('protein', pdb_path) | |
class ProteinSelect(Select): | |
def accept_residue(self, residue): | |
# Only accept standard amino acids | |
standard_aa = set('ACDEFGHIKLMNPQRSTVWY') | |
return residue.get_resname() in standard_aa | |
# Find the longest protein chain | |
longest_sequence = "" | |
longest_chain = None | |
for model in structure: | |
for chain in model: | |
sequence = "" | |
for residue in chain: | |
if Select().accept_residue(residue): | |
sequence += residue.get_resname() | |
# Convert 3-letter amino acid codes to 1-letter | |
aa_dict = { | |
'ALA':'A', 'CYS':'C', 'ASP':'D', 'GLU':'E', 'PHE':'F', | |
'GLY':'G', 'HIS':'H', 'ILE':'I', 'LYS':'K', 'LEU':'L', | |
'MET':'M', 'ASN':'N', 'PRO':'P', 'GLN':'Q', 'ARG':'R', | |
'SER':'S', 'THR':'T', 'VAL':'V', 'TRP':'W', 'TYR':'Y' | |
} | |
one_letter_sequence = ''.join([aa_dict.get(res, 'X') for res in sequence]) | |
# Track the longest sequence | |
if len(one_letter_sequence) > len(longest_sequence) and \ | |
10 < len(one_letter_sequence) < 1500: | |
longest_sequence = one_letter_sequence | |
longest_chain = chain | |
return longest_sequence, longest_chain | |
def process_pdb(pdb_id): | |
# Fetch PDB file | |
pdb_path = fetch_pdb(pdb_id) | |
if not pdb_path: | |
return "Failed to fetch PDB file", None, None | |
# Extract protein sequence and chain | |
protein_sequence, chain = extract_protein_sequence(pdb_path) | |
if not protein_sequence: | |
return "No suitable protein sequence found", None, None | |
# Predict binding sites | |
sequence, normalized_scores = predict_protein_sequence(protein_sequence) | |
# Prepare representations for coloring residues | |
reps = [] | |
for i, (res, score) in enumerate(zip(sequence, normalized_scores), start=1): | |
# Map score to a color gradient from blue (low) to red (high) | |
color_intensity = int(score * 255) | |
color = f'rgb({color_intensity}, 0, {255-color_intensity})' | |
rep = { | |
"model": 0, | |
"chain": chain.id, | |
"resname": res, | |
"resnum": i, | |
"style": "cartoon", | |
"color": color, | |
"residue_range": f"{i}-{i}", | |
"around": 0, | |
"byres": True, | |
"visible": True | |
} | |
reps.append(rep) | |
# Prepare result string | |
result_str = "\n".join([f"{aa}: {score:.2f}" for aa, score in zip(sequence, normalized_scores)]) | |
return result_str, reps, pdb_path | |
# Create Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# Protein Binding Site Prediction") | |
with gr.Row(): | |
with gr.Column(): | |
# PDB ID input with default suggestion | |
pdb_input = gr.Textbox( | |
value="2IWI", | |
label="PDB ID", | |
placeholder="Enter PDB ID here..." | |
) | |
# Predict button | |
predict_btn = gr.Button("Predict Binding Sites") | |
with gr.Column(): | |
# Binding site predictions output | |
predictions_output = gr.Textbox( | |
label="Binding Site Predictions" | |
) | |
# 3D Molecule visualization | |
molecule_output = Molecule3D( | |
label="Protein Structure", | |
reps=[] # Start with empty representations | |
) | |
# Prediction logic | |
predict_btn.click( | |
process_pdb, | |
inputs=[pdb_input], | |
outputs=[predictions_output, molecule_output, molecule_output] | |
) | |
# Add some example inputs | |
gr.Markdown("## Examples") | |
gr.Examples( | |
examples=[ | |
["2IWI"], | |
["1ABC"], | |
["4HHB"] | |
], | |
inputs=[pdb_input], | |
outputs=[predictions_output, molecule_output, molecule_output] | |
) | |
demo.launch() |