import gradio as gr from model_loader import load_model import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader import re import numpy as np import os import pandas as pd import copy import transformers, datasets from transformers import AutoTokenizer from transformers import DataCollatorForTokenClassification from datasets import Dataset from scipy.special import expit import requests from gradio_molecule3d import Molecule3D # Biopython imports from Bio.PDB import PDBParser, Select, PDBIO from Bio.PDB.DSSP import DSSP from Bio.PDB import PDBList from matplotlib import cm # For color mapping from matplotlib.colors import Normalize # Load model and move to device checkpoint = 'ThorbenF/prot_t5_xl_uniref50' max_length = 1500 model, tokenizer = load_model(checkpoint, max_length) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model.to(device) model.eval() # Function to fetch a PDB file def fetch_pdb(pdb_id): pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb' pdb_path = f'pdb_files/{pdb_id}.pdb' os.makedirs('pdb_files', exist_ok=True) response = requests.get(pdb_url) if response.status_code == 200: with open(pdb_path, 'wb') as f: f.write(response.content) return pdb_path return None def normalize_scores(scores): min_score = np.min(scores) max_score = np.max(scores) return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores def process_pdb(pdb_id, segment): pdb_path = fetch_pdb(pdb_id) if not pdb_path: return "Failed to fetch PDB file", None, None parser = PDBParser(QUIET=1) structure = parser.get_structure('protein', pdb_path) chain = structure[0][segment] # Comprehensive amino acid mapping aa_dict = { 'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F', 'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L', 'MET': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN': 'Q', 'ARG': 'R', 'SER': 'S', 'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y', 'MSE': 'M', 'SEP': 'S', 'TPO': 'T', 'CSO': 'C', 'PTR': 'Y', 'HYP': 'P' } # Exclude non-amino acid residues sequence = "".join( aa_dict[residue.get_resname().strip()] for residue in chain if residue.get_resname().strip() in aa_dict ) # Prepare input for model prediction input_ids = tokenizer(" ".join(sequence), return_tensors="pt").input_ids.to(device) with torch.no_grad(): outputs = model(input_ids).logits.detach().cpu().numpy().squeeze() # Calculate scores and normalize them scores = expit(outputs[:, 1] - outputs[:, 0]) normalized_scores = normalize_scores(scores) # Prepare the result string, including only amino acid residues result_str = "\n".join([ f"{res.get_resname()} {res.id[1]} {sequence[i]} {normalized_scores[i]:.2f}" for i, res in enumerate(chain) if res.get_resname().strip() in aa_dict ]) # Save predictions to file with open(f"{pdb_id}_predictions.txt", "w") as f: f.write(result_str) return result_str, pdb_path, f"{pdb_id}_predictions.txt" reps = [{"model": 0, "style": "cartoon", "color": "spectrum"}] # Gradio UI with gr.Blocks() as demo: gr.Markdown("# Protein Binding Site Prediction") with gr.Row(): pdb_input = gr.Textbox(value="2IWI", label="PDB ID", placeholder="Enter PDB ID here...") segment_input = gr.Textbox(value="A", label="Chain ID (Segment)", placeholder="Enter Chain ID here...") visualize_btn = gr.Button("Visualize Sructure") prediction_btn = gr.Button("Predict Ligand Binding Site") molecule_output = Molecule3D(label="Protein Structure", reps=reps) predictions_output = gr.Textbox(label="Binding Site Predictions") download_output = gr.File(label="Download Predictions") visualize_btn.click(fetch_pdb, inputs=[pdb_input], outputs=molecule_output) prediction_btn.click( process_pdb, inputs=[pdb_input, segment_input], outputs=[predictions_output, molecule_output, download_output] ) gr.Markdown("## Examples") gr.Examples( examples=[ ["2IWI"], ["7RPZ"], ["3TJN"] ], inputs=[pdb_input, segment_input], outputs=[predictions_output, molecule_output, download_output] ) demo.launch(share=True)