Spaces:
Running
Running
import gradio as gr | |
from model_loader import load_model | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.utils.data import DataLoader | |
import re | |
import numpy as np | |
import os | |
import pandas as pd | |
import copy | |
import transformers, datasets | |
from transformers import AutoTokenizer | |
from transformers import DataCollatorForTokenClassification | |
from datasets import Dataset | |
from scipy.special import expit | |
import requests | |
from gradio_molecule3d import Molecule3D | |
# Biopython imports | |
from Bio.PDB import PDBParser, Select, PDBIO | |
from Bio.PDB.DSSP import DSSP | |
from Bio.PDB import PDBList | |
from matplotlib import cm # For color mapping | |
from matplotlib.colors import Normalize | |
# Configuration | |
checkpoint = 'ThorbenF/prot_t5_xl_uniref50' | |
max_length = 1500 | |
# Load model and move to device | |
model, tokenizer = load_model(checkpoint, max_length) | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model.to(device) | |
model.eval() | |
def is_valid_sequence_length(length: int) -> bool: | |
"""Check if sequence length is within valid range.""" | |
return 100 <= length <= 1500 | |
def is_nucleic_acid_chain(chain) -> bool: | |
"""Check if chain contains nucleic acids.""" | |
nucleic_acids = {'A', 'C', 'G', 'T', 'U', 'DA', 'DC', 'DG', 'DT', 'DU', 'UNK'} | |
return any(residue.get_resname().strip() in nucleic_acids for residue in chain) | |
def extract_protein_sequence(pdb_path): | |
""" | |
Extract the longest protein sequence from a PDB file with improved logic | |
""" | |
parser = PDBParser(QUIET=1) | |
structure = parser.get_structure('protein', pdb_path) | |
# Comprehensive amino acid mapping | |
aa_dict = { | |
# Standard amino acids (20 canonical) | |
'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F', | |
'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L', | |
'MET': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN': 'Q', 'ARG': 'R', | |
'SER': 'S', 'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y', | |
# Modified amino acids and alternative names | |
'MSE': 'M', # Selenomethionine | |
'SEP': 'S', # Phosphoserine | |
'TPO': 'T', # Phosphothreonine | |
'CSO': 'C', # Hydroxylalanine | |
'PTR': 'Y', # Phosphotyrosine | |
'HYP': 'P', # Hydroxyproline | |
} | |
# Ligand and nucleic acid exclusion set | |
ligand_exclusion_set = {'HOH', 'WAT', 'DOD', 'SO4', 'PO4', 'GOL', 'ACT', 'EDO'} | |
# Find the longest protein chain | |
longest_sequence = "" | |
longest_chain = None | |
for model in structure: | |
for chain in model: | |
# Skip nucleic acid chains | |
if is_nucleic_acid_chain(chain): | |
continue | |
# Extract and convert sequence | |
sequence = "" | |
for residue in chain: | |
# Check if residue is a standard amino acid or a known modified amino acid | |
res_name = residue.get_resname().strip() | |
if res_name in aa_dict: | |
sequence += aa_dict[res_name] | |
# Check for valid length and update longest sequence | |
if (10 < len(sequence) < 1500 and | |
len(sequence) > len(longest_sequence)): | |
longest_sequence = sequence | |
longest_chain = chain | |
if not longest_sequence: | |
return None, None, pdb_path | |
# Save filtered PDB if needed | |
if longest_chain: | |
io = PDBIO() | |
io.set_structure(longest_chain.get_parent().get_parent()) | |
filtered_pdb_path = pdb_path.replace('.pdb', '_filtered.pdb') | |
io.save(filtered_pdb_path) | |
return longest_sequence, longest_chain, filtered_pdb_path | |
return longest_sequence, longest_chain, pdb_path | |
def create_dataset(tokenizer, seqs, labels, checkpoint): | |
tokenized = tokenizer(seqs, max_length=max_length, padding=False, truncation=True) | |
dataset = Dataset.from_dict(tokenized) | |
# Adjust labels based on checkpoint | |
if ("esm" in checkpoint) or ("ProstT5" in checkpoint): | |
labels = [l[:max_length-2] for l in labels] | |
else: | |
labels = [l[:max_length-1] for l in labels] | |
dataset = dataset.add_column("labels", labels) | |
return dataset | |
def convert_predictions(input_logits): | |
all_probs = [] | |
for logits in input_logits: | |
logits = logits.reshape(-1, 2) | |
probabilities_class1 = expit(logits[:, 1] - logits[:, 0]) | |
all_probs.append(probabilities_class1) | |
return np.concatenate(all_probs) | |
def normalize_scores(scores): | |
min_score = np.min(scores) | |
max_score = np.max(scores) | |
return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores | |
def predict_protein_sequence(test_one_letter_sequence): | |
# Sanitize input sequence | |
test_one_letter_sequence = test_one_letter_sequence.replace("O", "X") \ | |
.replace("B", "X").replace("U", "X") \ | |
.replace("Z", "X").replace("J", "X") | |
# Prepare sequence for different model types | |
if ("prot_t5" in checkpoint) or ("ProstT5" in checkpoint): | |
test_one_letter_sequence = " ".join(test_one_letter_sequence) | |
if "ProstT5" in checkpoint: | |
test_one_letter_sequence = "<AA2fold> " + test_one_letter_sequence | |
# Create dummy labels | |
dummy_labels = [np.zeros(len(test_one_letter_sequence))] | |
# Create dataset | |
test_dataset = create_dataset(tokenizer, | |
[test_one_letter_sequence], | |
dummy_labels, | |
checkpoint) | |
# Select appropriate data collator | |
data_collator = (DataCollatorForTokenClassification(tokenizer) | |
if "esm" not in checkpoint and "ProstT5" not in checkpoint | |
else DataCollatorForTokenClassification(tokenizer)) | |
# Create data loader | |
test_loader = DataLoader(test_dataset, batch_size=1, collate_fn=data_collator) | |
# Predict | |
for batch in test_loader: | |
input_ids = batch['input_ids'].to(device) | |
attention_mask = batch['attention_mask'].to(device) | |
with torch.no_grad(): | |
outputs = model(input_ids, attention_mask=attention_mask) | |
logits = outputs.logits.detach().cpu().numpy() | |
# Process logits | |
logits = logits[:, :-1] # Remove last element for prot_t5 | |
logits = convert_predictions(logits) | |
# Normalize and format results | |
normalized_scores = normalize_scores(logits) | |
test_one_letter_sequence = test_one_letter_sequence.replace(" ", "") | |
return test_one_letter_sequence, normalized_scores | |
def fetch_pdb(pdb_id): | |
try: | |
# Create a directory to store PDB files if it doesn't exist | |
os.makedirs('pdb_files', exist_ok=True) | |
# Fetch the PDB structure from RCSB | |
pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb' | |
pdb_path = f'pdb_files/{pdb_id}.pdb' | |
# Download the file | |
response = requests.get(pdb_url) | |
if response.status_code == 200: | |
with open(pdb_path, 'wb') as f: | |
f.write(response.content) | |
return pdb_path | |
else: | |
return None | |
except Exception as e: | |
print(f"Error fetching PDB: {e}") | |
return None | |
def score_to_color(score): | |
norm = Normalize(vmin=0, vmax=1) # Normalize scores between 0 and 1 | |
color_map = cm.coolwarm # Directly use the colormap (e.g., 'cividis', 'coolwarm', etc.) | |
rgba = color_map(norm(score)) # Get RGBA values | |
hex_color = '#{:02x}{:02x}{:02x}'.format(int(rgba[0] * 255), int(rgba[1] * 255), int(rgba[2] * 255)) | |
return hex_color | |
def process_pdb(pdb_id): | |
# Fetch PDB file | |
pdbl = PDBList() | |
pdb_path = pdbl.retrieve_pdb_file(pdb_id, pdir='pdb_files', file_format='pdb') | |
if not pdb_path or not os.path.exists(pdb_path): | |
return "Failed to fetch PDB file", None | |
# Extract protein sequence and chain | |
protein_sequence, chain, filtered_pdb_path = extract_protein_sequence(pdb_path) | |
if not protein_sequence: | |
return "No suitable protein sequence found", None | |
# Predict binding sites | |
sequence, normalized_scores = predict_protein_sequence(protein_sequence) | |
# Prepare result string | |
result_str = "\n".join([f"{aa}: {score:.2f}" for aa, score in zip(sequence, normalized_scores)]) | |
# Prepare representations for Molecule3D | |
reps = [ | |
{ | |
"model": 0, | |
"chain": "", | |
"resname": "", | |
"style": "cartoon", | |
"color": "spectrum", | |
"residue_range": "", | |
"around": 0, | |
"byres": False, | |
"visible": True | |
} | |
] | |
# Add color-coded residues based on binding site scores | |
#for i, score in enumerate(normalized_scores): | |
# if score > 0.7: # You can adjust this threshold | |
# reps.append({ | |
# "model": 0, | |
# "chain": chain.get_id(), | |
# "style": "stick", | |
# "color": score_to_color(score), | |
# "residue_range": f"{i+1}-{i+1}", | |
# "byres": True, | |
# "visible": True | |
# }) | |
# Create Molecule3D with the PDB file and representations | |
molecule_viewer = Molecule3D( | |
reps=reps | |
) | |
return result_str, molecule_viewer | |
# Create Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# Protein Binding Site Prediction") | |
with gr.Row(): | |
with gr.Column(): | |
pdb_input = gr.Textbox( | |
value="2IWI", | |
label="PDB ID", | |
placeholder="Enter PDB ID here..." | |
) | |
predict_btn = gr.Button("Predict Binding Sites") | |
with gr.Column(): | |
predictions_output = gr.Textbox( | |
label="Binding Site Predictions" | |
) | |
molecule_output = Molecule3D(label="Protein Structure") | |
# Prediction logic | |
predict_btn.click( | |
process_pdb, | |
inputs=[pdb_input], | |
outputs=[predictions_output, molecule_output] | |
) | |
gr.Markdown("## Examples") | |
gr.Examples( | |
examples=[ | |
["2IWI"], | |
["7RPZ"], | |
["3TJN"] | |
], | |
inputs=[pdb_input], | |
outputs=[predictions_output, molecule_output] | |
) | |
demo.launch() |